# Import Library

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname)
    for filename in filenames:
        if filename != '*.jpg':
            print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip uninstall -y tensorflow

In [None]:
import torch
torch.autograd.set_detect_anomaly(True)

print('pytorch version:', torch.__version__)
print("GPU available:", torch.cuda.device_count())
#print('GPU name:',torch.cuda.get_device_name(0))
device_name = (torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
print("device name:", device_name)

# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

In [None]:
# For data augmentation
import torchvision
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2

In [None]:
import transformers

# For Tokenizers
from transformers import ViTImageProcessor, ViTConfig

# For Model
from transformers import ViTModel, ViTForImageClassification

# For GPU
from transformers import set_seed
from torch.optim import AdamW
from accelerate import Accelerator, notebook_launcher

# For Dataset
from torch.utils.data import Dataset, DataLoader

# For Loss calculation
import torch.nn.functional as F
from torch.nn import CosineEmbeddingLoss, TripletMarginLoss, MSELoss

# For Display
from tqdm.notebook import tqdm

In [None]:
import random
import cv2
import pandas as pd
import numpy as np
from PIL import Image
from itertools import combinations, product

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [8,5]

In [None]:
from sklearn.model_selection import train_test_split

# Declare Global Constants

In [None]:
output_dir = '/kaggle/working'
data_dir = '/kaggle/input/celeba-dataset'
image_dir = '/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/'

In [None]:
BATCH_SIZE = 8
print('BATCH_SIZE =',BATCH_SIZE)

MODEL_TRANSFORMER = 'google/vit-base-patch16-224'

CLIP_SIZE = 224
print('Image Dimension =', CLIP_SIZE,'X', CLIP_SIZE)

SEED = 42

In [None]:
def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
seed_everything(SEED)

# Load Dataset

In [None]:
identity_df = pd.read_csv('/kaggle/input/identity-celeba/identity_CelebA.txt', sep='\s+',  header=None, names=["image", "identity"])
identity_df.head(3)

In [None]:
attributes_df = pd.read_csv(data_dir+'/list_attr_celeba.csv')
attributes_df.head(3)

In [None]:
partition_df = pd.read_csv(data_dir+'/list_eval_partition.csv')
partition_df.head(3)

In [None]:
landmarks_df = pd.read_csv(data_dir+'/list_landmarks_align_celeba.csv')
landmarks_df.head(3)

In [None]:
bbox_df = pd.read_csv(data_dir+'/list_bbox_celeba.csv')
bbox_df.head(3)

# Select desired data

In [None]:
combined_df = pd.concat([identity_df.set_index(keys=['image'], verify_integrity=True), partition_df.set_index(keys=['image_id'], verify_integrity=True)], axis=1, verify_integrity=True).copy()
combined_df

In [None]:
combined_df['partition'].value_counts()

In [None]:
train_df = combined_df[combined_df['partition']==0].copy()
print('Train dataset shape:',train_df.shape)
val_df = combined_df[combined_df['partition']==1].copy()
print('Validation dataset shape:',val_df.shape)
test_df = combined_df[combined_df['partition']==2].copy()
print('Test dataset shape:',test_df.shape)

# Create Dataset

In [None]:
def createContrastPairs(data_df):
  selected_list = []
  negetive_list = []
  classes = list(data_df['identity'].unique())
  for cls1 in tqdm(classes):
    # randomly select alternate class
    temp_cls = classes
    temp_cls.remove(cls1)
    cls2 = random.choice(temp_cls)

    # List all the
    images_class1 = data_df[data_df['identity'] == cls1].index.to_list()
    images_class2 = data_df[data_df['identity'] == cls2].index.to_list()

    # Create list of all positive combinations
    for img1, img2 in combinations(images_class1, 2):
      selected_list.append([img1, img2, 1])

    # Create list of negetive combinations
    for img1, img2 in product(images_class1, images_class2):
      negetive_list.append([img1, img2, -1])

  # Balance the positive and negetive list
  negetive_list = random.sample(negetive_list, len(selected_list))
  # Combine the selections
  selected_list.extend(negetive_list)

  # Create Dataframe
  data_df = pd.DataFrame(selected_list, columns=['image1','image2','similarity'])
  # Shuffle dataset
  data_df = data_df.sample(frac = 1)

  return data_df

In [None]:
train_pair_df = createContrastPairs(train_df)
print('Train Dataset Pair shape:', train_pair_df.shape)
val_pair_df = createContrastPairs(val_df)
print('Validation Dataset Pair shape:', val_pair_df.shape)
test_pair_df = createContrastPairs(test_df)
print('Test Dataset Pair shape:', test_pair_df.shape)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_num_1, image_num_2, labels):
        self.image_num_1 = image_num_1  # Store paths instead of images
        self.image_num_2 = image_num_2
        self.labels = labels
        self.transform_dataset = ViTImageProcessor.from_pretrained(MODEL_TRANSFORMER, attn_implementation="sdpa", torch_dtype=torch.float16)

    def __len__(self):
        return len(self.image_num_1)

    def __getitem__(self, idx):
        img_path_1 = self.image_num_1[idx]
        img_path_2 = self.image_num_2[idx]

         # Load image 1 when needed
        image_1 = Image.open(img_path_1).convert("RGB")
        image_1 = self.transform_dataset(image_1)['pixel_values'][0]

         # Load image 2 when needed
        image_2 = Image.open(img_path_2).convert("RGB")
        image_2 = self.transform_dataset(image_2)['pixel_values'][0]

        label = self.labels[idx]

        return {'pixel_values_1': image_1, 'pixel_values_2': image_2, 'labels': label}

