In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install timm

In [None]:
!gdown --folder https://drive.google.com/drive/folders/115l82GBgu6RETopB_Qy36hn4TVwXI2t5?usp=sharing

In [None]:
!mkdir "/content/Datasets"
!mkdir "/content/Datasets/images"
!mkdir "/content/Datasets/images/train"
!mkdir "/content/Datasets/images/val"
!mkdir "/content/Datasets/images/test"
!mkdir "/content/Datasets/labels"
!mkdir "/content/Datasets/labels/val"
!mkdir "/content/Datasets/labels/train"

!unzip "/content/Advanced/CV/Train.zip" -d "/content/Datasets/images/train"
!unzip "/content/Advanced/CV/Validation.zip" -d "/content/Datasets/images/val"
!unzip "/content/Advanced/CV/Test.zip" -d "/content/Datasets/images/test"

!unzip "/content/Advanced/CV/train_labels.zip" -d "/content/Datasets/labels/train"
!unzip "/content/Advanced/CV/val_labels.zip" -d "/content/Datasets/labels/val"

In [None]:
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm 
import random
import torch
import pandas as pd

ROOT = r"." # Change this as required - tmp location
train_images_path = "/content/Datasets/images/train" 
train_labels_path = "/content/Datasets/labels/train" 
val_images_path = "/content/Datasets/images/val"
val_labels_path = "/content/Datasets/labels/val"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# **Generate Folders for Train Crops and Validation Crops**

In [None]:
import os

train_crops_path = ROOT + "/REID_Data/train" # Change this as required
num_plushies = 200
try:
    os.mkdir(train_crops_path)
except:
    pass

for i in range(num_plushies):
    os.mkdir(train_crops_path + f"/{i}")

val_crops_path = ROOT + "/REID_Data/val" # Change this as required
num_plushies = 10

try:
    os.mkdir(val_crops_path)
except:
    pass

for i in range(num_plushies):
    os.mkdir(val_crops_path + f"/{i}")

# **Generate Train and Validation Crops**

In [None]:
import cv2
import numpy as np
import pandas as pd
import os

images_path = train_images_path
labels_path = train_labels_path
annotated_images_path = train_crops_path


for label_name in tqdm(os.listdir(labels_path)):
    if label_name[-4:] != ".txt":
        continue
    image_name = label_name[:-4] + ".png"
    # print("Checking", image_name)

    image_path = os.path.join(images_path, image_name)
    label_path = os.path.join(labels_path, label_name)

    df = pd.read_csv(label_path, delim_whitespace=True, header=None)
    df.columns = ["cat", "xc", "yc", "w", "h"]

    img = cv2.imread(image_path)
    img_h, img_w = img.shape[:2]

    for i in range(df.shape[0]):
        bb = df.iloc[i]
        cat = str(int(bb["cat"]))
        tl = (int((bb["xc"] - bb["w"]/2) * img_w), int((bb["yc"] - bb["h"]/2) * img_h))
        br = (int((bb["xc"] + bb["w"]/2) * img_w), int((bb["yc"] + bb["h"]/2) * img_h))

        cropped_img = img[tl[1]:br[1], tl[0]:br[0]]
        annotated_img_name = f"{cat}_{len(os.listdir(os.path.join(annotated_images_path, cat)))}.png"
        cv2.imwrite(os.path.join(annotated_images_path, cat, annotated_img_name), cropped_img)

In [None]:
images_path = val_images_path
labels_path = val_labels_path
annotated_images_path = val_crops_path


for label_name in tqdm(os.listdir(labels_path)):
    if label_name[-4:] != ".txt":
        continue
    image_name = label_name[:-4] + ".png"
#     print("Checking", image_name)

    image_path = os.path.join(images_path, image_name)
    label_path = os.path.join(labels_path, label_name)

    df = pd.read_csv(label_path, delim_whitespace=True, header=None)
    df.columns = ["cat", "xc", "yc", "w", "h"]

    img = cv2.imread(image_path)
    img_h, img_w = img.shape[:2]

    for i in range(df.shape[0]):
        bb = df.iloc[i]
        cat = str(int(bb["cat"]))
        tl = (int((bb["xc"] - bb["w"]/2) * img_w), int((bb["yc"] - bb["h"]/2) * img_h))
        br = (int((bb["xc"] + bb["w"]/2) * img_w), int((bb["yc"] + bb["h"]/2) * img_h))

        cropped_img = img[tl[1]:br[1], tl[0]:br[0]]
        annotated_img_name = f"{cat}_{len(os.listdir(os.path.join(annotated_images_path, cat)))}.png"
        cv2.imwrite(os.path.join(annotated_images_path, cat, annotated_img_name), cropped_img)

