In [54]:
import os
import random
from PIL import Image
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
import pickle


In [55]:

# ========== CONFIG ==========
DATA_DIR = 'tammathon-task-1\\train\\train\\'
BATCH_SIZE = 32
NUM_EPOCHS = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
N = 100 # how many folders to use (for testing purposes you can set N = 100)



In [56]:
# ========== STEP 1: LOAD & SPLIT DATA ==========
def load_image_paths(data_dir):
    dirs = os.listdir(data_dir)
    label_names = sorted(dirs)[:N]
    print(len(label_names))
    
    label_to_idx = {label: idx for idx, label in enumerate(label_names)}
    idx_to_label = {idx: label for label, idx in label_to_idx.items()}

    train_data = []
    val_data = []

    for i, label_name in enumerate(label_names):
        label_path = os.path.join(data_dir, label_name)
        image_paths = [os.path.join(label_path, f) for f in sorted(os.listdir(label_path), reverse=False) if f.endswith('.png')]
        
        if len(image_paths) >= 3:
            train_data.extend((img, label_to_idx[label_name]) for img in image_paths[1:])
            val_data.append((image_paths[0], label_to_idx[label_name]))
        else:
            train_data.extend((img, label_to_idx[label_name]) for img in image_paths[:])
        
        if not i % 10000:
            print(i)

    return train_data, val_data, label_to_idx, idx_to_label


train_data, val_data, label_to_idx, idx_to_label = load_image_paths(DATA_DIR)
num_classes = len(label_to_idx)


100
0


In [57]:
# ========== STEP 2: DEFINE DATASET ==========

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

class CatDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

train_dataset = CatDataset(train_data, transform=transform)
val_dataset = CatDataset(val_data, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


In [58]:
# ========== STEP 3: MODEL WITH EMBEDDINGS ==========

class EmbeddingModel(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()  # remove classification head
        self.embedding = nn.Linear(in_features, embedding_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.embedding(x)
        return F.normalize(x, p=2, dim=1)  # important for cosine similarity

model = EmbeddingModel().to(DEVICE)






In [59]:
# ========== STEP 4: ARC FACE LOSS ==========

class ArcFaceLoss(nn.Module):
    def __init__(self, embedding_dim, num_classes, scale=30.0, margin=0.50):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)
        self.scale = scale
        self.margin = margin
        self.num_classes = num_classes

    def forward(self, embeddings, labels):
        embeddings = F.normalize(embeddings, p=2, dim=1)
        W = F.normalize(self.weight, p=2, dim=1)

        cosine = F.linear(embeddings, W)
        theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.margin)

        one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(DEVICE)
        output = cosine * (1 - one_hot) + target_logits * one_hot
        return F.cross_entropy(self.scale * output, labels)

loss_fn = ArcFaceLoss(embedding_dim=512, num_classes=num_classes).to(DEVICE)
optimizer = optim.Adam(list(model.parameters()) + list(loss_fn.parameters()), lr=1e-4)

In [61]:
# ========== STEP 5: TRAINING LOOP ==========

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        embeddings = model(images)
        loss = loss_fn(embeddings, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    result_vars = {
        'model': model,
        'running_loss': running_loss,
        'optimizer': optimizer,
        }
        
    with open(f'model_03_result_vars_epoch{epoch+1}.pkl', 'wb') as f:
        pickle.dump(result_vars, f)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {running_loss / len(train_loader):.4f}")


Epoch 1/10, Loss: 17.9793
Epoch 2/10, Loss: 14.7424
Epoch 3/10, Loss: 11.0001
Epoch 4/10, Loss: 6.7598
Epoch 5/10, Loss: 3.6038
Epoch 6/10, Loss: 1.4773
Epoch 7/10, Loss: 0.5082
Epoch 8/10, Loss: 0.1639
Epoch 9/10, Loss: 0.0783
Epoch 10/10, Loss: 0.0489


In [62]:
# ========== STEP 6: KNN VALIDATION ==========

model.eval()

# Get embeddings for train
train_embeddings = []
train_labels = []
with torch.no_grad():
    for images, labels in DataLoader(train_dataset, batch_size=BATCH_SIZE):
        images = images.to(DEVICE)
        emb = model(images)
        train_embeddings.append(emb.cpu())
        train_labels.extend(labels)

train_embeddings = torch.cat(train_embeddings).numpy()
train_labels = [idx_to_label[i.item()] for i in train_labels]

# Fit KNN (cosine = 1 - similarity)
knn = NearestNeighbors(n_neighbors=10, metric='cosine', algorithm='brute')
knn.fit(train_embeddings)

# Validation with enforced 3 unique nearest neighbors
top3_predictions = []
true_labels_str = []

with torch.no_grad():
    for image, label in val_loader:
        image = image.to(DEVICE)
        emb = model(image).cpu().numpy()
        distances, indices = knn.kneighbors(emb, n_neighbors=20)

        seen = set()
        preds = []
        for idx in indices[0]:
            lbl = train_labels[idx]
            if lbl not in seen:
                preds.append(lbl)
                seen.add(lbl)
            if len(preds) == 3:
                break
        
        top3_predictions.append(preds)
        true_labels_str.append(idx_to_label[label.item()])

# Accuracy
correct = sum(t in p for t, p in zip(true_labels_str, top3_predictions))
print(f"Top-3 Accuracy: {correct / len(val_data):.4f}")


Top-3 Accuracy: 0.9500
