# 1. Dataset

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
import cv2
import os
from torchvision import transforms
import pandas as pd
from torch import nn
from torchvision.transforms import transforms

import random
from torch import randint


def seed_everything(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


class GaussianBlur(object):
    """blur a single image on CPU"""
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img
    
class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]
    
    
class MammoCompDataset(Dataset):
    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.ToPILImage(),
                                              transforms.RandomResizedCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * size)),
                                              transforms.ToTensor()])
        return data_transforms
    
    def __init__(self,
                 data_path="../VinDr_Mammo/physionet.org/files/vindr-mammo/1.0.0/images_png/",
                 metadata="../VinDr_Mammo/physionet.org/files/vindr-mammo/1.0.0/breast-level_annotations1.csv",
                 phase="train",
                 mode="binary_contrastive",
                 transform=None,
                 datalen=100,
                 certain=True,
                 seed=None):
        self.phase = phase
        self.datalen = datalen  # Number of image pairs for training/testing
        self.certain = certain
        self.mode = mode
        self.data_path = data_path
        if (seed):
            seed_everything(seed)

        self.transform = ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(224), 2)
        data = pd.read_csv(metadata)
        self.data = data.loc[data['split'] == phase].reset_index()
        self.birads = []
        for i in range(1, 6):
            self.birads.append(self.data.loc[self.data['breast_birads'] == f'BI-RADS {i}'])
        self.name_of_classes = [1, 2, 3, 4, 5]
        self.len_of_classes = [len(self.birads[i].index) for i in range(5)]
        self.paths1 = []
        self.paths2 = []
        self.listi1 = []
        self.listi2 = []
        self.complabels = []
        curlen = 0
        self.imagesinclass0 = self.birads[0]
        seed_everything(seed)
        while (curlen < self.datalen):
            if (mode == 'multiclass_contrastive'):
                if (random.randint(0, 1) == 0):
                    i1 = random.randint(0, 4)
                    i2 = random.randint(0, 4)
                else:
                    i1 = random.randint(0, 4)
                    i2 = i1
                # pickimageA = randint(0, lenofclass[random_pick_2class[0]], (1,))
                self.listi1.append(i1)
                self.listi2.append(i2)
                self.paths1.append(self.get_path(self.birads[i1], randint(0, self.len_of_classes[i1], (1,))[0]))
                self.paths2.append(self.get_path(self.birads[i2], randint(0, self.len_of_classes[i2], (1,))[0]))
                self.complabels.append((i1 == i2) * 1)
                curlen = curlen + 1
            elif mode == 'binary_contrastive':
                modee = random.randint(0, 3)
                if (modee == 0):
                    i1 = 0
                    i2 = 0
                elif (modee == 1):
                    i1 = random.randint(1, 4)
                    i2 = random.randint(1, 4)
                elif (modee == 2):
                    i1 = 0
                    i2 = random.randint(1, 4)
                else:
                    i2 = 0
                    i1 = random.randint(1, 4)
                # pickimageA = randint(0, lenofclass[random_pick_2class[0]], (1,))
                self.listi1.append(i1)
                self.listi2.append(i2)
                self.paths1.append(self.get_path(self.birads[i1], randint(0, self.len_of_classes[i1], (1,))[0]))
                self.paths2.append(self.get_path(self.birads[i2], randint(0, self.len_of_classes[i2], (1,))[0]))
                self.complabels.append((((i1 == 0) and (i2 == 0)) or ((i1 != 0) and (i2 != 0))) * 1)
                curlen = curlen + 1
            elif (mode == 'severity_comparison'):
                i1 = random.randint(1, 4)
                i2 = random.randint(1, 4)
                # pickimageA = randint(0, lenofclass[random_pick_2class[0]], (1,))
                self.listi1.append(i1)
                self.listi2.append(i2)
                self.paths1.append(self.get_path(self.birads[i1], randint(0, self.len_of_classes[i1], (1,))[0]))
                self.paths2.append(self.get_path(self.birads[i2], randint(0, self.len_of_classes[i2], (1,))[0]))
                self.complabels.append(((i1 > i2)) * 1)
                curlen = curlen + 1
            elif (mode == 'preference_contrastive'):
                if (random.randint(0, 5) > 4):
                    i1 = random.randint(1, 4)
                    i2 = i1
                else:
                    i1 = random.randint(1, 4)
                    i2 = random.randint(1, 4)
                # pickimageA = randint(0, lenofclass[random_pick_2class[0]], (1,))
                self.listi1.append(i1)
                self.listi2.append(i2)
                self.paths1.append(self.get_path(self.birads[i1], randint(0, self.len_of_classes[i1], (1,))[0]))
                self.paths2.append(self.get_path(self.birads[i2], randint(0, self.len_of_classes[i2], (1,))[0]))
                self.complabels.append((((i1 > i2)) * 1) if (i1 != i2) else 2)
                curlen = curlen + 1
            else:
                assert False, f"No mode {mode} found, please try multiclass_contrastive or binary_contrastive"

    def get_score(self, data, index):
        birads = data['breast_birads'].iloc[index.item()]
        score = eval(birads[-1])
        return score

    def get_path(self, data, index):

        image_name = data['image_id'].iloc[index.item()]
        study_id = data['study_id'].iloc[index.item()]
        image_path = os.path.join(self.data_path, study_id + '/' + image_name + '.png')
        return (image_path)

    def __getitem__(self, index):
        imageA = cv2.imread(self.paths1[index])
        imageB = cv2.imread(self.paths2[index])

        label = self.complabels[index]

        imageA = self.transform(imageA)
        imageB = self.transform(imageB)
        if self.mode == 'severity_comparison':
            ref_img = self.get_ref_images()
            return (imageA, imageB), ref_img, label, (self.listi1[index], self.listi2[index])
        else:
            return (imageA, imageB), label, (self.listi1[index], self.listi2[index])

    def get_ref_images(self):
        ref_img = self.get_path(self.imagesinclass0, randint(0, len(self.imagesinclass0), (1,))[0])
        ref_img = cv2.imread(ref_img)
