In [None]:
!pip install torch torchvision transformers
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTConfig
import pandas as pd
import os
import numpy as np
import gdown
import zipfile
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, roc_curve
from sklearn.metrics import confusion_matrix

BASE_DIR = "Data_V9_ViT"
scaler = torch.amp.GradScaler('cuda')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "Data_V9_ViT.zip", quiet=False)
with zipfile.ZipFile("Data_V9_ViT.zip", 'r') as zip_ref:
    zip_ref.extractall("")

## ViT Model

In [None]:
class LoadDataset(Dataset):
    def __init__(self, parquet_path, base_dir=BASE_DIR):
        super().__init__()
        self.df = pd.read_parquet(parquet_path)
        self.base_dir = base_dir.rstrip("/")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        def pt_path(rel_path: str) -> str:
            return os.path.join(self.base_dir, rel_path + ".pt")

        img1_path = pt_path(row["sample_1"])
        img2_path = pt_path(row["sample_2"])

        try:
            img1_tensor = torch.load(img1_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img1_path}: {e}")

        try:
            img2_tensor = torch.load(img2_path)
        except Exception as e:
            raise ValueError(f"Could not load tensor at {img2_path}: {e}")

        label = float(row["label"])
        label_tensor = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        return img1_tensor, img2_tensor, label_tensor, row["sample_1"], row["sample_2"]

In [None]:
vit_name = "google/vit-base-patch16-384"
config = ViTConfig.from_pretrained(vit_name)
vit_backbone = ViTModel.from_pretrained(vit_name, config=config).to(device)

for param in vit_backbone.embeddings.parameters():
    param.requires_grad = False

for i in range(6):
    for param in vit_backbone.encoder.layer[i].parameters():
        param.requires_grad = False

In [None]:
class ViTEmbedder(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        outputs = self.vit(pixel_values=x) 
        return outputs.pooler_output         

vit_embedder = ViTEmbedder(vit_backbone).to(device)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=768, hidden_dims=[1024, 512, 256], dropout=0.1):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(prev, h),
                nn.LayerNorm(h),
                nn.GELU(),
            ]
            prev = h
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        z = self.net(x)          
        return F.normalize(z, p=2, dim=1)

proj_head = ProjectionHead(
    in_dim=768,
    hidden_dims=[512, 256, 128],
    dropout=0.1
).to(device)

In [None]:
class EuclideanDistance(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        diff_sq = (x - y).pow(2)
        sum_sq = diff_sq.sum(dim=1, keepdim=True)
        sum_sq = torch.clamp(sum_sq, min=self.eps)
        return torch.sqrt(sum_sq)

In [None]:
class SiameseViT(nn.Module):
    def __init__(self, embedder, head):
        super().__init__()
        self.embedder = embedder
        self.head = head
        self.distance = EuclideanDistance()

    def forward_once(self, img):
        cls_emb = self.embedder(img)
        emb256 = self.head(cls_emb)
        return emb256

    def forward(self, img_a, img_b):
        emb_a = self.forward_once(img_a)
        emb_b = self.forward_once(img_b)
        dist = self.distance(emb_a, emb_b)
        return dist

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, dist, label):

        label = label.view(-1, 1)
        pos_loss = (1.0 - label) * torch.pow(dist, 2)
        neg_loss = label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        loss = torch.mean(pos_loss + neg_loss)
        return loss

criterion = ContrastiveLoss(margin=1.0).to(device)


## Two Speed Distance Calculation


In [None]:
file_id = "file_id"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "uni_test.parquet", quiet=False)

# file_id = "file_id"
# gdown.download(f"https://drive.google.com/uc?id={file_id}", "uni_train.parquet", quiet=False)

# file_id = "file_id"
# gdown.download(f"https://drive.google.com/uc?id={file_id}", "uni_val.parquet", quiet=False)

In [None]:
file_id = "1ASQuT-LU5H1f_ydTrytQ-f5k8l-C7-yw"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "best_model.pt", quiet=False)

siamese_model = SiameseViT(vit_embedder, proj_head).to(device)
siamese_model.load_state_dict(torch.load("best_model.pt", map_location=device))

