In [1]:
import os
ON_KAGGLE_KERNEL = os.path.isdir("/kaggle/input")
start_dir = os.getcwd()

if ON_KAGGLE_KERNEL:
    os.chdir("/kaggle/input/utilities/")
else:
    os.chdir(f"{os.environ.get('PYTHONPATH')}/src/utils")

from common import load_structure, save_structure, load_train_file, INPUT_DATA_DIR, OUTPUT_DATA_DIR, SUB_DIR, set_seed, WANDB_KEY, WEIGHTS_DIR, load_test_file
from data_proc import create_idvs_with_one_img_in_train_split, create_train_val_loaders, split_into_train_val, encode_labels
from data_structures import WhaleDataset, TorchConfig, InferenceKNNModel
os.chdir(start_dir)

print(INPUT_DATA_DIR)

import importlib
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from tqdm import tqdm
import gc
import time 
import copy

def _import_or_install(name: str):
    try:
        globals()[name] = importlib.import_module(name)
        # importlib.import_module(name)
        print(f"{name} found and imported (version {globals()[name].__version__}).")
    except ModuleNotFoundError:
        !pip install -q --upgrade $name
        # importlib.import_module(name)
        globals()[name] = importlib.import_module(name)
        print(f"{name} Installed and imported (version {globals()[name].__version__}).")


_import_or_install("timm")
_import_or_install("wandb")

if ON_KAGGLE_KERNEL:
    try:
        wandb.login(key=WANDB_KEY)
    except Exception as e:
        print(f"WandB login failed:\n{e}")
else:
    %reload_ext autoreload
    %autoreload 2
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = 'all'



def _load_prev_weights():
    weight_file = sorted(os.listdir(WEIGHTS_DIR))[0]
    print(f"Using weight file {weight_file}.")
    return torch.load(str(WEIGHTS_DIR / weight_file), map_location="cpu")


/home/paul/projects/Happywhale_competition/data/cropped
timm found and imported (version 0.4.12).
wandb found and imported (version 0.12.10).


In [2]:
conf = TorchConfig.default()
conf.dict()
set_seed(conf.seed)

conf.epochs = 20

if not ON_KAGGLE_KERNEL:
    conf.train_batch_size = 1

{'seed': 319,
 'epochs': 10,
 'img_size': 448,
 'augm_args': {'hor_flip': {'p': 0.5},
  'ver_flip': {'p': 0.5},
  'rot': {'p': 0.5, 'limit': 30}},
 'model_name': 'tf_efficientnet_b0',
 'num_classes': 15587,
 'train_batch_size': 32,
 'valid_batch_size': 64,
 'optim': torch.optim.adam.Adam,
 'optim_args': {'lr': 0.0001, 'weight_decay': 1e-06},
 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR,
 'scheduler_args': {'T_max': 500, 'eta_min': 1e-06},
 'n_fold': 5,
 'init_optim_': None,
 'init_sched_': None,
 's': 30.0,
 'm': 0.5,
 'ls_eps': 0.0,
 'easy_margin': False}

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