#         ref_img = self.transform(ref_img)
        ref_img = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])(ref_img)

        return ref_img

    def __len__(self):
        return self.datalen
    

# 2. Model

## 2.1 - ViT Encoder

In [None]:
# --------------------------------------------------------
# SimMIM
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on BEIT code bases (https://github.com/microsoft/unilm/tree/master/beit)
# Written by Yutong Lin, Zhenda Xie
# --------------------------------------------------------

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # comment out this for the orignal BERT implement
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0., window_size=None, attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        if window_size:
            self.window_size = window_size
            # cls to token & token to cls & cls to cls
            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(window_size[0])
            coords_w = torch.arange(window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
            relative_position_index = \
                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            relative_position_index[0, 0:] = self.num_relative_distance - 3
            relative_position_index[0:, 0] = self.num_relative_distance - 2
            relative_position_index[0, 0] = self.num_relative_distance - 1

            self.register_buffer("relative_position_index", relative_position_index)
        else:
            self.window_size = None
            self.relative_position_bias_table = None
            self.relative_position_index = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, rel_pos_bias=None):
        B, N, C = x.shape
        qkv_bias = None
        if self.q_bias is not None:
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        if self.relative_position_bias_table is not None:
            relative_position_bias = \
                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                    self.window_size[0] * self.window_size[1] + 1,
                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)

        if rel_pos_bias is not None:
            attn = attn + rel_pos_bias

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, rel_pos_bias=None):
        if self.gamma_1 is None:
            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class RelativePositionBias(nn.Module):

    def __init__(self, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # cls to token & token 2 cls & cls to cls

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = \
            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = self.num_relative_distance - 3
        relative_position_index[0:, 0] = self.num_relative_distance - 2
        relative_position_index[0, 0] = self.num_relative_distance - 1

        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self):
        relative_position_bias = \
            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1] + 1,
                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww


