In [None]:
# COLAB
!pip install neptune-client
# pip install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases
!unzip data.zip
!mkdir artifacts/

In [22]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import pandas as pd 
import numpy as np 

from tqdm import tqdm

from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score

import warnings
warnings.simplefilter('ignore')

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from torchvision import transforms as T
from torchvision.io import read_image

import itertools

from sklearn.model_selection import train_test_split

import neptune.new as neptune


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
run = neptune.init(
    project="victorcallejas/Belluga",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=="
)

https://app.neptune.ai/victorcallejas/Belluga/e/BEL-172
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [24]:
device = torch.device("cuda")
print(torch.cuda.get_device_name(0), torch.cuda.get_device_properties(device))

input_dtype = torch.float32

NVIDIA GeForce GTX 1070 with Max-Q Design _CudaDeviceProperties(name='NVIDIA GeForce GTX 1070 with Max-Q Design', major=6, minor=1, total_memory=8191MB, multi_processor_count=16)


In [25]:
# SCORING
PREDICTION_LIMIT = 20
QUERY_ID_COL = "query_id"
DATABASE_ID_COL = "database_image_id"
SCORE_COL = "score"

SCORE_THRESHOLD = 0.5

class MeanAveragePrecision:
    @classmethod
    def score(cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int):
        """Calculates mean average precision for a ranking task.
        :param predicted: The predicted values as a dataframe with specified column names
        :param actual: The ground truth values as a dataframe with specified column names
        """
        if not predicted[SCORE_COL].between(0.0, 1.0).all():
            raise ValueError("Scores must be in range [0, 1].")
        if predicted.index.name != QUERY_ID_COL:
            raise ValueError(
                f"First column of submission must be named '{QUERY_ID_COL}', "
                f"got {predicted.index.name}."
            )
        if predicted.columns.to_list() != [DATABASE_ID_COL, SCORE_COL]:
            raise ValueError(
                f"Columns of submission must be named '{[DATABASE_ID_COL, SCORE_COL]}', "
                f"got {predicted.columns.to_list()}."
            )

        unadjusted_aps, predicted_n_pos, actual_n_pos = cls._score_per_query(
            predicted, actual, prediction_limit
        )
        adjusted_aps = unadjusted_aps.multiply(predicted_n_pos).divide(actual_n_pos)
        return adjusted_aps.mean()

    @classmethod
    def _score_per_query(
        cls, predicted: pd.DataFrame, actual: pd.DataFrame, prediction_limit: int
    ):
        """Calculates per-query mean average precision for a ranking task."""
        merged = predicted.merge(
            right=actual.assign(actual=1.0),
            how="left",
            on=[QUERY_ID_COL, DATABASE_ID_COL],
        ).fillna({"actual": 0.0})
        # Per-query raw average precisions based on predictions
        unadjusted_aps = merged.groupby(QUERY_ID_COL).apply(
            lambda df: average_precision_score(df["actual"].values, df[SCORE_COL].values)
            if df["actual"].sum()
            else 0.0
        )
        # Total ground truth positive counts for rescaling
        predicted_n_pos = merged["actual"].groupby(QUERY_ID_COL).sum().astype("int64").rename()
        actual_n_pos = actual.groupby(QUERY_ID_COL).size().clip(upper=prediction_limit)
        return unadjusted_aps, predicted_n_pos, actual_n_pos
    
    
def map_score(dataloader, model, threshold=SCORE_THRESHOLD):
    
    model.eval()
    
    sub = []
    
    sigmoid = torch.nn.Sigmoid()
    
    with torch.no_grad():        
    
        for query, reference, query_id, reference_id in tqdm(dataloader):
            
            query = query.to(device, non_blocking=True, dtype=input_dtype)
            reference = reference.to(device, non_blocking=True, dtype=input_dtype)

            with torch.cuda.amp.autocast(enabled = fp16):
                logits = sigmoid(model(query=query, reference=reference)).cpu().squeeze().tolist()
                
            sub.extend(zip(query_id, reference_id, logits))
            
    sub = pd.DataFrame(sub, columns=['query_id', 'database_image_id', 'score'])
    sub = sub[sub.score > threshold]
    sub = sub.set_index(['database_image_id']).groupby('query_id')['score'].nlargest(20).reset_index()
    sub = sub.set_index('query_id')
    
    mean_avg_prec = MeanAveragePrecision.score(
        predicted=sub, actual=dataloader.dataset.gt, prediction_limit=PREDICTION_LIMIT
    )
    
    print('MaP: ',mean_avg_prec)
    return mean_avg_prec