In [None]:
def load_tensor(rel_path):
    path = os.path.join(rel_path + ".pt")
    return torch.load(path).unsqueeze(0).to(device)  # shape: (1, C, H, W)

In [None]:
df = pd.read_parquet("uni_test.parquet")
rows = []

siamese_model.eval()  # Important: eval mode for inference

with torch.no_grad():
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Computing ViT embeddings"):
        p1N = load_tensor(row["sample_1_N"])
        p1F = load_tensor(row["sample_1_F"])
        p2N = load_tensor(row["sample_2_N"])
        p2F = load_tensor(row["sample_2_F"])

        emb_1N = siamese_model.forward_once(p1N)
        emb_1F = siamese_model.forward_once(p1F)
        emb_2N = siamese_model.forward_once(p2N)
        emb_2F = siamese_model.forward_once(p2F)

        # Compute distances
        d_1N_1F = torch.norm(emb_1N - emb_1F).item()
        d_2N_2F = torch.norm(emb_2N - emb_2F).item()
        d_1N_2N = torch.norm(emb_1N - emb_2N).item()
        d_1N_2F = torch.norm(emb_1N - emb_2F).item()
        d_1F_2N = torch.norm(emb_1F - emb_2N).item()
        d_1F_2F = torch.norm(emb_1F - emb_2F).item()

        sample_1_N = os.path.basename(row["sample_1_N"])
        sample_1_F = os.path.basename(row["sample_1_F"])
        sample_2_N = os.path.basename(row["sample_2_N"])
        sample_2_F = os.path.basename(row["sample_2_F"])

        rows.append({
            "d_1N_1F": d_1N_1F,
            "d_2N_2F": d_2N_2F,
            "d_1N_2N": d_1N_2N,
            "d_1N_2F": d_1N_2F,
            "d_1F_2N": d_1F_2N,
            "d_1F_2F": d_1F_2F,
            "sample_1_N": sample_1_N,
            "sample_1_F": sample_1_F,
            "sample_2_N": sample_2_N,
            "sample_2_F": sample_2_F,
            "label": row["label"]
        })

dist_df = pd.DataFrame(rows)
dist_df.to_parquet("test_dissimilarity_dataset_vit.parquet", index=False)
dist_df.to_csv("test_dissimilarity_dataset_vit.csv", index=False)

## PairNet Model

In [None]:
df_train = pd.read_parquet("train_dissimilarity_dataset_vit.parquet")
df_test = pd.read_parquet("test_dissimilarity_dataset_vit.parquet")
df_val = pd.read_parquet("val_dissimilarity_dataset_vit.parquet")

X_train = df_train[["d_1N_1F", "d_2N_2F", "d_1N_2N", "d_1N_2F", "d_1F_2N", "d_1F_2F"]].values
X_test = df_test[["d_1N_1F", "d_2N_2F", "d_1N_2N", "d_1N_2F", "d_1F_2N", "d_1F_2F"]].values
X_val = df_val[["d_1N_1F", "d_2N_2F", "d_1N_2N", "d_1N_2F", "d_1F_2N", "d_1F_2F"]].values

y_train = df_train["label"].values
y_test = df_test["label"].values
y_val = df_val["label"].values

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