# **Generate Train Matched and Non Matched Pairs**

In [None]:
counts = {}

crops_path = train_crops_path
num_plushies = 200
print("Train Set")
for i in range(num_plushies):
    counts[i] = len(os.listdir(os.path.join(crops_path, str(i))))
print(counts)

def generate_matched_pairs(count):
    matched_df = pd.DataFrame(columns = ["img1", "img2"])
    matched_dict = {k: [] for k in counts.keys()}
    
    for i in tqdm(range(count)):
        plushie = random.choice(list(counts.keys()))    
        num1, num2 = random.randrange(counts[plushie]), random.randrange(counts[plushie])
        
        while set([num1, num2]) in matched_dict[plushie]:
            plushie = random.choice(list(counts.keys()))
            num1, num2 = random.randrange(counts[plushie]), random.randrange(counts[plushie])
        
        matched_dict[plushie].append(set([num1, num2]))
        matched_df.loc[i] = [f"{plushie}/{plushie}_{num1}.png", f"{plushie}/{plushie}_{num2}.png"]
    
    return matched_df
        

def generate_non_matched_pairs(count):
    df = pd.DataFrame(columns = ["img1", "img2"])
    
    for i in tqdm(range(count)):
        plushie1 = random.choice(list(counts.keys()))
        plushie2 = random.choice(list(counts.keys()))
        while plushie1 == plushie2:
            plushie1 = random.choice(list(counts.keys()))
            plushie2 = random.choice(list(counts.keys()))
            
        num1, num2 = random.randrange(counts[plushie1]), random.randrange(counts[plushie2])
        
        if len(df[df['img1'] == f"{plushie1}/{plushie1}_{num1}.png"]) > 0:
            while f"{plushie2}/{plushie2}_{num2}.png" in df[df['img1'] == f"{plushie1}/{plushie1}_{num1}.png"]['img2'].values:
                num1, num2 = random.randrange(counts[plushie1]), random.randrange(counts[plushie2])
        
        df.loc[i] = [f"{plushie1}/{plushie1}_{num1}.png", f"{plushie2}/{plushie2}_{num2}.png"]
    
    
    return df


matched_df = generate_matched_pairs(30000)
matched_df.to_csv(ROOT + "/matched_pairs_train.csv",index=False)
   
non_matched_df = generate_non_matched_pairs(30000)
non_matched_df.to_csv(ROOT + "/non_matched_pairs_train.csv", index = False)


# **Generate Validation Matched and Non Matched Pairs**

In [None]:
counts = {}
print("Validation Set")
crops_path = val_crops_path
num_plushies = 10

for i in range(num_plushies):
    counts[i] = len(os.listdir(os.path.join(crops_path, str(i))))
print(counts)

matched_df = generate_matched_pairs(6000)
matched_df.to_csv(ROOT + "/matched_pairs_val.csv",index=False)

non_matched_df = generate_non_matched_pairs(6000)
non_matched_df.to_csv(ROOT + "/non_matched_pairs_val.csv",index=False)