In [27]:
# DATA

IMG_SIZE = 224
ROOT_DIR = '../data/'
NORM_TRANSFORMS = torch.nn.Sequential(
    T.Resize([IMG_SIZE, IMG_SIZE]),
    T.ConvertImageDtype(input_dtype),
    T.Normalize(mean = (0.4234, 0.4272, 0.4641),
                std  = (0.2037, 0.2027, 0.2142)),
)

VAL_SPLIT = 0.05

METADATA = pd.read_csv('../data/metadata.csv')[:50]

TRAIN, VAL = train_test_split(METADATA, test_size=0.05, random_state=42)
TRAIN, VAL = TRAIN.reset_index(), VAL.reset_index()
#TRAIN, VAL = METADATA, METADATA
#TRAIN = METADATA

def getImages(metadata):
    IMAGES = {}
    for image_id, path in tqdm(zip(metadata.image_id, metadata.path), total=metadata.shape[0]):
        IMAGES[image_id] = NORM_TRANSFORMS(read_image(ROOT_DIR + path))
    return IMAGES

IMAGES = getImages(METADATA)

class PreTrain_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        return IMAGES[self.metadata.image_id[idx]]


class Eval_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
    
        # GROUND TRUTH
        gt = []
        for wid in self.metadata.whale_id: # query
            tmp = self.metadata[self.metadata.whale_id == wid].image_id.tolist() # get all images id
            gt.extend(list(itertools.permutations(tmp, 2)))
        self.gt = pd.DataFrame(gt,columns=['query_id','database_image_id'])
        self.gt = self.gt.set_index('query_id')
        
        # ALL QUERIES
        self.query_reference = list(itertools.permutations(self.metadata.image_id, 2))
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return len(self.query_reference)
    
    def __getitem__(self, idx):
        query_id = self.query_reference[idx][0]
        reference_id = self.query_reference[idx][1]
        
        query = self.getimage(query_id)
        reference = self.getimage(reference_id)
        
        return query, reference, query_id, reference_id
    
    
