In [None]:
!pip install torch torchvision transformers
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig
import pandas as pd
import os
import numpy as np
from PIL import Image
import gdown
import zipfile
import matplotlib.pyplot as plt
from torchmetrics import AUROC
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

BASE_DIR = "Data_V9_ViT"

scaler = torch.amp.GradScaler('cuda')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "texture_uni.zip", quiet=False)
with zipfile.ZipFile("texture_uni.zip", 'r') as zip_ref:
    zip_ref.extractall("")

In [4]:
class LoadDataset(Dataset):
    def __init__(self, parquet_path, base_dir=BASE_DIR):
        super().__init__()
        self.df = pd.read_parquet(parquet_path)
        self.base_dir = base_dir.rstrip("/")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        def pt_path(rel_path: str) -> str:
            return os.path.join(self.base_dir, rel_path + ".pt")

        img1_path = pt_path(row["sample_1"])
        img2_path = pt_path(row["sample_2"])

        try:
            img1_tensor = torch.load(img1_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img1_path}: {e}")

        try:
            img2_tensor = torch.load(img2_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img2_path}: {e}")

        label = float(row["label"])
        label_tensor = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        return img1_tensor, img2_tensor, label_tensor, row["sample_1"], row["sample_2"]

In [None]:
vit_name = "google/vit-base-patch16-384"
config = ViTConfig.from_pretrained(vit_name)
vit_backbone = ViTModel.from_pretrained(vit_name, config=config).to(device)

for param in vit_backbone.embeddings.parameters():
    param.requires_grad = False

for i in range(6):
    for param in vit_backbone.encoder.layer[i].parameters():
        param.requires_grad = False


In [None]:
class ViTEmbedder(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        outputs = self.vit(pixel_values=x)    
        return outputs.pooler_output          

vit_embedder = ViTEmbedder(vit_backbone).to(device)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=768, hidden_dims=[1024, 512, 256], dropout=0.1):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(prev, h),
                nn.LayerNorm(h),
                nn.GELU(),
            ]
            prev = h
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        z = self.net(x)         
        return F.normalize(z, p=2, dim=1)

proj_head = ProjectionHead(
    in_dim=768,
    hidden_dims=[512, 256, 128],
    dropout=0.1
).to(device)

In [8]:
class EuclideanDistance(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        diff_sq = (x - y).pow(2)
        sum_sq = diff_sq.sum(dim=1, keepdim=True)
        sum_sq = torch.clamp(sum_sq, min=self.eps)
        return torch.sqrt(sum_sq)

In [9]:
class SiameseViT(nn.Module):
    def __init__(self, embedder, head):
        super().__init__()
        self.embedder = embedder
        self.head = head
        self.distance = EuclideanDistance()

    def forward_once(self, img):
        cls_emb = self.embedder(img)
        emb256 = self.head(cls_emb)
        return emb256

    def forward(self, img_a, img_b):
        emb_a = self.forward_once(img_a)
        emb_b = self.forward_once(img_b)
        dist = self.distance(emb_a, emb_b)
        return dist

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, dist, label):

        label = label.view(-1, 1)  
        pos_loss = (1.0 - label) * torch.pow(dist, 2)
        neg_loss = label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        loss = torch.mean(pos_loss + neg_loss)
        return loss
criterion = ContrastiveLoss(margin=1.0).to(device)


## Two Speed Distance Calculation


In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "uni_test.parquet", quiet=False)

In [None]:
file_id = "1ASQuT-LU5H1f_ydTrytQ-f5k8l-C7-yw"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "best_model.pt", quiet=False)

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)
siamese_model.load_state_dict(torch.load("best_model.pt", map_location=device))

In [None]:
def load_tensor(folder_path):
    tensor_paths = sorted([
        os.path.join(folder_path, f)
        for f in os.listdir(folder_path)
        if f.endswith(".pt")
    ])
    embeddings = []
    for path in tensor_paths:
        tensor = torch.load(path).unsqueeze(0).to(device)  
        emb = siamese_model.forward_once(tensor)
        embeddings.append(emb)

    avg_embedding = torch.stack(embeddings).mean(dim=0)  
    return avg_embedding


In [None]:
df = pd.read_parquet("uni_test.parquet")
rows = []

siamese_model.eval()  

with torch.no_grad():
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Computing ViT embeddings"):
        p1N = load_tensor(row["sample_1_N"])
        p1F = load_tensor(row["sample_1_F"])
        p2N = load_tensor(row["sample_2_N"])
        p2F = load_tensor(row["sample_2_F"])

        # Already averaged embeddings
        emb_1N = p1N
        emb_1F = p1F
        emb_2N = p2N
        emb_2F = p2F

        # Compute distances
        d_1N_1F = torch.norm(emb_1N - emb_1F).item()
        d_2N_2F = torch.norm(emb_2N - emb_2F).item()
        d_1N_2N = torch.norm(emb_1N - emb_2N).item()
        d_1N_2F = torch.norm(emb_1N - emb_2F).item()
        d_1F_2N = torch.norm(emb_1F - emb_2N).item()
        d_1F_2F = torch.norm(emb_1F - emb_2F).item()

        sample_1_N = row["sample_1_N"].replace("texture_uni/", "")
        sample_1_F = row["sample_1_F"].replace("texture_uni/", "")
        sample_2_N = row["sample_2_N"].replace("texture_uni/", "")
        sample_2_F = row["sample_2_F"].replace("texture_uni/", "")

        rows.append({
            "d_1N_1F": d_1N_1F,
            "d_2N_2F": d_2N_2F,
            "d_1N_2N": d_1N_2N,
            "d_1N_2F": d_1N_2F,
            "d_1F_2N": d_1F_2N,
            "d_1F_2F": d_1F_2F,
            "sample_1_N": sample_1_N,
            "sample_1_F": sample_1_F,
            "sample_2_N": sample_2_N,
            "sample_2_F": sample_2_F,
            "label": row["label"]
        })

dist_df = pd.DataFrame(rows)
dist_df.to_parquet("test_dissimilarity_dataset_vit.parquet", index=False)
dist_df.to_csv("test_dissimilarity_dataset_vit.csv", index=False)
