### Preparation

In [None]:
!pip install facenet_pytorch
!pip install einops

In [1]:
from google.colab import drive
drive.mount('/content/drive')
import os
root_dir = "/content/drive/MyDrive" # Set appropriate directory
os.chdir(root_dir)

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/data.zip -d /content

In [5]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from glob import glob
from collections import defaultdict
import torchvision
from random import choice
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.optim.lr_scheduler import ReduceLROnPlateau
from facenet_pytorch import InceptionResnetV1
# from torch.nn import Parameter

### Config

In [6]:
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_file_path = "./data/train-relationships/train_relationships.csv"
    train_folders_path = "/content/data/train/"
    val_famillies = "F09"
    test_relationship_file = "/content/data/submissions/sample_submission.csv"
    batch_size = 64
    number_of_epochs = 100

    learning_rate = 0.0001

    MIN_NUM_PATCHES = 16
    pretrained_vits = '/content/drive/MyDrive/face_transformer/Backbone_VITs_Epoch_2_Batch_12000_Time_2021-03-17-04-05_checkpoint.pth'
    pretrained_vit = '/content/drive/MyDrive/face_transformer/Backbone_VIT_Epoch_2_Batch_20000_Time_2021-01-12-16-48_checkpoint.pth'

### Dataset

In [7]:
class KinDataset(Dataset):
    def __init__(self, relations, person_to_images_map, transform1, transform2):
        self.relations = relations
        self.transform1 = transform1
        self.transform2 = transform2
        self.person_to_images_map = person_to_images_map
        self.ppl = list(person_to_images_map.keys())

    def __len__(self):
        return len(self.relations)*2

    def __getitem__(self, idx):

        if (idx%2==0): #Positive samples
            p1, p2 = self.relations[idx//2]
            label = 1
        else:          #Negative samples
            while True:
                p1 = choice(self.ppl)
                p2 = choice(self.ppl)
                if p1 != p2 and (p1, p2) not in self.relations and (p2, p1) not in self.relations:
                    break
            label = 0

        path1, path2 = choice(self.person_to_images_map[p1]), choice(self.person_to_images_map[p2])
        im1, im2 = Image.open(path1), Image.open(path2)

        img1, img2 = self.transform1(im1), self.transform1(im2)
        img3, img4 = self.transform2(im1), self.transform2(im2)

        return img1, img2, img3, img4, label

In [8]:
print("Prepare data...")
all_images = glob(Config.train_folders_path + "*/*/*.jpg")

train_images = [x for x in all_images if Config.val_famillies not in x]
val_images = [x for x in all_images if Config.val_famillies in x]

train_person_to_images_map = defaultdict(list)

ppl = [x.split("/")[-3] + "/" + x.split("/")[-2] for x in all_images]

for x in train_images:
    train_person_to_images_map[x.split("/")[-3] + "/" + x.split("/")[-2]].append(x)

val_person_to_images_map = defaultdict(list)

for x in val_images:
    val_person_to_images_map[x.split("/")[-3] + "/" + x.split("/")[-2]].append(x)

relationships = pd.read_csv(Config.train_file_path)
relationships = list(zip(relationships.p1.values, relationships.p2.values))
relationships = [x for x in relationships if x[0] in ppl and x[1] in ppl]

train_relations = [x for x in relationships if Config.val_famillies not in x[0]]
val_relations  = [x for x in relationships if Config.val_famillies in x[0]]

train_transform = transforms.Compose([
    transforms.Resize(130),
    transforms.CenterCrop(112),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0., 0., 0.],
                         std=[1/255., 1/255., 1/255.])
])
train_transform2 = transforms.Compose([
    transforms.Resize(160),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])
val_transform = transforms.Compose([
    transforms.Resize(130),
    transforms.CenterCrop(112),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0., 0., 0.],
                         std=[1/255., 1/255., 1/255.])
])
val_transform2 = transforms.Compose([
    transforms.Resize(160),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])
trainset = KinDataset(train_relations, train_person_to_images_map, train_transform, train_transform2)
valset = KinDataset(val_relations, val_person_to_images_map, val_transform, val_transform2)

trainloader = DataLoader(trainset, batch_size=Config.batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=Config.batch_size, shuffle=False)

Prepare data...


### Model

In [9]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max
        #embed()
        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)

        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            #embed()
            x = ff(x)
        return x