class PairNet(nn.Module):
    def __init__(self):
        super(PairNet, self).__init__()
        self.fc1 = nn.Linear(6, 4)
        self.fc2 = nn.Linear(4, 2)
        self.out = nn.Linear(2, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

model = PairNet()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def find_eer_threshold(y_true, y_probs):
    fpr, tpr, thresholds = roc_curve(y_true, y_probs)
    fnr = 1 - tpr
    idx = np.argmin(np.abs(fpr - fnr))
    eer = (fpr[idx] + fnr[idx]) / 2
    return thresholds[idx], eer

best_val_loss = float("inf")
best_model_state = None
best_threshold = 0.5
patience = 40
trigger_times = 0

for epoch in range(500):
    model.train()
    epoch_loss = 0.0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        output = model(batch_X)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    # Validation
    model.eval()
    with torch.no_grad():
        val_output = model(X_val_tensor)
        val_loss = criterion(val_output, y_val_tensor)

        val_probs = torch.sigmoid(val_output).numpy().flatten()
        val_true = y_val_tensor.numpy().flatten()

        # EER-based threshold
        eer_thresh, eer_val = find_eer_threshold(val_true, val_probs)
        val_preds = (val_probs > eer_thresh).astype(int)

        val_acc = accuracy_score(val_true, val_preds)
        val_f1 = f1_score(val_true, val_preds, zero_division=0)
        val_prec = precision_score(val_true, val_preds, zero_division=0)
        val_rec = recall_score(val_true, val_preds, zero_division=0)

    print(f"Epoch {epoch+1}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss.item():.4f}, "
          f"Val Acc: {val_acc:.4f}, F1: {val_f1:.4f}, Precision: {val_prec:.4f}, Recall: {val_rec:.4f}, EER: {eer_val:.4f}")

    if val_loss.item() < best_val_loss:
        best_val_loss = val_loss.item()
        best_model_state = model.state_dict()
        best_threshold = eer_thresh
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered.")
            break

model.load_state_dict(best_model_state)

print("Min prob:", val_probs.min(), "Max prob:", val_probs.max(), "Threshold:", eer_thresh)

model.eval()
with torch.no_grad():
    logits = model(X_test_tensor)
    test_probs = torch.sigmoid(logits).numpy().flatten()
    test_preds = (test_probs > best_threshold).astype(int)
    test_true = y_test_tensor.numpy().flatten()

test_metrics = {
    "Accuracy": accuracy_score(test_true, test_preds),
    "F1 Score": f1_score(test_true, test_preds),
    "ROC AUC": roc_auc_score(test_true, test_probs),
    "Precision": precision_score(test_true, test_preds),
    "Recall": recall_score(test_true, test_preds),
    "EER Threshold": best_threshold
}

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'best_threshold': best_threshold,
}, "pairnet_model.pth")

In [None]:
file_id = "1HvTVqpN5PNSdDZkroQPp-8Yjg-tNRVdt"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "pairnet_model.pth", quiet=False)
checkpoint = torch.load("pairnet_model.pth", weights_only=False)

In [None]:
model = PairNet()
checkpoint = torch.load("pairnet_model.pth", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

best_threshold = checkpoint["best_threshold"]

with torch.no_grad():
    logits = model(X_test_tensor)
    test_probs = torch.sigmoid(logits).numpy().flatten()
    test_preds = (test_probs > best_threshold).astype(int)
    test_true = y_test_tensor.numpy().flatten()

def compute_far_frr(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[1, 0]).ravel()
    far = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    return far, frr

far, frr = compute_far_frr(test_true, test_preds)

test_metrics = {
    "Accuracy": accuracy_score(test_true, test_preds),
    "F1 Score": f1_score(test_true, test_preds),
    "ROC AUC": roc_auc_score(test_true, test_probs),
    "Precision": precision_score(test_true, test_preds),
    "Recall": recall_score(test_true, test_preds),
    "EER Threshold": best_threshold,
    "FAR": far,
    "FRR": frr
}

print(test_metrics)

In [None]:
df_test["probability"] = test_probs
df_test["prediction"] = test_preds
df_test["true_label"] = y_test
columns_to_keep = [
    "sample_1_N", "sample_1_F", "sample_2_N", "sample_2_F",
    "probability", "prediction", "true_label"
]
df_test_filtered = df_test[columns_to_keep]

df_wrong = df_test_filtered[df_test_filtered["prediction"] != df_test_filtered["true_label"]]
df_wrong.to_csv("wrong_predictions.csv", index=False)
print(f"Saved {len(df_wrong)} wrong predictions.")

df_low_conf = df_test_filtered[df_test_filtered["probability"] < 0.7]
df_low_conf.to_csv("low_confidence_predictions.csv", index=False)
print(f"Saved {len(df_low_conf)} low-confidence predictions (< 0.7).")

df_test_filtered.to_csv("all_results.csv", index=False)
print(f"Saved {len(df_test_filtered)} all results")
