In [None]:
# COLAB
!pip install neptune-client
# pip install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases
!unzip data.zip
!mkdir artifacts/

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import pandas as pd 

from tqdm import tqdm

from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score

import warnings
warnings.simplefilter('ignore')

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

import torch
from torch import einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from collections import OrderedDict

from torchvision import transforms as T
from torchvision.io import read_image

import itertools

from sklearn.model_selection import train_test_split

import neptune.new as neptune


In [2]:
run = neptune.init(
    project="victorcallejas/Belluga",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=="
)

https://app.neptune.ai/victorcallejas/Belluga/e/BEL-162
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [4]:
device = torch.device("cuda")
print(torch.cuda.get_device_name(0), torch.cuda.get_device_properties(device))

fp16 = False
input_dtype = torch.float16 if fp16 else torch.float32

scaler =  torch.cuda.amp.GradScaler(enabled=fp16)

NVIDIA GeForce GTX 1070 with Max-Q Design _CudaDeviceProperties(name='NVIDIA GeForce GTX 1070 with Max-Q Design', major=6, minor=1, total_memory=8191MB, multi_processor_count=16)


In [5]:
# SCORING
PREDICTION_LIMIT = 20
QUERY_ID_COL = "query_id"
DATABASE_ID_COL = "database_image_id"
SCORE_COL = "score"

SCORE_THRESHOLD = 0.5

class MeanAveragePrecision:
    @classmethod
    def score(cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int):
        """Calculates mean average precision for a ranking task.
        :param predicted: The predicted values as a dataframe with specified column names
        :param actual: The ground truth values as a dataframe with specified column names
        """
        if not predicted[SCORE_COL].between(0.0, 1.0).all():
            raise ValueError("Scores must be in range [0, 1].")
        if predicted.index.name != QUERY_ID_COL:
            raise ValueError(
                f"First column of submission must be named '{QUERY_ID_COL}', "
                f"got {predicted.index.name}."
            )
        if predicted.columns.to_list() != [DATABASE_ID_COL, SCORE_COL]:
            raise ValueError(
                f"Columns of submission must be named '{[DATABASE_ID_COL, SCORE_COL]}', "
                f"got {predicted.columns.to_list()}."
            )

        unadjusted_aps, predicted_n_pos, actual_n_pos = cls._score_per_query(
            predicted, actual, prediction_limit
        )
        adjusted_aps = unadjusted_aps.multiply(predicted_n_pos).divide(actual_n_pos)
        return adjusted_aps.mean()

    @classmethod
    def _score_per_query(
        cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int
    ):
        """Calculates per-query mean average precision for a ranking task."""
        merged = predicted.merge(
            right=actual.assign(actual=1.0),
            how="left",
            on=[QUERY_ID_COL, DATABASE_ID_COL],
        ).fillna({"actual": 0.0})
        # Per-query raw average precisions based on predictions
        unadjusted_aps = merged.groupby(QUERY_ID_COL).apply(
            lambda df: average_precision_score(df["actual"].values, df[SCORE_COL].values)
            if df["actual"].sum()
            else 0.0
        )
        # Total ground truth positive counts for rescaling
        predicted_n_pos = merged["actual"].groupby(QUERY_ID_COL).sum().astype("int64").rename()
        actual_n_pos = actual.groupby(QUERY_ID_COL).size().clip(upper=prediction_limit)
        return unadjusted_aps, predicted_n_pos, actual_n_pos
    
    
def map_score(dataloader, model, threshold=SCORE_THRESHOLD):
    
    model.eval()
    
    sub = []
    
    sigmoid = torch.nn.Sigmoid()
    
    with torch.no_grad():        
    
        for query, reference, query_id, reference_id in tqdm(dataloader):
            
            query = query.to(device, non_blocking=True, dtype=input_dtype)
            reference = reference.to(device, non_blocking=True, dtype=input_dtype)

            with torch.cuda.amp.autocast(enabled = fp16):
                logits = sigmoid(model(query=query, reference=reference)).cpu().squeeze().tolist()
                
            sub.extend(zip(query_id, reference_id, logits))
            
    sub = pd.DataFrame(sub, columns=['query_id', 'database_image_id', 'score'])
    sub = sub[sub.score > threshold]
    sub = sub.set_index(['database_image_id']).groupby('query_id')['score'].nlargest(20).reset_index()
    sub = sub.set_index('query_id')
    
    mean_avg_prec = MeanAveragePrecision.score(
        predicted=sub, actual=dataloader.dataset.gt, prediction_limit=PREDICTION_LIMIT
    )
    
    print('MaP: ',mean_avg_prec)
    return mean_avg_prec