In [None]:
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm
from torch.optim.lr_scheduler import LinearLR
from torch import nn, optim

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def load_data(train_ds_path, test_ds_path):
    
    train_ds = torch.load(train_ds_path).to(device)
    val_ds = torch.load(test_ds_path).to(device)
    
    # 1 - Matched, 0 - Non-Matched
    train_label = torch.zeros(size = (train_ds.shape[0], 1)).to(device)
    for i in range(train_ds.shape[0]//2):
        train_label[i] = torch.tensor([1])
        
    
    val_label = torch.zeros(size = (val_ds.shape[0], 1)).to(device)
    for i in range(val_ds.shape[0]//2):
        val_label[i] = torch.tensor([1])
        
    print(f"X Train: {train_ds.shape}")
    print(f"X Val: {val_ds.shape}")
    print(f"y Train: {train_label.shape}")
    print(f"y Val: {val_label.shape}")
    

    return train_ds, train_label, val_ds, val_label

In [None]:
# Reference: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
class EarlyStopper_Checkpoint():
    def __init__(self, patience=1, save_path = None, min_delta=0, metric = "val_loss"):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.save_path = save_path
        self.metric = metric
        if metric == "val_loss":
            self.min_validation_loss = np.inf
        elif metric == "val_acc":
            self.max_validation_accuracy = 0

    def early_stop(self, metric):
        if self.metric == "val_loss":
            if self.check_metric(metric):
                self.min_validation_loss = metric
                self.counter = 0
            elif metric > (self.min_validation_loss + self.min_delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return True
            return False
        elif self.metric == "val_acc":
            if self.check_metric(metric):
                self.max_validation_accuracy = metric
                self.counter = 0
            elif metric < (self.max_validation_accuracy + self.min_delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return True
            return False
            
    def check_metric(self, metric):
        if self.metric == "val_loss":
            if metric < self.min_validation_loss:
                return True
            else:
                return False
        elif self.metric == "val_acc":
            if metric > self.max_validation_accuracy:
                return True
            else:
                return False
    
class NN_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NN_Classifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Dropout(p = 0.5),
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Dropout(p = 0.5),
            nn.Linear(hidden_size, output_size)
        )
    
    def forward(self, x):
        output = self.linear_relu_stack(x)
        return output

def ema_loss(cur_loss, prev_loss):
    loss = 0.9 * prev_loss + 0.1 * cur_loss
    return loss


def signed_sqrt(x1, x2):
    return torch.sign(x1*x2) * torch.sqrt(torch.abs(x1*x2))
    
def combine_function(x1, x2):
    return torch.concatenate([x1 + x2, x1 - x2, x2 -x1, x1**2 + x2**2, x1*x2, signed_sqrt(x1,x2)], axis = -1)


# **Convert images into RESNet/SENet/ViT features and save them**

In [None]:
from PIL import Image
from torchvision import transforms
import cv2
from transformers import  AutoImageProcessor, ResNetModel, ViTImageProcessor, ViTModel
import torch
import urllib
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

class BGR2RGB:
    def __call__(self, image):
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
class SquarePad:
    def __call__(self, image):
        max_wh = max(image.shape[:2])
        p_top, p_left = [(max_wh - s) // 2 for s in image.shape[:2]]
        p_bottom, p_right = [max_wh - (s+pad) for s, pad in zip(image.shape[:2], [p_top,p_left])]
        return cv2.copyMakeBorder(image, p_top, p_bottom, p_left, p_right, cv2.BORDER_CONSTANT, None, value = 0)

def load_feature_extractor(model_name):
    if model_name == "resnet":
        model = ResNetModel.from_pretrained("microsoft/resnet-50").eval().to(device)
        transform = transforms.Compose([BGR2RGB(),
            SquarePad(),
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return model, transform 

    elif model_name == "vit":
        processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
        model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').eval().to(device)
        return model, processor

    elif model_name == "senet":
        model = timm.create_model('seresnet152d', pretrained=True,num_classes = 0).eval().to(device)
        config = resolve_data_config({}, model=model)
        transform = create_transform(**config)
        return model, transform

In [None]:
model, processor = load_feature_extractor("resnet")

In [None]:
resnet_train_path = ROOT + "/REID_Data/resnet_train"

try:
    os.mkdir(resnet_train_path)
except:
    print("resnet_train folder exists")

for plushie in os.listdir(train_crops_path):
    try:
        os.mkdir(ROOT + f"/REID_Data/resnet_train/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(train_crops_path)):
    for img_name in os.listdir(train_crops_path + f"/{plushie}"):
        input_image = cv2.imread(train_crops_path + f"/{plushie}/{img_name}")
        input_ = processor(input_image).reshape(1, 3, 224, 224)
        with torch.no_grad():
            output_ = model(input_, )['pooler_output'].flatten()
            torch.save(output_, resnet_train_path + f"/{plushie}/{img_name[:-4]}.pt")
        

In [None]:
resnet_val_path = ROOT + "/REID_Data/resnet_val"

try:
    os.mkdir(resnet_val_path)
except:
    print("resnet_val folder exists")

for plushie in os.listdir(val_crops_path):
    try:
        os.mkdir(resnet_val_path + f"/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(val_crops_path)):
    for img_name in os.listdir(val_crops_path + f"/{plushie}"):
        input_image = cv2.imread(val_crops_path + f"/{plushie}/{img_name}")
        input_ = processor(input_image).reshape(1, 3, 224, 224)
        with torch.no_grad():
            output_ = model(input_, )['pooler_output'].flatten()
            torch.save(output_, resnet_val_path + f"/{plushie}/{img_name[:-4]}.pt")
   

In [None]:
model, processor = load_feature_extractor("vit")

In [None]:
vit_train_path = ROOT + "/REID_Data/vit_train"

try:
    os.mkdir(vit_train_path)
except:
    print("vit_train folder exists")

for plushie in os.listdir(train_crops_path):
    try:
        os.mkdir(ROOT + f"/REID_Data/vit_train/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(train_crops_path)):
    for img_name in os.listdir(train_crops_path + f"/{plushie}"):
        input_image = cv2.imread(train_crops_path + f"/{plushie}/{img_name}")
        input_ = processor(BGR2RGB()(input_image), return_tensors = "pt")['pixel_values'][0].reshape(1, 3, 224, 224)
        with torch.no_grad():
            output_ = model(input_, )['pooler_output'].flatten()
            torch.save(output_, vit_train_path + f"/{plushie}/{img_name[:-4]}.pt")
        

In [None]:
vit_val_path = ROOT + "/REID_Data/vit_val"

try:
    os.mkdir(vit_val_path)
except:
    print("vit_val folder exists")

for plushie in os.listdir(val_crops_path):
    try:
        os.mkdir(vit_val_path + f"/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(val_crops_path)):
    for img_name in os.listdir(val_crops_path + f"/{plushie}"):
        input_image = cv2.imread(val_crops_path + f"/{plushie}/{img_name}")
        input_ = processor(BGR2RGB()(input_image), return_tensors = "pt")['pixel_values'][0].reshape(1, 3, 224, 224)
        with torch.no_grad():
            output_ = model(input_, )['pooler_output'].flatten()
            torch.save(output_, vit_val_path + f"/{plushie}/{img_name[:-4]}.pt")
   
    

In [None]:
model, processor = load_feature_extractor("senet")

In [None]:
senet_train_path = ROOT + "/REID_Data/senet_train"
try:
    os.mkdir(senet_train_path)
except:
    print("senet_train folder exists")

for plushie in os.listdir(train_crops_path):
    try:
        os.mkdir(ROOT + f"/REID_Data/senet_train/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(train_crops_path)):
    for img_name in os.listdir(train_crops_path + f"/{plushie}"):
        input_image = Image.open(train_crops_path + f"/{plushie}/{img_name}").convert('RGB')
        input_ = processor(input_image).unsqueeze(0)
        with torch.no_grad():
            output_ = model(input_).flatten()
            torch.save(output_, senet_train_path + f"/{plushie}/{img_name[:-4]}.pt")
        

In [None]:
senet_val_path = ROOT + "/REID_Data/senet_val"
try:
    os.mkdir(senet_val_path)
except:
    print("senet_val folder exists")

for plushie in os.listdir(val_crops_path):
    try:
        os.mkdir(ROOT + f"/REID_Data/senet_val/{plushie}")
    except:
        pass

for plushie in tqdm(os.listdir(val_crops_path)):
    for img_name in os.listdir(val_crops_path + f"/{plushie}"):
        input_image = Image.open(val_crops_path + f"/{plushie}/{img_name}").convert('RGB')
        input_ = processor(input_image).unsqueeze(0)
        with torch.no_grad():
            output_ = model(input_).flatten()
            torch.save(output_, senet_val_path + f"/{plushie}/{img_name[:-4]}.pt")
        

# **Generate Train and Validation Dataset Tensor**
This is so that it will be faster to load the dataset for training the neural network later on. 

In [None]:
train_csv_matched = pd.read_csv(ROOT + "/matched_pairs_train.csv")
train_csv_non_matched = pd.read_csv(ROOT + "/non_matched_pairs_train.csv")

val_csv_matched =  pd.read_csv(ROOT + "/matched_pairs_val.csv")
val_csv_non_matched =  pd.read_csv(ROOT + "/non_matched_pairs_val.csv")

def save_dataset_as_tensor(csv_matched, csv_non_matched, path, total_size, dim):
    dataset = torch.empty(size = (total_size, 2, dim))
    
    for i in tqdm(range(0 , len(csv_matched))):
        img1 = torch.load(path + f"/{csv_matched['img1'][i][:-4]}.pt", map_location=device).reshape(1, -1)
        img2 = torch.load(path + f"/{csv_matched['img2'][i][:-4]}.pt", map_location=device).reshape(1, -1)
        input_ = torch.concat([img1, img2], dim = 0).reshape(1, 2, -1)
        dataset[i] = input_
    
    for i in tqdm(range(0, len(csv_non_matched))):
        img1 = torch.load(path + f"/{csv_non_matched['img1'][i][:-4]}.pt", map_location=device).reshape(1, -1)
        img2 = torch.load(path + f"/{csv_non_matched['img2'][i][:-4]}.pt", map_location=device).reshape(1, -1)
        input_ = torch.concat([img1, img2], dim = 0).reshape(1, 2, -1)
        dataset[len(csv_matched) + i] = input_
    
    return dataset    

In [None]:
path = ROOT + "/REID_Data/resnet_train"

dataset = save_dataset_as_tensor(train_csv_matched, train_csv_non_matched, path, 60000, 2048)
torch.save(dataset, ROOT + "/train_ds_resnet.pt")

path = ROOT + "/REID_Data/resnet_val"

dataset = save_dataset_as_tensor(val_csv_matched, val_csv_non_matched, path, 12000, 2048)
torch.save(dataset, ROOT + "/val_ds_resnet.pt")

In [None]:
path = ROOT + "/REID_Data/senet_train"

dataset = save_dataset_as_tensor(train_csv_matched, train_csv_non_matched, path, 60000, 2048)
torch.save(dataset, ROOT + "/train_ds_senet.pt")

path = ROOT + "/REID_Data/senet_val"

dataset = save_dataset_as_tensor(val_csv_matched, val_csv_non_matched, path, 12000, 2048)
torch.save(dataset, ROOT + "/val_ds_senet.pt")

In [None]:
path = ROOT + "/REID_Data/vit_train"

dataset = save_dataset_as_tensor(train_csv_matched, train_csv_non_matched, path, 60000, 768)
torch.save(dataset, ROOT + "/train_ds_vit.pt")

path = ROOT + "/REID_Data/vit_val"

dataset = save_dataset_as_tensor(val_csv_matched, val_csv_non_matched, path, 12000, 768)
torch.save(dataset, ROOT + "/val_ds_vit.pt")

# **Training Neural Network Classifier on RESNet Features**

In [None]:
X_train, y_train, X_val, y_val = load_data(ROOT + "/train_ds_resnet.pt", ROOT + "/val_ds_resnet.pt")

batch_size = 256

dataloader_train = DataLoader(TensorDataset(X_train, y_train.float()), 
                                          batch_size=batch_size,
                                          shuffle=True)

dataloader_val =  DataLoader(TensorDataset(X_val, y_val.float()), 
                              batch_size=batch_size,
                              shuffle=True)

model = NN_Classifier(2048 * 6, 2048, 1).to(device)

checkpoint_name = f"resnet"

learning_rate = 1e-4
no_of_epoch = 100

optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay=3e-4)
scheduler = LinearLR(optimizer, start_factor=1, end_factor = 0.2, total_iters=50)
start_epoch = 0
early_stopper = EarlyStopper_Checkpoint(patience = 5, 
                                        save_path =  ROOT + f"/{checkpoint_name}_best.pt",
                                        metric = "val_acc")
bce_loss_func = nn.BCEWithLogitsLoss(reduction='mean')

for epoch in range(start_epoch, start_epoch + no_of_epoch):
    train_loss = 0
    correct = 0

    # Training 
    model.train()
    for train_data, train_label in tqdm(dataloader_train):
        train_data = combine_function(train_data[:,0,:], train_data[:,1,:])
        optimizer.zero_grad()
        output= model(train_data)
        loss = bce_loss_func(output, train_label)
        loss.backward()
        optimizer.step()
        train_loss = ema_loss(loss, train_loss)

        y_pred = nn.Sigmoid()(output.detach())

        correct += ((y_pred>0.5).float() == train_label).float().sum()

    train_accuracy = correct / len(dataloader_train.dataset) * 100
    checkpoint = {"epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'scheduler_state_dict': scheduler.state_dict()}
    torch.save(checkpoint, ROOT + f"/{checkpoint_name}.pt")

    print(f"Saving checkpoint to {checkpoint_name}.pt")
    
    correct = 0
    val_loss = 0
    
    # Validation
    model.eval()
    with torch.no_grad():
        for val_data, val_label in tqdm(dataloader_val):
            val_data = combine_function(val_data[:,0,:], val_data[:,1,:])
            output= model(val_data)
            loss = bce_loss_func(output, val_label)
            y_pred = nn.Sigmoid()(output)
            val_loss = ema_loss(loss, val_loss)
            correct += ((y_pred>0.5).float() == val_label).float().sum()

        val_accuracy = correct / len(dataloader_val.dataset) * 100
        print(f"Epoch {epoch}: Train Loss: {train_loss: .5f}, Training Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss: .5f}, Val Accuracy = {val_accuracy:.2f}%")

    scheduler.step()
    if early_stopper.check_metric(val_accuracy):
        checkpoint = {"epoch": epoch,
                      "val_loss": val_loss,
                      "val_accuracy": val_accuracy,
                      "model_state_dict": model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict()}
        
        torch.save(checkpoint, ROOT + f"/{checkpoint_name}_best.pt")
        print(f"Saving checkpoint to {checkpoint_name}_best.pt")


    if early_stopper.early_stop(val_accuracy):
        print("Stopped early due to no improvement in validation loss")
        break


# **Training Neural Network Classifier on SENet Features**

In [None]:
X_train, y_train, X_val, y_val = load_data(ROOT + "/train_ds_senet.pt", ROOT + "/val_ds_senet.pt")

batch_size = 256

dataloader_train = DataLoader(TensorDataset(X_train, y_train.float()), 
                                          batch_size=batch_size,
                                          shuffle=True)

dataloader_val =  DataLoader(TensorDataset(X_val, y_val.float()), 
                              batch_size=batch_size,
                              shuffle=True)

model = NN_Classifier(2048 * 6, 2048, 1).to(device)

checkpoint_name = f"senet"

learning_rate = 1e-4
no_of_epoch = 100

optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay=3e-4)
scheduler = LinearLR(optimizer, start_factor=1, end_factor = 0.2, total_iters=50)
start_epoch = 0
early_stopper = EarlyStopper_Checkpoint(patience = 5, 
                                        save_path =  ROOT + f"/{checkpoint_name}_best.pt",
                                        metric = "val_acc")
bce_loss_func = nn.BCEWithLogitsLoss(reduction='mean')

for epoch in range(start_epoch, start_epoch + no_of_epoch):
    train_loss = 0
    correct = 0

    # Training 
    model.train()
    for train_data, train_label in tqdm(dataloader_train):
        train_data = combine_function(train_data[:,0,:], train_data[:,1,:])
        optimizer.zero_grad()
        output= model(train_data)
        loss = bce_loss_func(output, train_label)
        loss.backward()
        optimizer.step()
        train_loss = ema_loss(loss, train_loss)

        y_pred = nn.Sigmoid()(output.detach())

        correct += ((y_pred>0.5).float() == train_label).float().sum()

    train_accuracy = correct / len(dataloader_train.dataset) * 100
    checkpoint = {"epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'scheduler_state_dict': scheduler.state_dict()}
    torch.save(checkpoint, ROOT + f"/{checkpoint_name}.pt")

    print(f"Saving checkpoint to {checkpoint_name}.pt")
    
    correct = 0
    val_loss = 0
    
    # Validation
    model.eval()
    with torch.no_grad():
        for val_data, val_label in tqdm(dataloader_val):
            val_data = combine_function(val_data[:,0,:], val_data[:,1,:])
            output= model(val_data)
            loss = bce_loss_func(output, val_label)
            y_pred = nn.Sigmoid()(output)
            val_loss = ema_loss(loss, val_loss)
            correct += ((y_pred>0.5).float() == val_label).float().sum()

        val_accuracy = correct / len(dataloader_val.dataset) * 100
        print(f"Epoch {epoch}: Train Loss: {train_loss: .5f}, Training Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss: .5f}, Val Accuracy = {val_accuracy:.2f}%")

    scheduler.step()
    if early_stopper.check_metric(val_accuracy):
        checkpoint = {"epoch": epoch,
                      "val_loss": val_loss,
                      "val_accuracy": val_accuracy,
                      "model_state_dict": model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict()}
        
        torch.save(checkpoint, ROOT + f"/{checkpoint_name}_best.pt")
        print(f"Saving checkpoint to {checkpoint_name}_best.pt")


    if early_stopper.early_stop(val_accuracy):
        print("Stopped early due to no improvement in validation loss")
        break


# **Training Neural Network Classifier on Vision Transformer Features**

In [None]:
X_train, y_train, X_val, y_val = load_data(ROOT + "/train_ds_vit.pt", ROOT + "/val_ds_vit.pt")

batch_size = 256

dataloader_train = DataLoader(TensorDataset(X_train, y_train.float()), 
                                          batch_size=batch_size,
                                          shuffle=True)

dataloader_val =  DataLoader(TensorDataset(X_val, y_val.float()), 
                              batch_size=batch_size,
                              shuffle=True)

model = NN_Classifier(768 * 6, 2048, 1).to(device)

checkpoint_name = f"vit"

learning_rate = 1e-4
no_of_epoch = 100

optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay=3e-4)
scheduler = LinearLR(optimizer, start_factor=1, end_factor = 0.2, total_iters=50)
start_epoch = 0
early_stopper = EarlyStopper_Checkpoint(patience = 5, 
                                        save_path =  ROOT + f"/{checkpoint_name}_best.pt",
                                        metric = "val_acc")
bce_loss_func = nn.BCEWithLogitsLoss(reduction='mean')

for epoch in range(start_epoch, start_epoch + no_of_epoch):
    train_loss = 0
    correct = 0

    # Training 
    model.train()
    for train_data, train_label in tqdm(dataloader_train):
        train_data = combine_function(train_data[:,0,:], train_data[:,1,:])
        optimizer.zero_grad()
        output= model(train_data)
        loss = bce_loss_func(output, train_label)
        loss.backward()
        optimizer.step()
        train_loss = ema_loss(loss, train_loss)

        y_pred = nn.Sigmoid()(output.detach())

        correct += ((y_pred>0.5).float() == train_label).float().sum()

    train_accuracy = correct / len(dataloader_train.dataset) * 100
    checkpoint = {"epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'scheduler_state_dict': scheduler.state_dict()}
    torch.save(checkpoint, ROOT + f"/{checkpoint_name}.pt")

    print(f"Saving checkpoint to {checkpoint_name}.pt")
    
    correct = 0
    val_loss = 0
    
    # Validation
    model.eval()
    with torch.no_grad():
        for val_data, val_label in tqdm(dataloader_val):
            val_data = combine_function(val_data[:,0,:], val_data[:,1,:])
            output= model(val_data)
            loss = bce_loss_func(output, val_label)
            y_pred = nn.Sigmoid()(output)
            val_loss = ema_loss(loss, val_loss)
            correct += ((y_pred>0.5).float() == val_label).float().sum()

        val_accuracy = correct / len(dataloader_val.dataset) * 100
        print(f"Epoch {epoch}: Train Loss: {train_loss: .5f}, Training Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss: .5f}, Val Accuracy = {val_accuracy:.2f}%")

    scheduler.step()
    if early_stopper.check_metric(val_accuracy):
        checkpoint = {"epoch": epoch,
                      "val_loss": val_loss,
                      "val_accuracy": val_accuracy,
                      "model_state_dict": model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict()}
        
        torch.save(checkpoint, ROOT + f"/{checkpoint_name}_best.pt")
        print(f"Saving checkpoint to {checkpoint_name}_best.pt")


    if early_stopper.early_stop(val_accuracy):
        print("Stopped early due to no improvement in validation loss")
        break


# **Training Neural Network Classifier on all RESNet, SENet, and Vision Transformer Features**

In [None]:
X_train_1, y_train, X_val_1, y_val = load_data(ROOT + "/train_ds_resnet.pt", ROOT + "/val_ds_resnet.pt")
X_train_2, _, X_val_2, _ = load_data(ROOT + "/train_ds_senet.pt", ROOT + "/val_ds_senet.pt")
X_train_3, _, X_val_3, _ = load_data(ROOT + "/train_ds_vit.pt", ROOT + "/val_ds_vit.pt")

X_train = torch.concat([X_train_1, X_train_2, X_train_3], dim = -1)
X_val = torch.concat([X_val_1, X_val_2, X_val_3], dim = -1)

batch_size = 256

dataloader_train = DataLoader(TensorDataset(X_train, y_train.float()), 
                                          batch_size=batch_size,
                                          shuffle=True)

dataloader_val =  DataLoader(TensorDataset(X_val, y_val.float()), 
                              batch_size=batch_size,
                              shuffle=True)

model = NN_Classifier(4864 * 6, 1024, 1).to(device)

checkpoint_name = f"combined"

learning_rate = 1e-4
no_of_epoch = 100

optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay=3e-4)
scheduler = LinearLR(optimizer, start_factor=1, end_factor = 0.2, total_iters=50)
start_epoch = 0
early_stopper = EarlyStopper_Checkpoint(patience = 5, 
                                        save_path =  ROOT + f"/{checkpoint_name}_best.pt",
                                        metric = "val_acc")
bce_loss_func = nn.BCEWithLogitsLoss(reduction='mean')

for epoch in range(start_epoch, start_epoch + no_of_epoch):
    train_loss = 0
    correct = 0

    # Training 
    model.train()
    for train_data, train_label in tqdm(dataloader_train):
        train_data = combine_function(train_data[:,0,:], train_data[:,1,:])
        optimizer.zero_grad()
        output= model(train_data)
        loss = bce_loss_func(output, train_label)
        loss.backward()
        optimizer.step()
        train_loss = ema_loss(loss, train_loss)

        y_pred = nn.Sigmoid()(output.detach())

        correct += ((y_pred>0.5).float() == train_label).float().sum()

    train_accuracy = correct / len(dataloader_train.dataset) * 100
    checkpoint = {"epoch": epoch,
                  "model_state_dict": model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'scheduler_state_dict': scheduler.state_dict()}
    torch.save(checkpoint, ROOT + f"/{checkpoint_name}.pt")

    print(f"Saving checkpoint to {checkpoint_name}.pt")
    
    correct = 0
    val_loss = 0
    
    # Validation
    model.eval()
    with torch.no_grad():
        for val_data, val_label in tqdm(dataloader_val):
            val_data = combine_function(val_data[:,0,:], val_data[:,1,:])
            output= model(val_data)
            loss = bce_loss_func(output, val_label)
            y_pred = nn.Sigmoid()(output)
            val_loss = ema_loss(loss, val_loss)
            correct += ((y_pred>0.5).float() == val_label).float().sum()

        val_accuracy = correct / len(dataloader_val.dataset) * 100
        print(f"Epoch {epoch}: Train Loss: {train_loss: .5f}, Training Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss: .5f}, Val Accuracy = {val_accuracy:.2f}%")

    scheduler.step()
    if early_stopper.check_metric(val_accuracy):
        checkpoint = {"epoch": epoch,
                      "val_loss": val_loss,
                      "val_accuracy": val_accuracy,
                      "model_state_dict": model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict()}
        
        torch.save(checkpoint, ROOT + f"/{checkpoint_name}_best.pt")
        print(f"Saving checkpoint to {checkpoint_name}_best.pt")


    if early_stopper.early_stop(val_accuracy):
        print("Stopped early due to no improvement in validation loss")
        break
