In [None]:
!pip install torch torchvision transformers
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig
import pandas as pd
import os
import numpy as np
from PIL import Image
import gdown
import zipfile
import matplotlib.pyplot as plt
from torchmetrics import AUROC
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

BASE_DIR = "Data_V9_ViT"

scaler = torch.amp.GradScaler('cuda')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "Data_V9_ViT.zip", quiet=False)
with zipfile.ZipFile("Data_V9_ViT.zip", 'r') as zip_ref:
    zip_ref.extractall("")

In [None]:
class LoadDataset(Dataset):
    def __init__(self, parquet_path, base_dir=BASE_DIR):
        super().__init__()
        self.df = pd.read_parquet(parquet_path)
        self.base_dir = base_dir.rstrip("/")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        def pt_path(rel_path: str) -> str:
            return os.path.join(self.base_dir, rel_path + ".pt")

        img1_path = pt_path(row["sample_1"])
        img2_path = pt_path(row["sample_2"])

        try:
            img1_tensor = torch.load(img1_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img1_path}: {e}")

        try:
            img2_tensor = torch.load(img2_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img2_path}: {e}")

        label = float(row["label"])
        label_tensor = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        return img1_tensor, img2_tensor, label_tensor, row["sample_1"], row["sample_2"]

In [None]:
BATCH_SIZE = 128
VAL_BATCH_SIZE = 1024

train_dataset = LoadDataset(
    parquet_path=BASE_DIR + "/train.parquet"
)

val_dataset = LoadDataset(
    parquet_path=BASE_DIR + "/val.parquet"
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
val_loader   = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [None]:
vit_name = "google/vit-base-patch16-384"
config = ViTConfig.from_pretrained(vit_name)
vit_backbone = ViTModel.from_pretrained(vit_name, config=config).to(device)

for param in vit_backbone.embeddings.parameters():
    param.requires_grad = False

for i in range(6):
    for param in vit_backbone.encoder.layer[i].parameters():
        param.requires_grad = False


In [None]:
class ViTEmbedder(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        outputs = self.vit(pixel_values=x)    
        return outputs.pooler_output         

vit_embedder = ViTEmbedder(vit_backbone).to(device)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=768, hidden_dims=[1024, 512, 256], dropout=0.1):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(prev, h),
                nn.LayerNorm(h),
                nn.GELU(),
            ]
            prev = h
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        z = self.net(x)         
        return F.normalize(z, p=2, dim=1)

proj_head = ProjectionHead(
    in_dim=768,
    hidden_dims=[512, 256, 128],
    dropout=0.1
).to(device)

In [None]:
class EuclideanDistance(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        diff_sq = (x - y).pow(2)
        sum_sq = diff_sq.sum(dim=1, keepdim=True)
        sum_sq = torch.clamp(sum_sq, min=self.eps)
        return torch.sqrt(sum_sq)

In [None]:
class SiameseViT(nn.Module):
    def __init__(self, embedder, head):
        super().__init__()
        self.embedder = embedder
        self.head = head
        self.distance = EuclideanDistance()

    def forward_once(self, img):
        cls_emb = self.embedder(img)
        emb256 = self.head(cls_emb)
        return emb256

    def forward(self, img_a, img_b):
        emb_a = self.forward_once(img_a)
        emb_b = self.forward_once(img_b)
        dist = self.distance(emb_a, emb_b)
        return dist

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, dist, label):

        label = label.view(-1, 1)   
        pos_loss = (1.0 - label) * torch.pow(dist, 2)
        neg_loss = label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        loss = torch.mean(pos_loss + neg_loss)
        return loss

criterion = ContrastiveLoss(margin=1.0).to(device)


## Evaluating Section

In [None]:
file_id = "1ASQuT-LU5H1f_ydTrytQ-f5k8l-C7-yw"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "best_model.pt", quiet=False)

In [None]:
siamese_model = SiameseViT(vit_embedder, proj_head).to(device)
siamese_model.load_state_dict(torch.load("best_model.pt", map_location=device))

<All keys matched successfully>

In [None]:
def plot_histogram(dists, labels,set, out_path):
    neg = dists[labels == 0]
    pos = dists[labels == 1]

    plt.figure(figsize=(6,4))
    plt.hist(neg, bins=50, alpha=0.6, label="Same Writer")
    plt.hist(pos, bins=50, alpha=0.6, label="Different Writer")
    plt.title(f"{set}")
    plt.xlabel("Euclidean distance")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.grid(True)
    plt.savefig(out_path, dpi=300)  
    plt.show()
    # plt.close()

In [None]:
BATCH_SIZE_VAL = 512

test_dataset = LoadDataset(
    parquet_path=BASE_DIR + "/uni_test.parquet"
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_VAL,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader   = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE_VAL,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [None]:
siamese_model.to(device)
print(next(siamese_model.parameters()).device)

print("Loader num workers:", val_loader.num_workers)
print("Batch size:", val_loader.batch_size)

In [None]:
import numpy as np

def collect_scores(model, loader, device):
    model.eval()
    all_scores = []
    all_labels = []
    all_pairs = []
    with torch.no_grad():
        for img1, img2, label, name1, name2 in tqdm(loader, desc="Collecting Scores"):     
            img1, img2 = img1.to(device), img2.to(device)
            label      = label.to(device)

            scores = model(img1, img2)         
            all_scores.append(scores.view(-1))
            all_labels.append(label.view(-1))

            if isinstance(name1, str):  
                all_pairs.append((name1, name2))
            else:
                all_pairs.extend(zip(name1, name2))
    all_scores = torch.cat(all_scores).cpu().numpy()
    all_labels = torch.cat(all_labels).cpu().numpy().astype(int)

    return all_scores, all_labels,  all_pairs

In [None]:
thresholds = {
    'eer': 0.7371158,
    'f1': 0.72520995,
    'bf': 0.7593953
}

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score,roc_curve
test_scores, test_labels, sample_pairs = collect_scores(siamese_model, test_loader, device)
# df = pd.read_csv('test_scores_labels.csv')
import numpy as np

# test_scores = df['score'].values
# test_labels = df['label'].values.astype(int)
# plot_histogram(test_scores,test_labels,'[Internal Dataset] Predicted Distances Histogram','test_hist.png')
for name, thr in thresholds.items():
    preds = (test_scores >= thr).astype(int)
    acc   = accuracy_score(test_labels, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(test_labels, preds, average="binary")
    auc  = roc_auc_score(test_labels, test_scores)
    fpr, tpr, thresholdss = roc_curve(test_labels, test_scores)
    fnr = 1 - tpr
    eer_threshold_idx = np.argmin(np.abs(fpr - fnr))
    eer = (fpr[eer_threshold_idx] + fnr[eer_threshold_idx]) / 2
    eer_threshold = thresholdss[eer_threshold_idx]
    print(f"\nThreshold used     : {name}")
    print(f"Threshold value    : {thr:.3f}")
    print(f"ROC-AUC            : {auc:.4f}")
    print(f"Accuracy           : {acc:.4f}")
    print(f"Precision          : {prec:.4f}")
    print(f"Recall             : {rec:.4f}")
    print(f"F1-score           : {f1:.4f}")
    print(f"EER                : {eer:.4f}")