In [9]:
# DATA

IMG_SIZE = 224
ROOT_DIR = '../data/'
NORM_TRANSFORMS = torch.nn.Sequential(
    T.Resize([IMG_SIZE, IMG_SIZE]),
    T.ConvertImageDtype(input_dtype),
    T.Normalize(mean = (0.4234, 0.4272, 0.4641),
                std  = (0.2037, 0.2027, 0.2142)),
)

VAL_SPLIT = 0.05

METADATA = pd.read_csv('../data/metadata.csv')

TRAIN, VAL = train_test_split(METADATA, test_size=0.05, random_state=42)
TRAIN, VAL = TRAIN.reset_index(), VAL.reset_index()
#TRAIN, VAL = METADATA, METADATA
#TRAIN = METADATA

def getImages(metadata):
    IMAGES = {}
    for image_id, path in tqdm(zip(metadata.image_id, metadata.path), total=metadata.shape[0]):
        IMAGES[image_id] = NORM_TRANSFORMS(read_image(ROOT_DIR + path))
    return IMAGES

IMAGES = getImages(METADATA)

class PreTrain_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        return IMAGES[self.metadata.image_id[idx]]


class Eval_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
    
        # GROUND TRUTH
        gt = []
        for wid in self.metadata.whale_id: # query
            tmp = self.metadata[self.metadata.whale_id == wid].image_id.tolist() # get all images id
            gt.extend(list(itertools.permutations(tmp, 2)))
        self.gt = pd.DataFrame(gt,columns=['query_id','database_image_id'])
        self.gt = self.gt.set_index('query_id')
        
        # ALL QUERIES
        self.query_reference = list(itertools.permutations(self.metadata.image_id, 2))
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return len(self.query_reference)
    
    def __getitem__(self, idx):
        query_id = self.query_reference[idx][0]
        reference_id = self.query_reference[idx][1]
        
        query = self.getimage(query_id)
        reference = self.getimage(reference_id)
        
        return query, reference, query_id, reference_id
    
    
