In [None]:
import numpy as np
import pandas as pd
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from tqdm import tqdm
from backbones import get_model
import torch.nn.functional as F

# Load pre-trained model
data = torch.load("/l/users/sarim.hashmi/Thesis/hackathon/PetFace/outputs/cat/arcface_with_cutout/model_last.pt", map_location=torch.device('cpu'))

backbone = get_model("r50", dropout=0.0, fp16=True, num_features=768).cuda()
backbone.load_state_dict(data["state_dict_backbone"])

class CSVDataset(Dataset):
    def __init__(self, csv_files, root_dir, transform=None, is_test=False):
        self.transform = transform
        self.is_test = is_test
        self.samples = []

        raw_labels = []
        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            for _, row in tqdm(df.iterrows(), desc=f"Loading {csv_file}", total=len(df)):
                img_path = os.path.join(root_dir, row['filename'])
                label = row['label'] if not is_test else row['filename']  # dummy label for test
                self.samples.append((img_path, label))
                if not is_test:
                    raw_labels.append(label)

        if not is_test:
            classes = sorted(set(raw_labels))
            self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
            self.samples = [
                (path, self.class_to_idx[label]) for path, label in self.samples
            ]
            self.classes = classes
        else:
            self.class_to_idx = {}
            self.classes = []

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]

        # Fix Kaggle val path (could generalize if needed)
        if "/val/" in path:
            path = os.path.join("/l/users/sarim.hashmi/Thesis/hackathon/val/val", *path.split("/val/")[1:])

        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label


def get_dataloaders(train_csv, val_csv, root_dir=".", transform=None,
                    seed=42, train_ratio=0.8, batch_size=32, num_workers=4):
    """Returns: train_loader, val_loader, class_names"""
    
    print("Building dataset...")
    full_ds = CSVDataset(
        csv_files=[train_csv, val_csv],
        root_dir=root_dir,
        transform=transform,
        is_test=False,
    )

    train_size = int(train_ratio * len(full_ds))
    val_size = len(full_ds) - train_size
    
    print(f"Splitting dataset: {train_size} train, {val_size} validation samples")
    train_ds, val_ds = random_split(
        full_ds, [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )

    print("Creating data loaders...")
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, full_ds.classes


def get_test_loader(test_csv, root_dir=".", transform=None, batch_size=32, num_workers=4):
    """Returns: test_loader"""
    
    print("Building test dataset...")
    test_ds = CSVDataset(
        csv_files=[test_csv],
        root_dir=root_dir,
        transform=transform,
        is_test=True
    )
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return test_loader


def compute_class_means(model, dataloader, device):
    """
    Computes the mean feature vector for each class in `dataloader`.
    """
    model.eval()
    sum_feats = {}
    cnt = {}

    print("Computing class means...")
    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc="Processing batches"):
            imgs = imgs.to(device)
            feats = model(imgs)
            feats = feats.cpu()
            labels = labels.cpu()

            for feat, lbl in zip(feats, labels):
                lbl = int(lbl)
                if lbl not in sum_feats:
                    sum_feats[lbl] = feat.clone()
                    cnt[lbl] = 1
                else:
                    sum_feats[lbl] += feat
                    cnt[lbl] += 1

    class_means = {
        lbl: sum_feats[lbl] / cnt[lbl]
        for lbl in sum_feats
    }
    return class_means


def predict_with_ncm(model, dataloader, class_means, device, metric="euclidean", topk=3):
    """
    Predict labels on a dataset using Nearest Class Mean.
    """
    model.eval()
    top1_preds, topk_preds, targets = [], [], []

    # Stack class means (C, D)
    labels_sorted = sorted(class_means.keys())
    mean_matrix = torch.stack([class_means[lbl] for lbl in labels_sorted])  # (C, D)
    print(f"Class mean matrix shape: {mean_matrix.size()}")
    
    with torch.no_grad():
        for img, target_label in tqdm(dataloader, desc="Predicting"):
            img = img.to(device)

            # Feature extraction
            feats = model(img).cpu()  # (B, D)
            
            if metric == "euclidean":
                dists = torch.cdist(feats, mean_matrix, p=2)  # (B, C)
                topk_idx = torch.topk(dists, topk, largest=False).indices  # (B, k)
            elif metric == "cosine":
                # Normalize both
                feats_norm = F.normalize(feats, dim=1)  # (B, D)
                means_norm = F.normalize(mean_matrix, dim=1)  # (C, D)
                sims = feats_norm @ means_norm.T  # (B, C)
                topk_idx = torch.topk(sims, topk, largest=True).indices  # (B, k)
            else:
                raise ValueError("Unsupported metric")

            # Map indices to label values
            for i in range(topk_idx.size(0)):
                pred_labels = [labels_sorted[j] for j in topk_idx[i].tolist()]
                top1_preds.append(pred_labels[0])
                topk_preds.append(pred_labels)
                targets.append(target_label[i])

    return top1_preds, topk_preds, targets


# Define transforms
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

# Create dataloaders
print("Setting up dataloaders...")
train_loader, val_loader, classes = get_dataloaders(
    train_csv='/l/users/sarim.hashmi/Thesis/hackathon/train.csv',
    val_csv='/l/users/sarim.hashmi/Thesis/hackathon/val.csv',
    root_dir='/l/users/sarim.hashmi/Thesis/hackathon/PetFace/data/PetFace/images/',
    transform=transform,
    seed=123,
    train_ratio=0.9,
    batch_size=64,
    num_workers=8,
)

test_loader = get_test_loader(
    test_csv='/l/users/sarim.hashmi/Thesis/hackathon/test/test.csv',
    root_dir='/l/users/sarim.hashmi/Thesis/hackathon/test/',
    transform=transform,
    batch_size=64,
    num_workers=8,
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = backbone.to(device)

# Compute class means
class_means = compute_class_means(model, train_loader, device)

# Display class means
for lbl, mean_vec in list(class_means.items())[:5]:  # Show first 5 for brevity
    print(f"Class {lbl}: mean feature norm = {mean_vec.norm().item():.3f}")

# Save class means
torch.save(class_means, "class_means2.pt")
print("Saved class means to class_means2.pt")

# Load class means (you can comment this out if not needed)
class_means = torch.load("class_means2.pt")

# Make predictions
print("Making predictions...")
pred, topk, target = predict_with_ncm(backbone, test_loader, class_means, device, metric="cosine")

# Create output dataframe
print("Saving predictions...")
df = pd.DataFrame(topk, columns=["label_1", "label_2", "label_3"])
df.insert(0, "filename", target)  # insert filenames as first column

# Save to CSV with index
df.index.name = "#"
df.to_csv("top3_predictions_new.csv", index=True)
print("Done! Predictions saved to top3_predictions_new.csv")