In [None]:
# COLAB
!pip install neptune-client timm
# pip install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases
!unzip data.zip
!mkdir artifacts/

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import pandas as pd 

from tqdm import tqdm

from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score

import warnings
warnings.simplefilter('ignore')

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

from functools import partial

import torch
from torch import einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from collections import OrderedDict

import torchvision 
from torchvision import transforms as T
from torchvision.io import read_image

import itertools

from sklearn.model_selection import train_test_split

import neptune.new as neptune


In [2]:
run = neptune.init(
    project="victorcallejas/Belluga",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=="
)

https://app.neptune.ai/victorcallejas/Belluga/e/BEL-131
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
device = torch.device("cuda")
print(torch.cuda.get_device_name(0), torch.cuda.get_device_properties(device))

fp16 = True
input_dtype = torch.float16 if fp16 else torch.float32

scaler =  torch.cuda.amp.GradScaler(enabled=fp16)

NVIDIA GeForce GTX 1070 with Max-Q Design _CudaDeviceProperties(name='NVIDIA GeForce GTX 1070 with Max-Q Design', major=6, minor=1, total_memory=8191MB, multi_processor_count=16)


In [4]:
# SCORING
PREDICTION_LIMIT = 20
QUERY_ID_COL = "query_id"
DATABASE_ID_COL = "database_image_id"
SCORE_COL = "score"

SCORE_THRESHOLD = 0.5

class MeanAveragePrecision:
    @classmethod
    def score(cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int):
        """Calculates mean average precision for a ranking task.
        :param predicted: The predicted values as a dataframe with specified column names
        :param actual: The ground truth values as a dataframe with specified column names
        """
        if not predicted[SCORE_COL].between(0.0, 1.0).all():
            raise ValueError("Scores must be in range [0, 1].")
        if predicted.index.name != QUERY_ID_COL:
            raise ValueError(
                f"First column of submission must be named '{QUERY_ID_COL}', "
                f"got {predicted.index.name}."
            )
        if predicted.columns.to_list() != [DATABASE_ID_COL, SCORE_COL]:
            raise ValueError(
                f"Columns of submission must be named '{[DATABASE_ID_COL, SCORE_COL]}', "
                f"got {predicted.columns.to_list()}."
            )

        unadjusted_aps, predicted_n_pos, actual_n_pos = cls._score_per_query(
            predicted, actual, prediction_limit
        )
        adjusted_aps = unadjusted_aps.multiply(predicted_n_pos).divide(actual_n_pos)
        return adjusted_aps.mean()

    @classmethod
    def _score_per_query(
        cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int
    ):
        """Calculates per-query mean average precision for a ranking task."""
        merged = predicted.merge(
            right=actual.assign(actual=1.0),
            how="left",
            on=[QUERY_ID_COL, DATABASE_ID_COL],
        ).fillna({"actual": 0.0})
        # Per-query raw average precisions based on predictions
        unadjusted_aps = merged.groupby(QUERY_ID_COL).apply(
            lambda df: average_precision_score(df["actual"].values, df[SCORE_COL].values)
            if df["actual"].sum()
            else 0.0
        )
        # Total ground truth positive counts for rescaling
        predicted_n_pos = merged["actual"].groupby(QUERY_ID_COL).sum().astype("int64").rename()
        actual_n_pos = actual.groupby(QUERY_ID_COL).size().clip(upper=prediction_limit)
        return unadjusted_aps, predicted_n_pos, actual_n_pos
    
    
def map_score(dataloader, model, threshold=SCORE_THRESHOLD):
    
    model.eval()
    
    sub = []
    
    sigmoid = torch.nn.Sigmoid()
    
    with torch.no_grad():        
    
        for query, reference, query_id, reference_id in tqdm(dataloader):
            
            query = query.to(device, non_blocking=True, dtype=input_dtype)
            reference = reference.to(device, non_blocking=True, dtype=input_dtype)

            with torch.cuda.amp.autocast(enabled = fp16):
                logits = sigmoid(model(query=query, reference=reference)).cpu().squeeze().tolist()
                
            sub.extend(zip(query_id, reference_id, logits))
            
    sub = pd.DataFrame(sub, columns=['query_id', 'database_image_id', 'score'])
    sub = sub[sub.score > threshold]
    sub = sub.set_index(['database_image_id']).groupby('query_id')['score'].nlargest(20).reset_index()
    sub = sub.set_index('query_id')
    
    mean_avg_prec = MeanAveragePrecision.score(
        predicted=sub, actual=dataloader.dataset.gt, prediction_limit=PREDICTION_LIMIT
    )
    
    print('MaP: ',mean_avg_prec)
    return mean_avg_prec