class ViT_face(nn.Module):
    def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
        )
        self.loss_type = loss_type
        self.GPU_ID = GPU_ID
        if self.loss_type == 'None':
            print("no loss for vit_face")
        else:
            self.loss = CosFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID)

    def forward(self, img, label= None , mask = None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x, mask)

        last_hidden_state = x.detach()

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        emb = self.mlp_head(x)
        if label is not None:
            x = self.loss(emb, label)
            return x, emb
        else:
            return emb, last_hidden_state

class ViTs_face(nn.Module):
    def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, ac_patch_size,
                         pad, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * ac_patch_size ** 2
        assert num_patches > Config.MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size
        self.soft_split = nn.Unfold(kernel_size=(ac_patch_size, ac_patch_size), stride=(self.patch_size, self.patch_size), padding=(pad, pad))


        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
        )
        self.loss_type = loss_type
        self.GPU_ID = GPU_ID
        if self.loss_type == 'None':
            print("no loss for vit_face")

    def forward(self, img, label= None , mask = None, return_lhs=False):
        p = self.patch_size
        x = self.soft_split(img)
        x = x.transpose(1, 2)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x, mask)

        last_hidden_state = x.detach()
        # print('transformer_out', x.shape)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        emb = self.mlp_head(x)
        if label is not None:
            x = self.loss(emb, label)
            return x, emb
        elif return_lhs:
            return emb, last_hidden_state
        else:
            return emb

def build_encoder(name):
    if name == 'vits':
        model = ViTs_face(
            loss_type=None,
            GPU_ID=Config.device,
            num_class=93431,
            image_size=112,
            patch_size=8,
            ac_patch_size=12,
            pad=4,
            dim=512,
            depth=20,
            heads=8,
            mlp_dim=2048,
            dropout=0.1,
            emb_dropout=0.1
        )
        model.load_state_dict(torch.load(Config.pretrained_vits, map_location=Config.device), strict=False)
    else:
        model = ViT_face(
            image_size=112,
            patch_size=8,
            loss_type=None,
            GPU_ID=Config.device,
            num_class=93431,
            dim=512,
            depth=20,
            heads=8,
            mlp_dim=2048,
            dropout=0.1,
            emb_dropout=0.1
        )
        model.load_state_dict(torch.load(Config.pretrained_vit, map_location=Config.device), strict=False)
    return model

In [10]:
class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)

In [11]:
# resnet = InceptionResnetV1(pretrained='vggface2')
# cnn = torch.nn.Sequential(*(list(resnet.children())[:-5]))
# cnn(torch.randn(2, 3, 160, 160)).shape