In [5]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    """

    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output


In [6]:
class HappyWhaleModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super(HappyWhaleModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.fc = ArcMarginProduct(
            in_features,
            conf.num_classes,
            s=conf.s,
            m=conf.m,
            easy_margin=conf.easy_margin,
            ls_eps=conf.ls_eps,
        )

    def forward(self, images, labels=None):
        features = self.model(images)
        # pooled_features = self.pooling(features)
        pooled_features = self.pooling(features).flatten(1)
        if labels is not None:
            output = self.fc(pooled_features, labels)
        else:
            output = self.fc(pooled_features)
        return output


In [7]:
df = create_idvs_with_one_img_in_train_split()
df_train, df_val = split_into_train_val(df)
len(df_train), len(df_val)
transforms = conf.make_transforms()
train_set = WhaleDataset(df_train, transforms["train"])
val_set = WhaleDataset(df_val, transforms["valid"])
train_loader, val_loader = create_train_val_loaders(
    train_set, val_set, conf.train_batch_size, conf.valid_batch_size
)
model = HappyWhaleModel(conf.model_name)
model.load_state_dict(_load_prev_weights())
model.to(device);

optim = conf.get_optim(model)
scheduler = conf.get_scheduler()
def criterion(outputs, labels):
    return nn.CrossEntropyLoss()(outputs, labels)

/home/paul/projects/Happywhale_competition/data/id_encoding.pickle saved.
/home/paul/projects/Happywhale_competition/data/splits/idvs_with_one_img_in_train.pickle saved.


(40589, 10444)

Using weight file Loss12.9268_epoch5.bin.


<All keys matched successfully>

HappyWhaleModel(
  (model): EfficientNet(
    (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_

In [8]:
def train_one_epoch(model, optimizer, dataloader, device, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)
        
        outputs = model(images, labels)
        loss = criterion(outputs, labels)
        # loss = loss / conf['n_accumulate']
            
        loss.backward()
    
        # if (step + 1) % CONFIG['n_accumulate'] == 0:
        optimizer.step()

        optimizer.zero_grad()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss

In [9]:
@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)

        outputs = model(images, labels)
        loss = criterion(outputs, labels)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss)   
    
    gc.collect()
    
    return epoch_loss

In [10]:
def run_training(model, optimizer, scheduler, num_epochs):
    if ON_KAGGLE_KERNEL:
        wandb.watch(model, log_freq=100)

    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_loss = np.inf

    for epoch in range(1, num_epochs + 1):
        gc.collect()
        train_epoch_loss = train_one_epoch(
            model,
            optimizer,
            dataloader=train_loader,
            device=device,
            epoch=epoch,
        )

        if scheduler is not None:
            scheduler.step()

        val_epoch_loss = valid_one_epoch(model, val_loader, device=device, epoch=epoch)

        # Log the metrics
        wandb.log({"Train Loss": train_epoch_loss})
        wandb.log({"Valid Loss": val_epoch_loss})

        # deep copy the model
        if val_epoch_loss <= best_epoch_loss:
            print(f"Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss})")
            best_epoch_loss = val_epoch_loss
            run.summary["Best Loss"] = best_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "Loss{:.4f}_epoch{:.0f}.bin".format(best_epoch_loss, epoch)
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved.")

        print()

    end = time.time()
    time_elapsed = end - start
    print(
        "Training complete in {:.0f}h {:.0f}m {:.0f}s".format(
            time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60
        )
    )
    print("Best Loss: {:.4f}".format(best_epoch_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model


In [11]:
if ON_KAGGLE_KERNEL:
    run = wandb.init(
    project="HappyWhale",
    config=conf.dict(),
    job_type="Train",
    tags=["Standard classifier", "efficientnet_b0", "448"],
    anonymous="must",
)

    model = run_training(model, optim, scheduler, num_epochs=conf.epochs)

if ON_KAGGLE_KERNEL:
    run.finish()

In [12]:
[m for m in model.named_modules()][-10:]
[par.shape for par in model.parameters()][-5:]
test_dataset = WhaleDataset(load_test_file(), transforms["valid"], labels=False)
test_input = test_dataset[10]["image"].unsqueeze(0)
test_input.shape
model = model.to("cpu")
model.load_state_dict(_load_prev_weights())

model.fc = nn.Identity()
model.eval()
sum(p.numel() for p in model.parameters())


[('model.blocks.6.0.se.gate', Sigmoid()),
 ('model.blocks.6.0.conv_pwl',
  Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)),
 ('model.blocks.6.0.bn3',
  BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)),
 ('model.conv_head',
  Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)),
 ('model.bn2',
  BatchNorm2d(1280, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)),
 ('model.act2', SiLU(inplace=True)),
 ('model.global_pool', Identity()),
 ('model.classifier', Identity()),
 ('pooling', GeM(p=2.9141, eps=1e-06)),
 ('fc', ArcMarginProduct())]

[torch.Size([1280, 320, 1, 1]),
 torch.Size([1280]),
 torch.Size([1280]),
 torch.Size([1]),
 torch.Size([15587, 1280])]

torch.Size([1, 3, 448, 448])

Using weight file Loss12.9268_epoch5.bin.


<All keys matched successfully>

HappyWhaleModel(
  (model): EfficientNet(
    (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_

4007549

In [13]:
# InferenceKNNModel.list_all()

In [14]:
im = InferenceKNNModel(model, name="2_ArcFace_first_try_uncropped")
im.end_to_end()

Embeddings loaded from /home/paul/projects/Happywhale_competition/data/embeddings/2_ArcFace_first_try_uncropped
Found previous train embeddings for 2_ArcFace_first_try_uncropped.
Found previous test embeddings for 2_ArcFace_first_try_uncropped.
train embeddings: Already complete.
test embeddings: Already complete.
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 78989 samples in 0.003s...
[t-SNE] Computed neighbors for 78989 samples in 157.514s...
[t-SNE] Computed conditional probabilities for sample 1000 / 78989
[t-SNE] Computed conditional probabilities for sample 2000 / 78989
[t-SNE] Computed conditional probabilities for sample 3000 / 78989
[t-SNE] Computed conditional probabilities for sample 4000 / 78989
[t-SNE] Computed conditional probabilities for sample 5000 / 78989
[t-SNE] Computed conditional probabilities for sample 6000 / 78989
[t-SNE] Computed conditional probabilities for sample 7000 / 78989
[t-SNE] Computed conditional probabilities for sample 8000 / 78989
[t-

100%|██████████| 2.29M/2.29M [00:10<00:00, 233kB/s] 


Successfully submitted to Happywhale - Whale and Dolphin IdentificationfileName                               date                 description                    status    publicScore  privateScore  
-------------------------------------  -------------------  -----------------------------  --------  -----------  ------------  
sub_2_ArcFace_first_try_uncropped.csv  2022-02-27 12:15:46  2_ArcFace_first_try_uncropped  complete  0.162        None          
sub_ArcFace_first_try_uncropped.csv    2022-02-26 20:16:40  ArcFace_first_try_uncropped    complete  0.175        None          
baseline_2.csv                         2022-02-24 12:13:38  python_subm_test               complete  0.113        None          
pytorch_classifier_thr_004.csv         2022-02-07 15:04:33  csv                            complete  0.145        None          
pytorch_classifier_thr_0.03.csv        2022-02-07 15:03:40  03                             complete  0.145        None          
pytorch_classifier_3.csv  