class Train_BellugaDataset(torch.utils.data.Dataset):
    
    def __init__(self, metadata):
        self.metadata = metadata
        self.aug = T.RandomErasing(p=0.4, scale=(0.12, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
            
    def getimage(self, image_id):
        return IMAGES[image_id]

    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        
        anchor = self.aug(self.getimage(self.metadata.image_id[idx]))
        label = self.metadata.whale_id[idx]
        
        pos = self.aug(self.getimage(self.metadata[self.metadata.whale_id == label].sample()['image_id'].values[0]))
        neg = self.aug(self.getimage(self.metadata[self.metadata.whale_id != label].sample()['image_id'].values[0]))

        return anchor, pos, neg
    

100%|██████████| 50/50 [00:00<00:00, 105.02it/s]


In [28]:
# DATALOADERS 

PRETRAIN_BS = 6
TRAIN_BS = 4
INFER_BS = TRAIN_BS

NUM_WORKERS = 0


#pretrain_dataset = PreTrain_BellugaDataset(METADATA)
train_train_dataset = Train_BellugaDataset(TRAIN)
#train_eval_dataset = Eval_BellugaDataset(TRAIN)
valid_eval_dataset = Eval_BellugaDataset(VAL)

'''
pretrain_dataloader = torch.utils.data.DataLoader(
                        pretrain_dataset, 
                        batch_size=PRETRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_train_dataloader = torch.utils.data.DataLoader(
                        train_train_dataset, 
                        batch_size=TRAIN_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
train_eval_dataloader = torch.utils.data.DataLoader(
                        train_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=True, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )
'''
valid_eval_dataloader = torch.utils.data.DataLoader(
                        valid_eval_dataset, 
                        batch_size=INFER_BS,
                        shuffle=False, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True
                    )   


In [None]:

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0]):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads

        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)

        # mlp to generate continuous relative position bias
        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(512, num_heads, bias=False))

        # get relative_coords_table
        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
        relative_coords_table = torch.stack(
            torch.meshgrid([relative_coords_h,
                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
        if pretrained_window_size[0] > 0:
            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
        else:
            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
        relative_coords_table *= 8  # normalize to -8, 8
        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
            torch.abs(relative_coords_table) + 1.0) / np.log2(8)

        self.register_buffer("relative_coords_table", relative_coords_table)

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(dim))
            self.v_bias = nn.Parameter(torch.zeros(dim))
        else:
            self.q_bias = None
            self.v_bias = None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv_bias = None
        if self.q_bias is not None:
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # cosine attention
        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
        attn = attn * logit_scale

        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, ' \
               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        pretrained_window_size (int): Window size in pre-training.
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
            pretrained_window_size=to_2tuple(pretrained_window_size))

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(self.norm1(x))

        # FFN
        x = x + self.drop_path(self.norm2(self.mlp(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(2 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.reduction(x)
        x = self.norm(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        flops += H * W * self.dim // 2
        return flops


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        pretrained_window_size (int): Local window size in pre-training.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 pretrained_window_size=0):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 pretrained_window_size=pretrained_window_size)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

    def _init_respostnorm(self):
        for blk in self.blocks:
            nn.init.constant_(blk.norm1.bias, 0)
            nn.init.constant_(blk.norm1.weight, 0)
            nn.init.constant_(blk.norm2.bias, 0)
            nn.init.constant_(blk.norm2.weight, 0)


class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops


class SwinTransformerV2(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint,
                               pretrained_window_size=pretrained_window_sizes[i_layer])
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)
        for bly in self.layers:
            bly._init_respostnorm()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops
    
    
def create_cross_swin_base():
    return SwinTransformerV2(img_size=224, patch_size=4, in_chans=3, num_classes=1,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=True, patch_norm=True,
                 use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0])  

In [29]:
model = create_cross_swin_base().to(device)
model

Matchformer(
  (backbone): Matchformer_SEA_lite(
    (AttentionBlock1): AttentionBlock(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 128, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
        (pos): Positional(
          (pa_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
          (sigmoid): Sigmoid()
        )
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (block): ModuleList(
        (0): Block(
          (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (q): Linear(in_features=128, out_features=128, bias=True)
            (kv): Linear(in_features=128, out_features=256, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (sr): Conv2d(128, 128, kernel_size=(8, 8), stride=(8, 8))
       

In [30]:
#ckpt = torch.load('/kaggle/input/ckptttt/net (3).pt')
#model.load_state_dict(ckpt['model_state_dict'], strict=False)

optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-5)
#optimizer.load_state_dict(ckpt['optimizer_state_dict'], )

#opt = torch.optim.SGD(model.parameters(), lr = .05)
loss_fn = torch.nn.BCEWithLogitsLoss()

In [31]:
epochs = 1000

for epoch_i in range(0, epochs):
    
    epoch_loss, epoch_acc = 0, 0
    
    model.train()
    
    for anchor, pos, neg in tqdm(train_train_dataloader):

        optimizer.zero_grad(True)
        
        anchor = anchor.to(device, non_blocking=True, dtype=input_dtype)
        pos = pos.to(device, non_blocking=True, dtype=input_dtype)
        neg = neg.to(device, non_blocking=True, dtype=input_dtype)
        
        query = torch.cat([anchor, anchor], dim=0)
        reference = torch.cat([pos, neg], dim=0)
        labels = torch.cat([torch.ones(pos.shape[0],1), torch.zeros(neg.shape[0],1)], dim=0).to(device)

        logits = model(query, reference)
        loss = loss_fn(logits, labels)
        
        loss.backward()
        optimizer.step()
      
        run['running/loss'].log(loss)
        
        # accuracy
        preds = torch.nn.Sigmoid()(logits).round().detach().cpu().numpy()
        acc = accuracy_score(labels.detach().cpu().numpy(), preds)
        run['running/acc'].log(acc)
        
        epoch_loss += loss
        epoch_acc += acc
        
    run['epoch/train/loss'].log(epoch_loss / len(train_train_dataloader))
    run['epoch/train/acc'].log(epoch_acc / len(train_train_dataloader))
    
    if epoch_i % 20 == 0:
        map = map_score(valid_eval_dataloader, model)
        run['epoch5/valid/map'].log(map)
        
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'/kaggle/working/artifacts/net.pt')



  0%|          | 0/12 [00:00<?, ?it/s]


torch.Size([16, 128, 56, 56]) torch.Size([16, 192, 28, 28])
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "c:\repos\belugas\env\lib\site-packages\IPython\core\interactiveshell.py", line 3397, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\vcall\AppData\Local\Temp\ipykernel_22024\1060586997.py", line 21, in <cell line: 3>
    logits = model(query, reference)
  File "c:\repos\belugas\env\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\repos\belugas\notebooks\..\src\model\MT\matchformer.py", line 320, in forward
    x = torch.cat([feats_q, feats_r], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 28 but got size 56 for tensor number 1 in the list.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\repos\belugas\env\lib\site-packages\IPython\core\interactiveshell.py", line 1992, in showtraceback
    stb = self.InteractiveTB.structured_trace

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'../artifacts/net_{epoch_i}_83acc.pt')


In [None]:
# Optimal batch size for inference
"""
5vcpus
52 ram
12 ram gpu

model to eval, optim bactch size calculate with tqdm
"""
model.eval()

TEST_INFER_BS_INIT = 10
TEST_INFER_BS_ML = 10

NUM_WORKERS = 0

class DummyTest(torch.utils.data.Dataset):
  def __init__(self):
    super().__init__()

  def __len__(self):
    return 7000000

  def __getitem__(self, idx):
    return torch.zeros((3,224,224)), torch.zeros((3,224,224))

BS = TEST_INFER_BS_INIT = 5
while True:
  i = 0
  BS = BS + TEST_INFER_BS_ML
  dataloader = torch.utils.data.DataLoader(
                          DummyTest(), 
                          batch_size=BS,
                          shuffle=False, 
                          num_workers=0,
                          pin_memory=True
                      )   

  for batch in tqdm(dataloader, total=len(dataloader)):
    query = batch[0].to(device, non_blocking=True, dtype=torch.float32)
    reference = batch[1].to(device, non_blocking=True, dtype=torch.float32)
    with torch.no_grad():
      logits, attn, q_cls, r_cls = model(query=query, reference=reference)
      i+=1
    
    if i == 50:
      print(BS)
      print(torch.cuda.mem_get_info(device=0))
      break

In [6]:
fp16 = True 
input_dtype = torch.float16 if fp16 else torch.float32

scaler =  torch.cuda.amp.GradScaler(enabled=fp16)

model = crossvit_base_224().to(device)
input = torch.zeros((2,3,224,224), dtype=input_dtype, device=device)
loss_fn = torch.nn.BCEWithLogitsLoss()
labels = torch.ones((2,1), dtype=input_dtype, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)

warmup, reps = 30, 10

for i in range(0, warmup):
    
        with torch.cuda.amp.autocast(enabled = fp16):
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
) as prof:
    print(torch.cuda.mem_get_info(0))
    for i in range(0, reps):
        
        with torch.cuda.amp.autocast(enabled = fp16):
            
            logits, attn, q_cls, r_cls = model(query=input, reference=input)
            loss = loss_fn(logits, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        

(6465060864, 8589737984)


In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=25))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::reshape         5.00%     566.319ms        11.90%        1.346s     153.356us     544.024ms         4.84%        1.311s     149.324us           0 b           0 b     542.46 Mb           0 

In [16]:
# layers x bs x (q,ref) x cross_attn_depth x n_heads x 1 x tokens(inc cls)
attn[0][0][0].shape

torch.Size([2, 12, 1, 197])

In [20]:
print(len(attn))
print(len(attn[0]))
print(len(attn[0][0]))
attn[0][0][0].shape

3
2
2


torch.Size([2, 12, 1, 197])