class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
                 use_mean_pooling=True, init_scale=0.001):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.in_chans = in_chans

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        if use_abs_pos_emb:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        else:
            self.pos_embed = None
        self.pos_drop = nn.Dropout(p=drop_rate)

        if use_shared_rel_pos_bias:
            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
        else:
            self.rel_pos_bias = None

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.use_rel_pos_bias = use_rel_pos_bias
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
            for i in range(depth)])
        self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
        self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None

        if self.pos_embed is not None:
            self._trunc_normal_(self.pos_embed, std=.02)
        self._trunc_normal_(self.cls_token, std=.02)

        self.apply(self._init_weights)
        self.fix_init_weight()

    def _trunc_normal_(self, tensor, mean=0., std=1.):
        trunc_normal_(tensor, mean=mean, std=std)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            self._trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            self._trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def get_num_layers(self):
        return len(self.blocks)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        batch_size, seq_len, _ = x.size()

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias=rel_pos_bias)

        x = self.norm(x)
        if self.fc_norm is not None:
            t = x[:, 1:, :]
            return self.fc_norm(t.mean(1))
        else:
            return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        return x


def build_vit(config):
    model = VisionTransformer(
        img_size=config.DATA.IMG_SIZE,
        patch_size=config.MODEL.VIT.PATCH_SIZE,
        in_chans=config.MODEL.VIT.IN_CHANS,
        embed_dim=config.MODEL.VIT.EMBED_DIM,
        depth=config.MODEL.VIT.DEPTH,
        num_heads=config.MODEL.VIT.NUM_HEADS,
        mlp_ratio=config.MODEL.VIT.MLP_RATIO,
        qkv_bias=config.MODEL.VIT.QKV_BIAS,
        drop_rate=config.MODEL.DROP_RATE,
        drop_path_rate=config.MODEL.DROP_PATH_RATE,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        init_values=config.MODEL.VIT.INIT_VALUES,
        use_abs_pos_emb=config.MODEL.VIT.USE_APE,
        use_rel_pos_bias=config.MODEL.VIT.USE_RPB,
        use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB,
        use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING)

    return model



## 2.2 - Model

In [None]:
import torch
import torch.nn as nn

class CombinedModel(nn.Module):

    def __init__(self, base_model, config, out_dim):
        super(CombinedModel, self).__init__()
        self.model_dict = {"vit": build_vit(config),
                            "swin": None}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.embed_dim

        # add mlp projection head
        self.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                nn.ReLU(),
                                torch.nn.Dropout(0.1),
                                nn.Linear(dim_mlp, 512),
                                nn.ReLU(),
                                torch.nn.Dropout(0.1),
                                nn.Linear(512, out_dim)
                                )

    def _get_basemodel(self, model_name):
        try:
            model = self.model_dict[model_name]
        except Exception:
            print("Invalid backbone architecture. Check the config file and pass one of: Vit or Swin Transformer")
        else:
            return model

    def forward(self, x):
        out = self.backbone(x)
        return self.fc(out)


In [None]:
class SimCLRModelPipeline(nn.Module):
    def __init__(self, encoder):
        super(SimCLRModelPipeline, self).__init__()
        # note that model requires 3 input channels, will repeat grayscale image x3
        self.encoder = encoder
        
    def forward_once(self, x):
        output = self.encoder(x)
        return output

    def forward(self, input1a, input1b, input2a, input2b, ref):
        output1a = self.forward_once(input1a)
        output1b = self.forward_once(input1b)
        output2a = self.forward_once(input2a)
        output2b = self.forward_once(input2b)
        ref_output = self.forward_once(ref)
        return output1a, output1b, output2a, output2b, ref_output

# 3. Loss function

## 3.1 - NT_Xent

In [None]:
import torch
import torch.nn as nn