In [5]:
# DATA

IMG_SIZE = 224
ROOT_DIR = '../data/'
NORM_TRANSFORMS = torch.nn.Sequential(
    T.Resize([IMG_SIZE, IMG_SIZE]),
    T.ConvertImageDtype(input_dtype),
    T.Normalize(mean = (0.4234, 0.4272, 0.4641),
                std  = (0.2037, 0.2027, 0.2142)),
)

VAL_SPLIT = 0.05

METADATA = pd.read_csv('../data/metadata.csv')
METADATA = METADATA[METADATA.viewpoint == 'top']

TRAIN, VAL = train_test_split(METADATA, test_size=0.05, random_state=42)
TRAIN, VAL = TRAIN.reset_index(), VAL.reset_index()
#TRAIN, VAL = METADATA, METADATA
#TRAIN = METADATA

def getImages(metadata):
    IMAGES = {}
    for image_id, path in tqdm(zip(metadata.image_id, metadata.path), total=metadata.shape[0]):
        IMAGES[image_id] = NORM_TRANSFORMS(read_image(ROOT_DIR + path))
    return IMAGES

IMAGES = getImages(METADATA)


    

100%|██████████| 5434/5434 [01:50<00:00, 49.30it/s]


In [14]:
class PreTrain_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        return IMAGES[self.metadata.image_id[idx]]


class Eval_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
    
        # GROUND TRUTH
        gt = []
        for wid in self.metadata.whale_id: # query
            tmp = self.metadata[self.metadata.whale_id == wid].image_id.tolist() # get all images id
            gt.extend(list(itertools.permutations(tmp, 2)))
        self.gt = pd.DataFrame(gt,columns=['query_id','database_image_id'])
        self.gt = self.gt.set_index('query_id')
        
        # ALL QUERIES
        self.query_reference = list(itertools.permutations(self.metadata.image_id, 2))
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return len(self.query_reference)
    
    def __getitem__(self, idx):
        query_id = self.query_reference[idx][0]
        reference_id = self.query_reference[idx][1]
        
        query = self.getimage(query_id)
        reference = self.getimage(reference_id)
        
        return query, reference, query_id, reference_id
    
    
class Train_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
        #self.aug = T.RandomErasing(p=0.8, scale=(0.12, 0.43), ratio=(0.3, 3.3), value=0, inplace=False)
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        
        anchor = self.getimage(self.metadata.image_id[idx])
        label = self.metadata.whale_id[idx]
        
        pos = self.getimage(self.metadata[self.metadata.whale_id == label].sample()['image_id'].values[0])
        neg = self.getimage(self.metadata[self.metadata.whale_id != label].sample()['image_id'].values[0])

        return anchor, pos, neg

In [15]:
# DATALOADERS 

PRETRAIN_BS = 4
TRAIN_BS = 16
INFER_BS = TRAIN_BS

NUM_WORKERS = 0


#pretrain_dataset = PreTrain_BellugaDataset(METADATA)
train_train_dataset = Train_BellugaDataset(TRAIN)
#train_eval_dataset = Eval_BellugaDataset(TRAIN)
valid_eval_dataset = Eval_BellugaDataset(VAL)