In [None]:
train_ds = CustomDataset(
        image_num_1=[os.path.join(image_dir, img) for img in train_pair_df.image1.to_list()],
        image_num_2=[os.path.join(image_dir, img) for img in train_pair_df.image2.to_list()],
        labels=torch.tensor(train_pair_df.similarity.values, dtype=torch.float32)
        )
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
val_ds = CustomDataset(
        image_num_1=[os.path.join(image_dir, img) for img in val_pair_df.image1.to_list()],
        image_num_2=[os.path.join(image_dir, img) for img in val_pair_df.image2.to_list()],
        labels=torch.tensor(val_pair_df.similarity.values, dtype=torch.float32)
        )
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)

In [None]:
#test_ds = CustomDataset(
#        image_num_1=[os.path.join(root_dir, img) for img in test_pair_df.image1.to_list()],
#        image_num_2=[os.path.join(root_dir, img) for img in test_pair_df.image2.to_list()],
#        labels=torch.tensor(test_pair_df.similarity.values, dtype=torch.int64)
#        )
#test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)

# Testing the Dataloader

In [None]:
# Convert images to numpy for visualization
def imgshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
# Testing the dataset
batch = next(iter(train_dl))
print(batch['pixel_values_1'].shape,batch['pixel_values_1'].shape, batch['labels'].shape)
imgshow(torchvision.utils.make_grid(batch['pixel_values_1'][0]))
imgshow(torchvision.utils.make_grid(batch['pixel_values_2'][0]))
print(batch['labels'][0])

# Load Model

In [None]:
torch.cuda.empty_cache()

In [None]:
hyperparameters = {
    "learning_rate": 0.0001,
    "num_epochs": 20, # set to very high number
    "seed": SEED,
    "patience": 10, # early stopping
    
    "num_hidden_layers": 2,
    "num_attention_heads": 1,
    
    "hidden_dropout_prob": 0.2,
    "attention_probs_dropout_prob":0.2,
    
    "output_dir_pt": f"{output_dir}/vit_celebA_gpu_pt_1.pt",
    "output_dir_transformer": f"{output_dir}/vit_celebA_gpu_pt_1"
}


In [None]:
class ParallelViTNetwork(torch.nn.Module):
    def __init__(self):
        super(ParallelViTNetwork, self).__init__()
        self.config = ViTConfig.from_pretrained(MODEL_TRANSFORMER ,return_dict=True, 
                                       num_hidden_layers= hyperparameters['num_hidden_layers'],
                                       num_attention_heads = hyperparameters['num_attention_heads'],
                                       hidden_dropout_prob = hyperparameters['hidden_dropout_prob'],
                                       attention_probs_dropout_prob = hyperparameters['attention_probs_dropout_prob']
                                      )
        self.embedding_model_1 = ViTModel.from_pretrained(MODEL_TRANSFORMER, config=self.config)
        self.embedding_model_2 = ViTModel.from_pretrained(MODEL_TRANSFORMER, config=self.config)
        
        # Freeze all layers
        for param in self.embedding_model_1.parameters():
            param.requires_grad = False
        for param in self.embedding_model_2.parameters():
            param.requires_grad = False

        hidden_size = self.embedding_model_1.config.hidden_size
        self.final_layer = torch.nn.Linear(hidden_size * 2, 1)  # Final dense layer
        self.activation = torch.nn.Tanh() # To restrict the output between -1 to 1

    def forward(self, image_1, image_2):
        out1 = self.embedding_model_1(image_1)  # Output from first ViT layer
        emb1 = out1.last_hidden_state[:, 0, :]
        
        out2 = self.embedding_model_2(image_2)  # Output from second ViT layer
        emb2 = out2.last_hidden_state[:, 0, :]
        
        concatenated = torch.cat((emb1, emb2), dim=1)  # Concatenate along the feature dimension
        output = self.final_layer(concatenated)  # Pass through final dense layer
        output = self.activation(output)  # Apply Tanh activation
        
        return output