class NT_Xent(nn.Module):
    """
    The normalized temperature-scaled cross entropy loss
    """
    def __init__(self, batch_size, temperature, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.mask = self.mask_correlated_samples(batch_size)
        self.device = device

        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        mask = torch.ones((batch_size * 2, batch_size * 2), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1)
        augmented examples within a minibatch as negative examples.
        """
        # doc: all the comments underneath are to be considered for a batch size of 128 unless specified otherwise
        p1 = torch.cat((z_i, z_j), dim=0)

        # doc: here the cosine similarity dim is 2. This works a bit differently from dimension-wise sum for example.
        # p1.shape = [256, 1, 64] and p2.shape = [1, 256, 64], when finding cosine similarity the first two dimensions
        # are iterated while taking the whole vector from the third dimension
        sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature

        # doc: suppose index for, p1 = [1, 2, 3, 4] where z_i = [1, 2] and z_j = [3, 4] and batch size = 2
        # then the similarity matrix will look like (in terms of indexes)
        # [11, 12, 13, 14]
        # [21, 22, 23, 24]
        # [31, 32, 33, 34]
        # [41, 42, 43, 44]
        # then torch.diag(sim, 2) = [13, 24] and torch.diag(sim, -2) = [31, 42] hence the positive samples
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        # doc: concatenate the positive samples
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
            self.batch_size * 2, 1
        )

        # doc: here the self.mask filters out the main diagonals which constitute the same samples
        # and also the minor diagonals of batch size and -batch size (look above)
        negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)

        labels = torch.zeros(self.batch_size * 2).to(self.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)

        # doc: normalize the loss i.e. 1/2N
        loss /= 2 * self.batch_size
        return loss

## 3.2 - ConPro Loss

In [None]:
from torch import nn

class PreferenceComparisonLoss(nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    Modified from: https://hackernoon.com/facial-similarity-with-siamese-networks-in-pytorch-9642aa9db2f7

    """ 

    def __init__(self, margin=2.0):
        super(PreferenceComparisonLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label, ref):
        # euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
        # loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
        #                               (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        
        cosine_distanceA = nn.functional.cosine_similarity(output1, ref)
        cosine_distanceB = nn.functional.cosine_similarity(output2, ref)
        loss_comparation = nn.NLLLoss()(nn.Sigmoid()(cosine_distanceA - cosine_distanceB), label)

        return loss_comparation

# 4. Train pipeline

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
import os


def train_model(model, train_dataset, val_dataset, checkpoint_folder, num_epochs=10, batch_size=32,
                learning_rate=0.001):
    """
    Train the model using the provided datasets.

    Args:
    - model: The model to be trained
    - train_dataset: Dataset for training
    - val_dataset: Dataset for validation
    - checkpoint_folder: Folder to store checkpoints
    - num_epochs: Number of epochs for training
    - batch_size: Batch size for training
    - learning_rate: Learning rate for optimization

    Returns:
    - model: Trained model
    - train_losses: List of training losses
    - val_losses: List of validation losses
    """
    # Create the checkpoint folder if it doesn't exist
    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    # Define data loaders for training and validation
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    # Define loss function and optimizer
    SimCLR_criterion = NT_Xent(batch_size, 0.07, device)
    ConPro_criterion = PreferenceComparisonLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    # Lists to store training and validation losses
    train_losses = []
    val_losses = []

    # Variables to keep track of the best model and its performance
    best_val_loss = float('inf')
    best_model_state = None

    model = model.to(device)
    print("Training started...")
    for epoch in range(num_epochs):
        torch.cuda.empty_cache()
        print("*" * 100)
        print(f"Epoch [{epoch + 1}/{num_epochs}]:")
        model.train()
        running_train_loss = 0.0
        for i, (inputs,ref, labels, _) in enumerate(train_loader):
            optimizer.zero_grad()
            # Forward pass
            inputAa = inputs[0][0].to(device)
            inputAb = inputs[0][1].to(device)
            inputBa = inputs[1][0].to(device)
            inputBb = inputs[1][1].to(device)
            ref = ref.to(device)
            labels = labels.to(device)
            output1a, output1b, output2a, output2b, ref_output = model(inputAa, inputAb, inputBa, inputBb, ref)
            # Compute loss
            loss = SimCLR_criterion(output1a, output1b) + SimCLR_criterion(output1a, output2a) + SimCLR_criterion(output1a, output2b) + SimCLR_criterion(output1b, output2a) + SimCLR_criterion(output1b, output2b)
            loss -= (ConPro_criterion(output1a, output2a, labels, ref_output) + ConPro_criterion(output1a, output2b, labels, ref_output) + ConPro_criterion(output1b, output2a, labels, ref_output) + ConPro_criterion(output1b, output2b, labels, ref_output))
            # Backward pass
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()

            if i % 200 == 0:
                print(f"\t Batch [{i}/{len(train_loader)}], Train Loss: {loss.item():.4f}")

        # Compute average training loss for the epoch
        epoch_train_loss = running_train_loss / len(train_loader)
        train_losses.append(epoch_train_loss)

        # Validation loop
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for i, (inputs,ref, labels, _) in enumerate(val_loader):
                inputAa = inputs[0][0].to(device)
                inputAb = inputs[0][1].to(device)
                inputBa = inputs[1][0].to(device)
                inputBb = inputs[1][1].to(device)
                ref = ref.to(device)
                labels = labels.to(device)
                output1a, output1b, output2a, output2b, ref_output = model(inputAa, inputAb, inputBa, inputBb, ref)
                # Compute loss
                loss = SimCLR_criterion(output1a, output1b) + SimCLR_criterion(output1a, output2a) + SimCLR_criterion(output1a, output2b)  + SimCLR_criterion(output1b, output2a) + SimCLR_criterion(output1b, output2b)
                loss -= (ConPro_criterion(output1a, output2a, labels, ref_output) + ConPro_criterion(output1a, output2b, labels, ref_output) + ConPro_criterion(output1b, output2a, labels, ref_output) + ConPro_criterion(output1b, output2b, labels, ref_output))
                loss /= 9
                running_val_loss += loss.item()

                if i % 100 == 0:
                    print(
                        f"Epoch [{epoch + 1}/{num_epochs}], Validation Batch [{i}/{len(val_loader)}], Val Loss: {loss.item():.4f}")

        # Compute average validation loss for the epoch
        epoch_val_loss = running_val_loss / len(val_loader)
        val_losses.append(epoch_val_loss)

        # Save the model checkpoint for every epoch (last model)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': epoch_val_loss
        }, os.path.join(checkpoint_folder, f'last.pt'))

        # Save the best model checkpoint based on validation loss
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_state = model.state_dict()
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss
            }, os.path.join(checkpoint_folder, f'best.pt'))

        # Print progress
        print(f"Validation, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
        print("*" * 100)
        scheduler.step()
    print("Training completed.")

    return model, train_losses, val_losses


In [None]:
import datetime
now = datetime.datetime.now()

config = {
    "annotation_data_path": "/kaggle/input/mammo-224-224-ver2/split_data.csv",
    "image_folder_path": "/kaggle/input/mammo-224-224-ver2/Processed_Images",
    "model_encoder": "vit",
    "data_length": 50000,
    "embedding_dim": 256, 
    "learning_rate":0.1,
    "num_epoch": 50,
    "batch_size": 16,
    "model_config": "/kaggle/input/vit-config/simmim_pretrain__vit_base__img224__800ep.yaml",
    "checkpoint": "/kaggle/input/pretrained-vit-encoder-model/pytorch/vit-meta-research-pretrained/1/vit_base_image224_800ep.pt",
    "checkpoint_folder": f"/kaggle/working/ViTBasedModel_{now}"
}

class AttrDict(dict):
    """A dictionary that allows for attribute-style access."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for key, value in self.items():
            if isinstance(value, dict):
                value = AttrDict(value)
            self[key] = value

    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError:
            raise AttributeError(f"'AttrDict' object has no attribute '{item}'")


In [None]:
import yaml

train_dataset = MammoCompDataset(data_path = config["image_folder_path"],
                                metadata = config["annotation_data_path"],
                                phase = "training",
                                mode = "severity_comparison",
                                datalen = config["data_length"],
                                seed=0)
valid_dataset = MammoCompDataset(data_path = config["image_folder_path"],
                                metadata = config["annotation_data_path"],
                                phase = "valid",
                                mode = "severity_comparison",
                                datalen = config["data_length"]//10,
                                seed=0)
with open(config["model_config"], 'r') as file:
    data = yaml.safe_load(file)
model_config = AttrDict(data)
encoder = CombinedModel("vit", model_config, out_dim=config["embedding_dim"])

if config["checkpoint"]:
    checkpoint = torch.load(config["checkpoint"])
    print("Checkpoint: {}".format(config["checkpoint"]))
    encoder.backbone.load_state_dict(checkpoint)
SimCLR_model = SimCLRModelPipeline(encoder)

train_model(model=SimCLR_model, train_dataset=train_dataset,
            val_dataset=valid_dataset, num_epochs=config["num_epoch"],
            batch_size=config["batch_size"], learning_rate=config["learning_rate"],
            checkpoint_folder=config["checkpoint_folder"]
            )