In [None]:
!pip install torch torchvision transformers
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTConfig
import pandas as pd
import os
import numpy as np
import gdown
import zipfile
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score,roc_curve

BASE_DIR = "Data_V9_ViT"
scaler = torch.amp.GradScaler('cuda')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "texture_iam.zip", quiet=False)
with zipfile.ZipFile("texture_iam.zip", 'r') as zip_ref:
    zip_ref.extractall("")

In [None]:
class LoadDataset(Dataset):
    def __init__(self, parquet_path, base_dir=BASE_DIR):
        super().__init__()
        self.df = pd.read_parquet(parquet_path)
        # self.base_dir = base_dir.rstrip("/")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = float(row["label"])
        label_tensor = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        dummy = torch.zeros(1)
        return dummy, dummy, label_tensor, str(row["sample_1"]), str(row["sample_2"])

In [None]:
vit_name = "google/vit-base-patch16-384"
config = ViTConfig.from_pretrained(vit_name)
vit_backbone = ViTModel.from_pretrained(vit_name, config=config).to(device)

for param in vit_backbone.embeddings.parameters():
    param.requires_grad = False

for i in range(6):
    for param in vit_backbone.encoder.layer[i].parameters():
        param.requires_grad = False


In [None]:
class ViTEmbedder(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        outputs = self.vit(pixel_values=x)   
        return outputs.pooler_output          

vit_embedder = ViTEmbedder(vit_backbone).to(device)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=768, hidden_dims=[1024, 512, 256], dropout=0.1):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(prev, h),
                nn.LayerNorm(h),
                nn.GELU(),
            ]
            prev = h
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        z = self.net(x)          
        return F.normalize(z, p=2, dim=1)

proj_head = ProjectionHead(
    in_dim=768,
    hidden_dims=[512, 256, 128],
    dropout=0.1
).to(device)

In [None]:
class EuclideanDistance(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        diff_sq = (x - y).pow(2)
        sum_sq = diff_sq.sum(dim=1, keepdim=True)
        sum_sq = torch.clamp(sum_sq, min=self.eps)
        return torch.sqrt(sum_sq)

In [None]:
class SiameseViT(nn.Module):
    def __init__(self, embedder, head):
        super().__init__()
        self.embedder = embedder
        self.head = head
        self.distance = EuclideanDistance()

    def forward_once(self, img):
        cls_emb = self.embedder(img)
        emb256 = self.head(cls_emb)
        return emb256

    def forward(self, img_a, img_b):
        emb_a = self.forward_once(img_a)
        emb_b = self.forward_once(img_b)
        dist = self.distance(emb_a, emb_b)
        return dist

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, dist, label):

        label = label.view(-1, 1)  
        pos_loss = (1.0 - label) * torch.pow(dist, 2)
        neg_loss = label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        loss = torch.mean(pos_loss + neg_loss)
        return loss

criterion = ContrastiveLoss(margin=1.0).to(device)

## Averaging Section


In [None]:
def load_and_average_textures(folder_path, model, device):
    texture_files = sorted([
        f for f in os.listdir(folder_path) if f.endswith(".pt")
    ])
    embeddings = []
    with torch.no_grad():
        for tf in texture_files:
            tensor = torch.load(os.path.join(folder_path, tf)).unsqueeze(0).to(device)
            emb = model.forward_once(tensor) 
            embeddings.append(emb.squeeze(0)) 
    avg_embedding = torch.stack(embeddings).mean(dim=0) 
    return avg_embedding


In [None]:
def collect_scores(model, loader, device, texture_root):
    model.eval()
    all_scores = []
    all_labels = []
    all_pairs = []

    with torch.no_grad():
        for _, _, label, name1_list, name2_list in tqdm(loader, desc="Collecting Scores"):
            label = label.to(device)

            # name1, name2 are paths like W001/S01_N/W001_S01_N
            for name1, name2, lbl in zip(name1_list, name2_list, label):
              sample_1 = name1.split('_')[0]
              sample_2 = name2.split('_')[0]
              folder1 = os.path.join(texture_root, sample_1, name1)
              folder2 = os.path.join(texture_root, sample_2, name2)

              emb1 = load_and_average_textures(folder1, model, device)
              emb2 = load_and_average_textures(folder2, model, device)

              dist = torch.norm(emb1 - emb2, p=2).unsqueeze(0)

              all_scores.append(dist)
              all_labels.append(lbl.view(-1))
              all_pairs.append((name1, name2))

    all_scores = torch.cat(all_scores).cpu().numpy()
    all_labels = torch.cat(all_labels).cpu().numpy().astype(int)

    return all_scores, all_labels, all_pairs


## Evaluating Section

In [None]:
file_id = "1ASQuT-LU5H1f_ydTrytQ-f5k8l-C7-yw"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "best_model.pt", quiet=False)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "iam_test.parquet", quiet=False)

In [None]:
siamese_model = SiameseViT(vit_embedder, proj_head).to(device)
siamese_model.load_state_dict(torch.load("best_model.pt", map_location=device))

In [None]:
BATCH_SIZE_VAL = 512

test_dataset = LoadDataset(
    parquet_path="iam_test.parquet"
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_VAL,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

In [None]:
thresholds = {
    'eer': 0.7371158,
    'f1': 0.72520995,
    'bf': 0.7593953
}

In [None]:
test_scores, test_labels, sample_pairs = collect_scores(siamese_model, test_loader, device, "texture_iam")

def compute_far_frr(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[1, 0]).ravel()
    far = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    return far, frr

for name, thr in thresholds.items():
    preds = (test_scores >= thr).astype(int)
    acc   = accuracy_score(test_labels, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(test_labels, preds, average="binary")
    auc  = roc_auc_score(test_labels, test_scores)
    fpr, tpr, thresholdss = roc_curve(test_labels, test_scores)
    fnr = 1 - tpr
    eer_threshold_idx = np.argmin(np.abs(fpr - fnr))
    eer = (fpr[eer_threshold_idx] + fnr[eer_threshold_idx]) / 2
    eer_threshold = thresholdss[eer_threshold_idx]
    far, frr = compute_far_frr(test_labels, preds)

    print(f"\nThreshold used     : {name}")
    print(f"Threshold value    : {thr:.3f}")
    print(f"ROC-AUC            : {auc:.4f}")
    print(f"Accuracy           : {acc:.4f}")
    print(f"Precision          : {prec:.4f}")
    print(f"Recall             : {rec:.4f}")
    print(f"F1-score           : {f1:.4f}")
    print(f"EER                : {eer:.4f}")
    print(f"FAR                : {far:.4f}")
    print(f"FRR                : {frr:.4f}")