In [None]:
# Now we train the model
def training_function():
    # Initialize accelerator
    accelerator = Accelerator()
    
    # The seed need to be set before we instantiate the model, as it will determine the random head.
    set_seed(hyperparameters["seed"])
    
    # Instantiate the model, chnage the final classification layer, let Accelerate handle the device placement.
    embedding_model = ParallelViTNetwork()
    
    # Loss function
    criterion = MSELoss() #CosineEmbeddingLoss(margin=0.25)
    
    # Instantiate optimizer
    optimizer = AdamW(embedding_model.parameters(), lr=hyperparameters["learning_rate"])
    
    # Define the learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=0.00001
    )
    
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.
    accelerated_model, acclerated_criterion ,acclerated_optimizer, acclerated_train_dl, acclerated_val_dl = accelerator.prepare(embedding_model, criterion, optimizer, train_dl, val_dl)
    
    # Build the training loop
    epochs_no_improve = 0
    min_val_loss = float("inf")

    for epoch in range(hyperparameters["num_epochs"]):
        # We only enable the progress bar on the main process to avoid having 8 progress bars.
        progress_bar = tqdm(range(len(acclerated_train_dl)), disable=not accelerator.is_main_process)
        progress_bar.set_description(f"Epoch: {epoch}")
        accelerated_model.train()
        training_loss = []
        for batch in acclerated_train_dl:
            # Forward pass
            img1 = batch['pixel_values_1']
            img2 = batch['pixel_values_2']
            label = batch['labels']
            
            # Extract embeddings
            train_output = accelerated_model(img1, img2)
            
            # Compute loss
            train_loss = acclerated_criterion(train_output, label)
            
            # Backward pass
            accelerator.backward(train_loss)
            
            # Optimize
            acclerated_optimizer.step()
            acclerated_optimizer.zero_grad()
            
            # We gather the loss from the GPU cores to have them all.
            training_loss.append(accelerator.gather(train_loss[None]))
            progress_bar.set_postfix({'loss': train_loss.item()})
            progress_bar.update(1)

        # Compute average training loss
        training_loss_final = torch.stack(training_loss).sum().item() / len(training_loss)
        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}: learning rate:", scheduler.get_last_lr())
        accelerator.print(f"epoch {epoch}: training loss:", training_loss_final)
        
        # Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
        accelerated_model.eval()
        validation_loss = []

        for batch in acclerated_val_dl:
            # Forward pass
            img1 = batch['pixel_values_1']
            img2 = batch['pixel_values_2']
            label = batch['labels']
            
            with torch.no_grad():
                val_output = accelerated_model(img1, img2)
            
            val_loss = acclerated_criterion(val_output, label)
            
            # We gather the loss from the GPU cores to have them all.
            validation_loss.append(accelerator.gather(val_loss[None]))

        # Compute average validation loss
        validation_loss_final = torch.stack(validation_loss).sum().item() / len(validation_loss)
        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}: validation loss:", validation_loss_final)
    
        # Step the scheduler
        scheduler.step(validation_loss_final)
    
        # Save model with early stopping
        if validation_loss_final < min_val_loss:
            epochs_no_improve = 0
            min_val_loss = validation_loss_final
            # Save the entire model (including architecture and weights)
            torch.save(accelerated_model, hyperparameters['output_dir_pt'])
            accelerated_model.save_pretrained(hyperparameters['output_dir_transformer'])
            continue
        else:
            epochs_no_improve += 1
            # Check early stopping condition
            if epochs_no_improve == hyperparameters["patience"]:
                accelerator.print("Early stopping!")
                break

In [None]:
training_function()