class Train_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
        self.aug = T.RandomErasing(p=0.4, scale=(0.12, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        
        anchor = self.aug(self.getimage(self.metadata.image_id[idx]))
        label = self.metadata.whale_id[idx]
        
        pos = self.aug(self.getimage(self.metadata[self.metadata.whale_id == label].sample()['image_id'].values[0]))
        neg = self.aug(self.getimage(self.metadata[self.metadata.whale_id != label].sample()['image_id'].values[0]))

        return anchor, pos, neg
    

100%|██████████| 5902/5902 [02:23<00:00, 40.99it/s]


In [10]:
# DATALOADERS 

PRETRAIN_BS = 6
TRAIN_BS = 64
INFER_BS = TRAIN_BS

NUM_WORKERS = 0


#pretrain_dataset = PreTrain_BellugaDataset(METADATA)
train_train_dataset = Train_BellugaDataset(TRAIN)
#train_eval_dataset = Eval_BellugaDataset(TRAIN)
valid_eval_dataset = Eval_BellugaDataset(VAL)

'''
pretrain_dataloader = torch.utils.data.DataLoader(
                        pretrain_dataset, 
                        batch_size=PRETRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_train_dataloader = torch.utils.data.DataLoader(
                        train_train_dataset, 
                        batch_size=TRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_eval_dataloader = torch.utils.data.DataLoader(
                        train_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
valid_eval_dataloader = torch.utils.data.DataLoader(
                        valid_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=False, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )   


In [11]:
# helpers
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# pre-layernorm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# feedforward
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None, kv_include_self = False):
        b, n, _, h = *x.shape, self.heads
        context = default(context, x)

        if kv_include_self:
            context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# transformer encoder, for qall and large patches
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# projecting CLS tokens, in the case that qall and large patch tokens have different dimensions
class ProjectInOut(nn.Module):
    def __init__(self, dim_in, dim_out, fn):
        super().__init__()
        self.fn = fn

        need_projection = dim_in != dim_out
        self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
        self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()

    def forward(self, x, *args, **kwargs):
        x = self.project_in(x)
        x = self.fn(x, *args, **kwargs)
        x = self.project_out(x)
        return x

# cross attention transformer
class CrossTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(ProjectInOut(dim, dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))))

    def forward(self, q_tokens, ref_tokens):
        (q_cls, q_patch_tokens), (ref_cls, ref_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (q_tokens, ref_tokens))
        for attend in self.layers:
            q_cls = attend(q_cls, context = ref_patch_tokens, kv_include_self = True) + q_cls
            ref_cls = attend(ref_cls, context = q_patch_tokens, kv_include_self = True) + ref_cls
            
        q_tokens = torch.cat((q_cls, q_patch_tokens), dim = 1)
        ref_tokens = torch.cat((ref_cls, ref_patch_tokens), dim = 1)
        
        return q_tokens, ref_tokens

# multi-scale encoder
class MultiScaleEncoder(nn.Module):
    def __init__(
        self,
        *,
        enc_depth,
        dim,
        enc_params,
        cross_attn_heads,
        cross_attn_depth,
        cross_attn_dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for enc_d, cross_d in zip(enc_depth, cross_attn_depth):
            self.layers.append(nn.ModuleList([
                Transformer(dim = dim, dropout = dropout, depth = enc_d, **enc_params),
                CrossTransformer(dim = dim, depth = cross_d, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)
            ]))

    def forward(self, q_tokens, ref_tokens):
        for enc, cross_attend in self.layers:
            q_tokens, ref_tokens = enc(q_tokens), enc(ref_tokens)
            q_tokens, ref_tokens = cross_attend(q_tokens, ref_tokens)

        return q_tokens, ref_tokens

# patch-based image to token embedder
class ImageEmbedder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        image_size,
        patch_size,
        dropout = 0.
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]

        return self.dropout(x)

# cross ViT class
class CrossViT(nn.Module):
    def __init__(
        self,
        *,
        image_size=224,
        num_classes=1,
        dim=192,
        patch_size = 16,
        enc_depth = [2,1,1],
        enc_heads = 8,
        enc_mlp_dim = 2048,
        enc_dim_head = 64,
        cross_attn_depth = [1,1,2],
        cross_attn_heads = 12,
        cross_attn_dim_head = 64,
        dropout = 0.2,
        emb_dropout = 0.1
    ):
        super().__init__()
        self.image_embedder = ImageEmbedder(dim = dim, image_size = image_size, patch_size = patch_size, dropout = emb_dropout)

        self.multi_scale_encoder = MultiScaleEncoder(
            dim = dim,
            enc_depth = enc_depth,
            cross_attn_heads = cross_attn_heads,
            cross_attn_dim_head = cross_attn_dim_head,
            cross_attn_depth = cross_attn_depth,
            enc_params = dict(
                heads = enc_heads,
                mlp_dim = enc_mlp_dim,
                dim_head = enc_dim_head
            ),
            dropout = dropout
        )

        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

    def forward(self, query, reference):
        q_tokens = self.image_embedder(query)
        ref_tokens = self.image_embedder(reference)

        q_tokens, ref_tokens = self.multi_scale_encoder(q_tokens, ref_tokens)

        q_cls, ref_cls = map(lambda t: t[:, 0], (q_tokens, ref_tokens))

        #cls = torch.cat([q_cls, ref_cls], dim=1)
        cls = q_cls + ref_cls
        logits = self.mlp_head(cls)

        return logits
    

def crossvit_base_224():
    
    return CrossViT(
        image_size=224,
        num_classes=1,
        dim=192,
        patch_size = 32,
        enc_depth = [1,1],
        enc_heads = 8,
        enc_mlp_dim = 2048,
        enc_dim_head = 64,
        cross_attn_depth = [1,1],
        cross_attn_heads = 12,
        cross_attn_dim_head = 64,
        dropout = 0.2,
        emb_dropout = 0.1
        )

In [12]:
model = crossvit_base_224().to(device)

#ckpt = torch.load('/kaggle/input/ckptttt/net (3).pt')
#model.load_state_dict(ckpt['model_state_dict'], strict=False)

optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-5)
#optimizer.load_state_dict(ckpt['optimizer_state_dict'], )

#opt = torch.optim.SGD(model.parameters(), lr = .05)
loss_fn = torch.nn.BCEWithLogitsLoss()

In [13]:
model

CrossViT(
  (image_embedder): ImageEmbedder(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
      (1): Linear(in_features=768, out_features=192, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (multi_scale_encoder): MultiScaleEncoder(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Transformer(
          (layers): ModuleList(
            (0): ModuleList(
              (0): PreNorm(
                (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
                (fn): Attention(
                  (attend): Softmax(dim=-1)
                  (dropout): Dropout(p=0.2, inplace=False)
                  (to_q): Linear(in_features=192, out_features=512, bias=False)
                  (to_kv): Linear(in_features=192, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=192, bias=True)
                   

In [14]:
epochs = 1000

for epoch_i in range(0, epochs):
    
    epoch_loss, epoch_acc = 0, 0
    
    model.train()
    
    for anchor, pos, neg in tqdm(train_train_dataloader):

        optimizer.zero_grad(True)
        
        anchor = anchor.to(device, non_blocking=True, dtype=input_dtype)
        pos = pos.to(device, non_blocking=True, dtype=input_dtype)
        neg = neg.to(device, non_blocking=True, dtype=input_dtype)
        
        query = torch.cat([anchor, anchor], dim=0)
        reference = torch.cat([pos, neg], dim=0)
        labels = torch.cat([torch.ones(pos.shape[0],1), torch.zeros(neg.shape[0],1)], dim=0).to(device)

        with torch.cuda.amp.autocast(enabled=fp16):
            logits = model(query=query, reference=reference)
            loss = loss_fn(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        run['running/loss'].log(loss)
        
        # accuracy
        preds = torch.nn.Sigmoid()(logits).round().detach().cpu().numpy()
        acc = accuracy_score(labels.detach().cpu().numpy(), preds)
        run['running/acc'].log(acc)
        
        epoch_loss += loss
        epoch_acc += acc
        
    run['epoch/train/loss'].log(epoch_loss / len(train_train_dataloader))
    run['epoch/train/acc'].log(epoch_acc / len(train_train_dataloader))
    
    if epoch_i % 20 == 0:
        map = map_score(valid_eval_dataloader, model)
        run['epoch5/valid/map'].log(map)
        
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'/kaggle/working/artifacts/net.pt')



 24%|██▍       | 21/88 [01:03<03:21,  3.01s/it]


KeyboardInterrupt: 

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'../artifacts/net_{epoch_i}_83acc.pt')


In [None]:
# Optimal batch size for inference
"""
5vcpus
52 ram
12 ram gpu

model to eval, optim bactch size calculate with tqdm
"""
model.eval()

TEST_INFER_BS_INIT = 10
TEST_INFER_BS_ML = 10

NUM_WORKERS = 0

class DummyTest(torch.utils.data.Dataset):
  def __init__(self):
    super().__init__()

  def __len__(self):
    return 7000000

  def __getitem__(self, idx):
    return torch.zeros((3,224,224)), torch.zeros((3,224,224))

BS = TEST_INFER_BS_INIT = 5
while True:
  i = 0
  BS = BS + TEST_INFER_BS_ML
  dataloader = torch.utils.data.DataLoader(
                          DummyTest(), 
                          batch_size=BS,
                          shuffle=False, 
                          num_workers=0,
                          pin_memory=True
                      )   

  for batch in tqdm(dataloader, total=len(dataloader)):
    query = batch[0].to(device, non_blocking=True, dtype=torch.float32)
    reference = batch[1].to(device, non_blocking=True, dtype=torch.float32)
    with torch.no_grad():
      logits, attn, q_cls, r_cls = model(query=query, reference=reference)
      i+=1
    
    if i == 50:
      print(BS)
      print(torch.cuda.mem_get_info(device=0))
      break

In [6]:
fp16 = True 
input_dtype = torch.float16 if fp16 else torch.float32

scaler =  torch.cuda.amp.GradScaler(enabled=fp16)

model = crossvit_base_224().to(device)
input = torch.zeros((2,3,224,224), dtype=input_dtype, device=device)
loss_fn = torch.nn.BCEWithLogitsLoss()
labels = torch.ones((2,1), dtype=input_dtype, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)

warmup, reps = 30, 10

for i in range(0, warmup):
    
        with torch.cuda.amp.autocast(enabled = fp16):
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
) as prof:
    print(torch.cuda.mem_get_info(0))
    for i in range(0, reps):
        
        with torch.cuda.amp.autocast(enabled = fp16):
            
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        

(6465060864, 8589737984)


In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=25))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::reshape         5.00%     566.319ms        11.90%        1.346s     153.356us     544.024ms         4.84%        1.311s     149.324us           0 b           0 b     542.46 Mb           0 

In [16]:
# layers x bs x (q,ref) x cross_attn_depth x n_heads x 1 x tokens(inc cls)
attn[0][0][0].shape

torch.Size([2, 12, 1, 197])

In [20]:
print(len(attn))
print(len(attn[0]))
print(len(attn[0][0]))
attn[0][0][0].shape

3
2
2


torch.Size([2, 12, 1, 197])