In [12]:
class SiameseNet(nn.Module):
    def __init__(self, name):
        super().__init__()

        self.encoder = build_encoder(name)
        resnet = InceptionResnetV1(pretrained='vggface2')
        self.cnn = torch.nn.Sequential(*(list(resnet.children())[:-5]))

        self.attn_pool = AttentionPool2d(3, 1792, 1792//64, 512)

        self.embed_size=512
        self.last1 = nn.Sequential(
            nn.Linear(self.embed_size*2, 512),
            nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(512,1)
        )
        self.last2 = nn.Sequential(
            nn.Linear(self.embed_size*2, 512),
            nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(512,1)
        )

        self.weight = nn.Parameter(torch.randn(1))

        self.fc1 = nn.Sequential(
            nn.Linear(self.embed_size*3, self.embed_size*4),
            nn.ReLU(),
            nn.Dropout(0.01),
            nn.Linear(self.embed_size*4, self.embed_size*1),
            nn.ReLU(),
            nn.Dropout(0.01),
            nn.Linear(self.embed_size*1, self.embed_size//4),
            nn.ReLU(),
        )
        self.last1 = nn.Sequential(
            nn.Linear(self.embed_size//4+1,self.embed_size//16),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.embed_size//16, 1),
        )

        self.fc2 = nn.Sequential(
            nn.Linear(self.embed_size*3, self.embed_size*4),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(self.embed_size*4, self.embed_size*1),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(self.embed_size*1, self.embed_size//4),
            nn.ReLU(),
        )
        self.last2 = nn.Sequential(
            nn.Linear(self.embed_size//4+1,self.embed_size//16),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(self.embed_size//16, 1),
        )

        flag = False # train last MHA in vit
        for name, layer in self.encoder.named_parameters():
            if '19' in name:
                flag = True
            if flag:
                layer.requires_grad_(True)
            else:
                layer.requires_grad_(False)

    def forward(self, input1, input2, input3, input4):

        emb1 = self.encoder(input1)
        emb2 = self.encoder(input2)
        emb3 = self.attn_pool(self.cnn(input3))
        emb4 = self.attn_pool(self.cnn(input4))

        diff_v = emb1 - emb2
        ssq_v = torch.pow(emb1,2) + torch.pow(emb2,2)
        mul_v = emb1 * emb2
        x1 = torch.cat([diff_v, ssq_v, mul_v],dim=-1)
        x1 = self.fc1(x1)
        cos_dis1 = 1-F.cosine_similarity(emb1, emb2, dim=-1)
        x1 = torch.cat([x1, cos_dis1.unsqueeze(1)], dim=-1)
        res1 = self.last1(x1)

        diff_c = emb3 - emb4
        ssq_c = torch.pow(emb3,2) + torch.pow(emb4,2)
        mul_c = emb3 * emb4
        x2 = torch.cat([diff_c, ssq_c, mul_c],dim=-1)
        x2 = self.fc2(x2)
        cos_dis2 = 1-F.cosine_similarity(emb3, emb4, dim=-1)
        x2 = torch.cat([x2, cos_dis2.unsqueeze(1)], dim=-1)
        res2 = self.last2(x2)

        result = self.weight*res1 + (1-self.weight)*res2
        return torch.sigmoid(result)

### Train

In [None]:
def train(net, criterion, optimizer):
    net.train()
    train_loss = 0.0
    running_loss = 0.0
    running_corrects = 0

    for i, batch in enumerate(trainloader):
        optimizer.zero_grad()

        img1, img2, img3, img4, label = batch
        img1, img2, img3, img4, label = img1.to(Config.device), img2.to(Config.device), img3.to(Config.device), img4.to(Config.device), label.float().view(-1,1).to(Config.device)
        output = net(img1, img2, img3, img4)
        preds = output>0.5

        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        running_loss += loss.item()
        running_corrects += torch.sum(preds == (label>0.5))

    train_loss /= len(trainset)
    running_corrects = running_corrects.item()/len(trainset)
    print('[{}], \ttrain loss: {:.5}\tacc: {:.5}'.format(epoch+1, train_loss, running_corrects))
    return train_loss, running_corrects

In [None]:
def validate(net, criterion, optimizer):
    net.eval()
    val_loss = 0.0
    running_corrects = 0

    for batch in valloader:
        img1, img2, img3, img4, label = batch
        img1, img2, img3, img4, label = img1.to(Config.device), img2.to(Config.device), img3.to(Config.device), img4.to(Config.device), label.float().view(-1,1).to(Config.device)
        with torch.no_grad():
            output = net(img1, img2, img3, img4)
            preds = output>0.5
            loss = criterion(output, label)

        val_loss += loss.item()
        running_corrects += torch.sum(preds == (label>0.5))

    val_loss /= len(valset)
    running_corrects = running_corrects.item()/len(valset)
    print('[{}], \tval loss: {:.5}\tacc: {:.5}'.format(epoch+1, val_loss, running_corrects))

    return val_loss, running_corrects

In [None]:
def gen_trainable(model):
    layer = str(np.random.choice(20))
    for name, param in model.named_parameters():
        if 'encoder' in name:
            if layer in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(False)
        else:
            param.requires_grad_(True)

In [None]:
print("Initialize network...")
net = SiameseNet('vits').to(Config.device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=Config.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, patience=10)

Initialize network...


In [None]:
print("Start training...")

best_val_loss = 1000
best_val_acc = 0.0
best_epoch = 0

history = []
accuracy = []
for epoch in range(Config.number_of_epochs):
    if epoch%10==0:
        gen_trainable(net)
    train_loss, train_acc = train(net, criterion, optimizer)
    val_loss, val_acc = validate(net, criterion, optimizer)
    history.append((train_loss, val_loss))
    accuracy.append((train_acc,val_acc))
    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(net.state_dict(), './checkpoints/best_loss8.pth')
        print('saving...')
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(net.state_dict(), './checkpoints/best_acc8.pth')
        print('saving...')

Start training...
[1], 	train loss: 0.010175	acc: 0.62057
[1], 	val loss: 0.0096074	acc: 0.70608
saving...
saving...
[2], 	train loss: 0.0091058	acc: 0.6858
[2], 	val loss: 0.0089197	acc: 0.70946
saving...
saving...
[3], 	train loss: 0.0083708	acc: 0.72531
[3], 	val loss: 0.0091719	acc: 0.70946
[4], 	train loss: 0.0078724	acc: 0.74608
[4], 	val loss: 0.0086805	acc: 0.73649
saving...
saving...
[5], 	train loss: 0.0077151	acc: 0.76192
[5], 	val loss: 0.0085651	acc: 0.74831
saving...
saving...
[6], 	train loss: 0.0074115	acc: 0.76907
[6], 	val loss: 0.0086026	acc: 0.74155
[7], 	train loss: 0.0075149	acc: 0.76345
[7], 	val loss: 0.0079966	acc: 0.76182
saving...
saving...
[8], 	train loss: 0.0072673	acc: 0.77861
[8], 	val loss: 0.0076663	acc: 0.80068
saving...
saving...
[9], 	train loss: 0.0070035	acc: 0.78627
[9], 	val loss: 0.0075937	acc: 0.77703
saving...
[10], 	train loss: 0.0068324	acc: 0.78968
[10], 	val loss: 0.0086859	acc: 0.73142
[11], 	train loss: 0.0066079	acc: 0.80228
[11], 	val

### Submission

In [13]:
class FamilyTestDataset(Dataset):
    def __init__(self, relations, data_dir, transform, transform2):
        """
        Args:
            relations (string): Data frame with the image paths.
            data_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.relations = relations
        self.data_dir = data_dir
        self.transform = transform
        self.transform2 = transform2

    def __len__(self) -> int:
        return len(self.relations)

    def __getpair__(self, idx):
        pair = (
            os.path.join(self.data_dir, self.relations.iloc[idx, 0].split("-")[0]),
            os.path.join(self.data_dir, self.relations.iloc[idx, 0].split("-")[1]),
        )
        return pair

    def __getlabel__(self, idx) -> int:
        return self.relations.iloc[idx, 1]

    def __getitem__(self, idx):
        pair = self.__getpair__(idx)

        im1 = Image.open(pair[0])
        im2 = Image.open(pair[1])

        img1 = self.transform(im1)
        img2 = self.transform(im2)
        img3 = self.transform2(im1)
        img4 = self.transform2(im2)

        return idx, img1, img2, img3, img4
def create_test_dataloader(test_image_dir: str, test_relationship_file: str):
    df = pd.read_csv(test_relationship_file)

    transform = transforms.Compose([
        transforms.Resize(112),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0., 0., 0.],
                             std=[1/255., 1/255., 1/255.])
    ])
    transform2 = transforms.Compose([
        transforms.Resize(160),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])
    test_dataset = FamilyTestDataset(
        relations=df, data_dir=test_image_dir, transform=transform, transform2=transform2
    )

    test_loader = DataLoader(
        test_dataset,
        shuffle=True,
        batch_size=200,
    )

    return test_loader

In [15]:
def load_classifier(path_to_model_weights: str):
    model = SiameseNet('vits')
    model.load_state_dict(torch.load(path_to_model_weights))
    return model

def create_submission(path_to_template: str, path_to_save: str, predictions):
    template = pd.read_csv(path_to_template)

    # Remember to save as floats as metric is AUC
    for row, pred in predictions.items():
        template.loc[row, "is_related"] = float(pred)

    template.to_csv(path_or_buf=path_to_save, index=False)
    return


def test_classifier(classifier, test_loader):
    predictions = {}

    classifier.to(Config.device)
    classifier.eval()
    for i, data in enumerate(test_loader):
        row, img1, img2, img3, img4 = data
        row, img1, img2, img3, img4 = row.to(Config.device), img1.to(Config.device), img2.to(Config.device), img3.to(Config.device), img4.to(Config.device)

        with torch.no_grad():
            output = classifier(img1, img2, img3, img4)

        for j in range(len(row)):
            predictions[row[j].item()] = output[j].item()

    return predictions


if __name__ == "__main__":
    path_to_model_weights = "./checkpoints/best_loss8.pth"
    path_to_template = "./data/submissions/sample_submission.csv"
    path_to_save = "./data/submissions/loss8.csv"

    classifier = load_classifier(path_to_model_weights)

    test_loader = create_test_dataloader(
        "/content/data/test",
        "/content/data/submissions/sample_submission.csv"
    )

    predictions = test_classifier(
        classifier=classifier, test_loader=test_loader
    )
    print(len(predictions))

    create_submission(
        path_to_template=path_to_template,
        path_to_save=path_to_save,
        predictions=predictions,
    )

5310