'''
pretrain_dataloader = torch.utils.data.DataLoader(
                        pretrain_dataset, 
                        batch_size=PRETRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_train_dataloader = torch.utils.data.DataLoader(
                        train_train_dataset, 
                        batch_size=TRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_eval_dataloader = torch.utils.data.DataLoader(
                        train_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
valid_eval_dataloader = torch.utils.data.DataLoader(
                        valid_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=False, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )   


In [7]:
from torchvision.models import resnet18

class CrossCNV(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = resnet18()
        self.backbone.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.backbone.fc = nn.Linear(in_features=512, out_features=1, bias=True)

    def forward(self, query, reference):
        x = torch.cat([query, reference], dim=1)
        x = self.backbone(x)
        return x
    

def crosscnv_base_224():
    return CrossCNV()

In [8]:
model = CrossCNV().to(device)

#ckpt = torch.load('net_170.pt')
#model.load_state_dict(ckpt['model_state_dict'], strict=False)

optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4)
#optimizer.load_state_dict(ckpt['optimizer_state_dict'], )

#opt = torch.optim.SGD(model.parameters(), lr = .05)
loss_fn = torch.nn.BCEWithLogitsLoss()

In [9]:
model

CrossCNV(
  (backbone): ResNet(
    (conv1): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r

In [16]:
epochs = 500

for epoch_i in range(0, epochs):
    
    epoch_loss, epoch_acc = 0, 0
    
    model.train()
    
    for anchor, pos, neg in tqdm(train_train_dataloader):

        optimizer.zero_grad(True)
        
        anchor = anchor.to(device, non_blocking=True, dtype=input_dtype)
        pos = pos.to(device, non_blocking=True, dtype=input_dtype)
        neg = neg.to(device, non_blocking=True, dtype=input_dtype)
        
        query = torch.cat([anchor, anchor], dim=0)
        reference = torch.cat([pos, neg], dim=0)
        labels = torch.cat([torch.ones(pos.shape[0],1), torch.zeros(neg.shape[0],1)], dim=0).to(device)

        with torch.cuda.amp.autocast(enabled=fp16):
            logits = model(query=query, reference=reference)
            loss = loss_fn(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        run['running/loss'].log(loss)
        
        # accuracy
        preds = torch.nn.Sigmoid()(logits).round().detach().cpu().numpy()
        acc = accuracy_score(labels.detach().cpu().numpy(), preds)
        run['running/acc'].log(acc)
        
        epoch_loss += loss
        epoch_acc += acc
        
    run['epoch/train/loss'].log(epoch_loss / len(train_train_dataloader))
    run['epoch/train/acc'].log(epoch_acc / len(train_train_dataloader))
    
    if epoch_i % 10 == 0:
        map = map_score(valid_eval_dataloader, model)
        run['epoch5/valid/map'].log(map)
        
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'../artifacts/net_{epoch_i}.pt')



100%|██████████| 323/323 [01:25<00:00,  3.77it/s]
100%|██████████| 4607/4607 [03:18<00:00, 23.25it/s]


MaP:  0.07933004577343532


100%|██████████| 323/323 [01:41<00:00,  3.18it/s]
100%|██████████| 323/323 [01:41<00:00,  3.17it/s]
100%|██████████| 323/323 [01:39<00:00,  3.26it/s]
100%|██████████| 323/323 [01:42<00:00,  3.14it/s]
100%|██████████| 323/323 [01:51<00:00,  2.89it/s]
100%|██████████| 323/323 [01:35<00:00,  3.38it/s]
100%|██████████| 323/323 [01:37<00:00,  3.33it/s]
100%|██████████| 323/323 [01:42<00:00,  3.15it/s]
100%|██████████| 323/323 [01:37<00:00,  3.33it/s]
100%|██████████| 323/323 [01:35<00:00,  3.38it/s]
100%|██████████| 4607/4607 [03:47<00:00, 20.22it/s]


MaP:  0.19133828358084343


100%|██████████| 323/323 [01:45<00:00,  3.07it/s]
100%|██████████| 323/323 [01:51<00:00,  2.91it/s]
100%|██████████| 323/323 [01:45<00:00,  3.06it/s]
100%|██████████| 323/323 [01:37<00:00,  3.31it/s]
100%|██████████| 323/323 [01:34<00:00,  3.40it/s]
100%|██████████| 323/323 [01:33<00:00,  3.47it/s]
100%|██████████| 323/323 [01:36<00:00,  3.35it/s]
100%|██████████| 323/323 [01:46<00:00,  3.03it/s]
100%|██████████| 323/323 [01:50<00:00,  2.91it/s]
100%|██████████| 323/323 [01:57<00:00,  2.76it/s]
100%|██████████| 4607/4607 [04:11<00:00, 18.28it/s]


MaP:  0.23659157446851564


100%|██████████| 323/323 [01:58<00:00,  2.73it/s]
100%|██████████| 323/323 [01:46<00:00,  3.04it/s]
100%|██████████| 323/323 [01:54<00:00,  2.83it/s]
100%|██████████| 323/323 [01:47<00:00,  3.01it/s]
100%|██████████| 323/323 [01:48<00:00,  2.99it/s]
100%|██████████| 323/323 [01:45<00:00,  3.08it/s]
100%|██████████| 323/323 [01:40<00:00,  3.22it/s]
100%|██████████| 323/323 [01:35<00:00,  3.38it/s]
100%|██████████| 323/323 [01:43<00:00,  3.13it/s]
100%|██████████| 323/323 [01:40<00:00,  3.20it/s]
100%|██████████| 4607/4607 [03:40<00:00, 20.87it/s]


MaP:  0.24234265632368546


100%|██████████| 323/323 [01:37<00:00,  3.30it/s]
100%|██████████| 323/323 [01:38<00:00,  3.27it/s]
100%|██████████| 323/323 [01:40<00:00,  3.22it/s]
100%|██████████| 323/323 [01:49<00:00,  2.94it/s]
100%|██████████| 323/323 [01:34<00:00,  3.41it/s]
100%|██████████| 323/323 [01:32<00:00,  3.48it/s]
 27%|██▋       | 86/323 [00:28<01:19,  2.99it/s]


KeyboardInterrupt: 

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'../artifacts/net_{epoch_i}_83acc.pt')


In [None]:
# Optimal batch size for inference
"""
5vcpus
52 ram
12 ram gpu

model to eval, optim bactch size calculate with tqdm
"""
model.eval()

TEST_INFER_BS_INIT = 10
TEST_INFER_BS_ML = 10

NUM_WORKERS = 0

class DummyTest(torch.utils.data.Dataset):
  def __init__(self):
    super().__init__()

  def __len__(self):
    return 7000000

  def __getitem__(self, idx):
    return torch.zeros((3,224,224)), torch.zeros((3,224,224))

BS = TEST_INFER_BS_INIT = 5
while True:
  i = 0
  BS = BS + TEST_INFER_BS_ML
  dataloader = torch.utils.data.DataLoader(
                          DummyTest(), 
                          batch_size=BS,
                          shuffle=False, 
                          num_workers=0,
                          pin_memory=True
                      )   

  for batch in tqdm(dataloader, total=len(dataloader)):
    query = batch[0].to(device, non_blocking=True, dtype=torch.float32)
    reference = batch[1].to(device, non_blocking=True, dtype=torch.float32)
    with torch.no_grad():
      logits, attn, q_cls, r_cls = model(query=query, reference=reference)
      i+=1
    
    if i == 50:
      print(BS)
      print(torch.cuda.mem_get_info(device=0))
      break

In [6]:
fp16 = True 
input_dtype = torch.float16 if fp16 else torch.float32

scaler =  torch.cuda.amp.GradScaler(enabled=fp16)

model = crossvit_base_224().to(device)
input = torch.zeros((2,3,224,224), dtype=input_dtype, device=device)
loss_fn = torch.nn.BCEWithLogitsLoss()
labels = torch.ones((2,1), dtype=input_dtype, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)

warmup, reps = 30, 10

for i in range(0, warmup):
    
        with torch.cuda.amp.autocast(enabled = fp16):
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
) as prof:
    print(torch.cuda.mem_get_info(0))
    for i in range(0, reps):
        
        with torch.cuda.amp.autocast(enabled = fp16):
            
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        

(6465060864, 8589737984)


In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=25))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::reshape         5.00%     566.319ms        11.90%        1.346s     153.356us     544.024ms         4.84%        1.311s     149.324us           0 b           0 b     542.46 Mb           0 

In [16]:
# layers x bs x (q,ref) x cross_attn_depth x n_heads x 1 x tokens(inc cls)
attn[0][0][0].shape

torch.Size([2, 12, 1, 197])

In [20]:
print(len(attn))
print(len(attn[0]))
print(len(attn[0][0]))
attn[0][0][0].shape

3
2
2


torch.Size([2, 12, 1, 197])