Final

In [None]:
# Install required packages
!pip install timm albumentations
!pip install faiss-cpu  # Use this on Kaggle instead of faiss-gpu
!pip install -q timm albumentations
!pip install -U albumentations

In [None]:
# Import required libraries
import os
import random
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
from tqdm import tqdm
import faiss
from PIL import Image
from matplotlib.patches import Rectangle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import models, transforms
from torch.cuda.amp import autocast, GradScaler

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define paths
TRAIN_DIR = '/kaggle/input/tammathon-task-1/train/train'
VAL_DIR = '/kaggle/input/tammathon-task-1/val/val'
TEST_DIR = '/kaggle/input/tammathon-task-1/test/test'

# Check if directories exist
print(f"Train directory exists: {os.path.exists(TRAIN_DIR)}")
print(f"Validation directory exists: {os.path.exists(VAL_DIR)}")
print(f"Test directory exists: {os.path.exists(TEST_DIR)}")

In [None]:
# List all directories (cat IDs)
cat_dirs = [d for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))]
total_cats = len(cat_dirs)
print(f"Total number of cat classes: {total_cats}")

# Sample 10,000 cat IDs for initial training instead of using all
sample_size = 10000  # Use 10,000 cats instead of all ~110K cats
sampled_cat_ids = random.sample(cat_dirs, min(sample_size, total_cats))
print(f"Using {len(sampled_cat_ids)} cat classes for training (sampled from {total_cats} total)")

# Display sample images from a few random classes
visualization_cats = random.sample(sampled_cat_ids, 5)  # Just show 5 random cats
plt.figure(figsize=(15, 10))
for i, cat_id in enumerate(visualization_cats):
    cat_path = os.path.join(TRAIN_DIR, cat_id)
    images = os.listdir(cat_path)
    
    for j, img_name in enumerate(images[:2]):  # Show 2 images per cat
        if j >= 2:
            break
            
        img_path = os.path.join(cat_path, img_name)
        img = Image.open(img_path)
        
        plt.subplot(5, 2, i*2 + j + 1)
        plt.imshow(img)
        plt.title(f"Cat ID: {cat_id}")
        plt.axis('off')

plt.tight_layout()
plt.show()

# Count images per cat (sampling a subset for statistics to save time)
print("Calculating image statistics (sampling 5000 cats for efficiency)...")
stat_sample = random.sample(sampled_cat_ids, min(5000, len(sampled_cat_ids)))
images_per_cat = {}
for cat_id in tqdm(stat_sample):
    cat_path = os.path.join(TRAIN_DIR, cat_id)
    images = os.listdir(cat_path)
    images_per_cat[cat_id] = len(images)

# Display distribution
plt.figure(figsize=(10, 5))
plt.hist(list(images_per_cat.values()), bins=10)
plt.title('Distribution of Images per Cat (Sample of 5000 cats)')
plt.xlabel('Number of Images')
plt.ylabel('Number of Cats')
plt.show()

print(f"Average images per cat: {np.mean(list(images_per_cat.values())):.2f}")
print(f"Min images per cat: {min(images_per_cat.values())}")
print(f"Max images per cat: {max(images_per_cat.values())}")

In [None]:
# Define image transformations for training and validation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define function to find images by cat_id
def find_image_by_cat_id(cat_id, directory=TRAIN_DIR):
    cat_path = os.path.join(directory, cat_id)
    if os.path.isdir(cat_path):
        img_files = os.listdir(cat_path)
        if img_files:
            return os.path.join(cat_path, img_files[0])
    
    print(f"No image found for cat_id {cat_id}")
    return None

# Dataset class for cat faces
class CatFaceDataset(Dataset):
    def __init__(self, directory, cat_ids=None, transform=None):
        self.directory = directory
        self.transform = transform
        self.samples = []
        
        # If cat_ids not provided, use all cat IDs in the directory
        if cat_ids is None:
            cat_ids = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
        
        # Create a list of (cat_id, image_path) pairs
        for cat_id in cat_ids:
            cat_path = os.path.join(directory, cat_id)
            if os.path.isdir(cat_path):
                for img_name in os.listdir(cat_path):
                    img_path = os.path.join(cat_path, img_name)
                    self.samples.append((cat_id, img_path))
        
        print(f"Loaded {len(self.samples)} images for {len(set([s[0] for s in self.samples]))} cats")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        cat_id, img_path = self.samples[idx]
        
        try:
            img = Image.open(img_path).convert('RGB')
        except:
            # If image can't be loaded, create a blank image
            img = Image.new('RGB', (224, 224), color='gray')
        
        if self.transform:
            img = self.transform(img)
        
        # Return image, dummy index, and cat_id
        return img, idx, cat_id

# Triplet batch sampler for triplet learning
class TripletBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        # Group indices by cat_id
        self.cat_id_to_indices = {}
        for idx, (_, _, cat_id) in enumerate(dataset):
            if cat_id not in self.cat_id_to_indices:
                self.cat_id_to_indices[cat_id] = []
            self.cat_id_to_indices[cat_id].append(idx)
        
        # Filter cat_ids with at least 2 images (for positive pairs)
        self.valid_cat_ids = [cat_id for cat_id, indices in self.cat_id_to_indices.items() 
                              if len(indices) >= 2]
        
    def __iter__(self):
        # Create triplets
        triplets = []
        for _ in range(len(self.dataset) // 3):  # Roughly same number of triplets as original dataset size
            if not self.valid_cat_ids:
                break
                
            # Select a random cat_id for anchor/positive
            anchor_cat_id = random.choice(self.valid_cat_ids)
            
            # Select two different indices for this cat (anchor and positive)
            anchor_pos_indices = random.sample(self.cat_id_to_indices[anchor_cat_id], 2)
            
            # Select a different cat_id for negative
            negative_cat_ids = [cat_id for cat_id in self.valid_cat_ids if cat_id != anchor_cat_id]
            if not negative_cat_ids:
                continue
                
            negative_cat_id = random.choice(negative_cat_ids)
            
            # Select an index for negative
            negative_idx = random.choice(self.cat_id_to_indices[negative_cat_id])
            
            # Add the triplet
            triplets.extend([anchor_pos_indices[0], anchor_pos_indices[1], negative_idx])
        
        # Return triplet indices in batches
        for i in range(0, len(triplets), self.batch_size):
            yield triplets[i:i + self.batch_size]
    
    def __len__(self):
        return max(1, len(self.dataset) // 3 // self.batch_size)

# Create datasets
print("Creating training dataset...")
train_dataset = CatFaceDataset(TRAIN_DIR, cat_ids=sampled_cat_ids, transform=train_transform)

# Split the same cats into train and validation datasets
print("Creating validation dataset from same cat IDs...")
# Create dictionary to track cat_id -> images mapping
cat_to_images = {}
for cat_id in sampled_cat_ids:
    cat_path = os.path.join(TRAIN_DIR, cat_id)
    if os.path.isdir(cat_path):
        cat_to_images[cat_id] = [os.path.join(cat_id, img) for img in os.listdir(cat_path)]

# Create a validation dataset by sampling 20% of images from each cat
val_samples = []
for cat_id, images in cat_to_images.items():
    if len(images) >= 2:  # Need at least 2 images for train/val split
        # Take 20% for validation
        val_size = max(1, int(len(images) * 0.2))
        val_images = random.sample(images, val_size)
        for img_path in val_images:
            val_samples.append((cat_id, os.path.join(TRAIN_DIR, img_path)))

# Create validation dataset with the sampled images
class ValidationDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        cat_id, img_path = self.samples[idx]
        try:
            img = Image.open(img_path).convert('RGB')
        except:
            # If image can't be loaded, create a blank image
            img = Image.new('RGB', (224, 224), color='gray')
        
        if self.transform:
            img = self.transform(img)
        
        return img, idx, cat_id

val_dataset = ValidationDataset(val_samples, transform=val_transform)
print(f"Created validation dataset with {len(val_dataset)} images from {len(set([s[0] for s in val_samples]))} cats")

# Model setup
class TripletNetwork(nn.Module):
    def __init__(self, embedding_dim=512, pretrained=True):
        super(TripletNetwork, self).__init__()
        
        # Use EfficientNet as backbone (smaller and faster than ResNet)
        # Fix the deprecated 'pretrained' parameter with 'weights'
        if pretrained:
            self.backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            self.backbone = models.efficientnet_b0(weights=None)
        
        # Replace classifier with embedding layer
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        # Add an embedding layer
        self.embedding = nn.Sequential(
            nn.Linear(in_features, embedding_dim),
            nn.BatchNorm1d(embedding_dim)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        embedding = self.embedding(features)
        # L2 normalize embeddings
        embedding = nn.functional.normalize(embedding, p=2, dim=1)
        return embedding

# Triplet loss
class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    def forward(self, anchor, positive, negative):
        pos_dist = torch.sum((anchor - positive) ** 2, dim=1)
        neg_dist = torch.sum((anchor - negative) ** 2, dim=1)
        loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0)
        return loss.mean(), pos_dist.mean(), neg_dist.mean()

# Functions for training and evaluation
def extract_embeddings_triplet(model, dataloader, device):
    model.eval()
    embeddings_list = []
    cat_ids_list = []
    filenames_list = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting embeddings"):
            images = batch[0]
            cat_ids = batch[2]  # The third element should be cat_ids
            
            # Check if cat_ids is a tuple or list and extract accordingly
            if isinstance(cat_ids, (list, tuple)) and isinstance(cat_ids[0], (tuple, list)):
                # If cat_ids is a tuple, the third element is filename
                filenames = [cid[2] for cid in cat_ids]
                cat_ids = [cid[0] for cid in cat_ids]
            else:
                filenames = cat_ids  # For test images, cat_id is the filename
                
            images = images.to(device)
            
            # Fix autocast usage to match updated PyTorch API format
            with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
                embeddings = model(images)
            
            embeddings_list.append(embeddings.cpu().numpy())
            cat_ids_list.extend(cat_ids)
            filenames_list.extend(filenames)
    
    embeddings = np.vstack(embeddings_list)
    return embeddings, cat_ids_list, filenames_list

# Initialize model
embedding_dim = 512
model = TripletNetwork(embedding_dim=embedding_dim, pretrained=True)
model = model.to(device)


# Initialize optimizer, loss, and learning rate scheduler
batch_size = 32
num_workers = 2
triplet_loss = TripletLoss(margin=0.2)

In [None]:
# Import necessary modules
import time
from torch.amp import GradScaler

# Training function with progress updates
def train_epoch(model, train_loader, optimizer, criterion, device, scaler=None):
    model.train()
    running_loss = 0.0
    pos_dist_sum = 0.0
    neg_dist_sum = 0.0
    count = 0
    
    print(f"  Starting training... ({len(train_loader)} batches)")
    start_time = time.time()
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        # Get triplet
        images = batch[0]
        batch_size = len(images) // 3  # Each triplet has 3 images
        
        # Split batch into anchor, positive, negative
        anchor = images[:batch_size].to(device)
        positive = images[batch_size:2*batch_size].to(device)
        negative = images[2*batch_size:3*batch_size].to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            anchor_embed = model(anchor)
            positive_embed = model(positive)
            negative_embed = model(negative)
            
            loss, pos_dist, neg_dist = criterion(anchor_embed, positive_embed, negative_embed)
        
        # Backward pass with gradient scaling
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        # Update statistics
        running_loss += loss.item() * batch_size
        pos_dist_sum += pos_dist.item() * batch_size
        neg_dist_sum += neg_dist.item() * batch_size
        count += batch_size
    
    # Calculate averages with safe division
    epoch_loss = running_loss / max(1, count)
    avg_pos_dist = pos_dist_sum / max(1, count)
    avg_neg_dist = neg_dist_sum / max(1, count)
    
    elapsed_time = time.time() - start_time
    print(f"  Training completed in {elapsed_time:.1f} seconds")
    
    return epoch_loss, avg_pos_dist, avg_neg_dist

# Validation function with progress updates
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    pos_dist_sum = 0.0
    neg_dist_sum = 0.0
    count = 0
    
    # Check if validation loader has batches
    if len(val_loader) == 0:
        print("  Warning: Validation loader is empty! Returning default values.")
        return 0.0, 0.0, 0.0, 0.0
    
    print(f"  Starting validation... ({len(val_loader)} batches)")
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            # Check if batch is valid
            if len(batch[0]) < 3:
                print(f"  Warning: Batch size too small ({len(batch[0])}), skipping.")
                continue
                
            images = batch[0]
            batch_size = len(images) // 3  # Each triplet has 3 images
            
            if batch_size == 0:
                print("  Warning: Batch size is zero after division, skipping.")
                continue
                
            # Split batch into anchor, positive, negative
            anchor = images[:batch_size].to(device)
            positive = images[batch_size:2*batch_size].to(device)
            negative = images[2*batch_size:3*batch_size].to(device)
            
            # Forward pass with mixed precision
            with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
                anchor_embed = model(anchor)
                positive_embed = model(positive)
                negative_embed = model(negative)
                
                loss, pos_dist, neg_dist = criterion(anchor_embed, positive_embed, negative_embed)
            
            # Update statistics
            running_loss += loss.item() * batch_size
            pos_dist_sum += pos_dist.item() * batch_size
            neg_dist_sum += neg_dist.item() * batch_size
            count += batch_size
    
    # Calculate averages with safe division
    if count == 0:
        print("  Warning: No valid batches in validation! Using default values.")
        val_loss = 0.0
        avg_pos_dist = 0.0
        avg_neg_dist = 0.0
        separation = 0.0
    else:
        val_loss = running_loss / count
        avg_pos_dist = pos_dist_sum / count
        avg_neg_dist = neg_dist_sum / count
        separation = avg_neg_dist - avg_pos_dist
    
    elapsed_time = time.time() - start_time
    print(f"  Validation completed in {elapsed_time:.1f} seconds")
    
    return val_loss, separation, avg_pos_dist, avg_neg_dist

# Create dataloaders with triplet sampling
print("Creating data loaders...")
train_sampler = TripletBatchSampler(train_dataset, batch_size*3)  # *3 because we need triplets
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=num_workers)
print(f"Created training loader with {len(train_loader)} batches")

# Check if validation dataset has enough images per cat for triplet creation
print("Checking validation dataset...")
val_cats = {}
for cat_id, _, _ in val_dataset:
    if cat_id not in val_cats:
        val_cats[cat_id] = 0
    val_cats[cat_id] += 1

# Filter cats that have at least 2 images (needed for positive pairs)
valid_val_cats = [cat_id for cat_id, count in val_cats.items() if count >= 2]
print(f"Found {len(valid_val_cats)} cats with at least 2 images (from {len(val_cats)} total cats)")

if len(valid_val_cats) < 2:
    print("WARNING: Not enough cats with sufficient images for validation. Using a subset of training data.")
    # Use a subset of training data for validation
    val_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size*3, 
        sampler=torch.utils.data.RandomSampler(train_dataset, num_samples=min(1000, len(train_dataset))),
        num_workers=num_workers
    )
else:
    val_sampler = TripletBatchSampler(val_dataset, batch_size*3)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, num_workers=num_workers)

print(f"Created validation loader with {len(val_loader)} batches")

# Training parameters
print("Setting up training parameters...")
initial_lr = 0.0001
fine_tuning_lr = 0.00002  # Much lower learning rate for fine-tuning (5x smaller)
optimizer = optim.Adam(model.parameters(), lr=initial_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=1)
scaler = GradScaler() if torch.cuda.is_available() else None
best_separation = -float('inf')
num_epochs_phase1 = 3  # Increased from 2 to 3
num_epochs_phase2 = 4  # Increased from 3 to 4

# Phase 1: Train with frozen backbone
print("\n" + "="*50)
print("Phase 1: Training with frozen backbone")
print("="*50)

print("Model backbone: EfficientNet-B0")

print("Freezing backbone parameters...")
for param in model.backbone.parameters():
    param.requires_grad = False

# Show GPU memory stats before training
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved(0)/1e9:.2f} GB")

for epoch in range(1, num_epochs_phase1 + 1):
    print("\n" + "-"*50)
    print(f"Epoch {epoch}/{num_epochs_phase1}")
    print("-"*50)
    
    train_loss, train_pos_dist, train_neg_dist = train_epoch(model, train_loader, optimizer, triplet_loss, device, scaler)
    val_loss, separation, val_pos_dist, val_neg_dist = validate(model, val_loader, triplet_loss, device)
    
    print("\nResults:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Separation: {separation:.4f}")
    print(f"  Avg Pos Dist: {val_pos_dist:.4f}, Avg Neg Dist: {val_neg_dist:.4f}")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Update learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(separation)
    new_lr = optimizer.param_groups[0]['lr']
    
    if new_lr != old_lr:
        print(f"  Learning rate decreased: {old_lr:.6f} → {new_lr:.6f}")
    
    # Save best model
    if separation > best_separation:
        best_separation = separation
        torch.save(model.state_dict(), 'best_model_phase1.pth')
        print(f"  New best model saved! Separation: {best_separation:.4f}")

print(f"Phase 1 completed with best separation: {best_separation:.4f}")

# Phase 2: Fine-tune the entire model with proper settings
print("\n" + "="*50)
print("Phase 2: Fine-tuning the entire model with reduced learning rate")
print("="*50)

# Load best model from Phase 1
print("Loading best model from Phase 1...")
model.load_state_dict(torch.load('best_model_phase1.pth'))

print("Unfreezing backbone parameters...")
for param in model.backbone.parameters():
    param.requires_grad = True

# Create new optimizer with reduced learning rate for Phase 2
print(f"Setting fine-tuning learning rate to {fine_tuning_lr:.6f} (from {initial_lr:.6f})")
optimizer = optim.Adam(model.parameters(), lr=fine_tuning_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=1)

# Reset best separation for phase 2
best_separation_phase2 = -float('inf')

for epoch in range(1, num_epochs_phase2 + 1):
    print("\n" + "-"*50)
    print(f"Epoch {epoch}/{num_epochs_phase2}")
    print("-"*50)
    
    train_loss, train_pos_dist, train_neg_dist = train_epoch(model, train_loader, optimizer, triplet_loss, device, scaler)
    val_loss, separation, val_pos_dist, val_neg_dist = validate(model, val_loader, triplet_loss, device)
    
    print("\nResults:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Separation: {separation:.4f}")
    print(f"  Avg Pos Dist: {val_pos_dist:.4f}, Avg Neg Dist: {val_neg_dist:.4f}")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Update learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(separation)
    new_lr = optimizer.param_groups[0]['lr']
    
    if new_lr != old_lr:
        print(f"  Learning rate decreased: {old_lr:.6f} → {new_lr:.6f}")
    
    # Save best model
    if separation > best_separation_phase2:
        best_separation_phase2 = separation
        torch.save(model.state_dict(), 'best_model_phase2.pth')
        print(f"  New best model saved! Separation: {best_separation_phase2:.4f}")

print("\n" + "="*50)
print(f"Training complete!")
print(f"Phase 1 best separation: {best_separation:.4f}")
print(f"Phase 2 best separation: {best_separation_phase2:.4f}")
print(f"Overall best model: {'Phase 1' if best_separation > best_separation_phase2 else 'Phase 2'}")
print("="*50)

In [None]:
# Load the best model
if os.path.exists('best_model_phase2.pth'):
    model.load_state_dict(torch.load('best_model_phase2.pth', weights_only=True))
    print("Loaded best model from phase 2")
elif os.path.exists('best_model_phase1.pth'):
    model.load_state_dict(torch.load('best_model_phase1.pth', weights_only=True))
    print("Loaded best model from phase 1")
else:
    print("No saved model found, using the last trained model")

# Create a regular dataloader for all training images
all_train_dataset = CatFaceDataset(TRAIN_DIR, cat_ids=sampled_cat_ids, transform=val_transform)
all_train_loader = DataLoader(all_train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Extract embeddings for all training images
print("Extracting embeddings for all training images...")
train_embeddings, train_cat_ids, _ = extract_embeddings_triplet(model, all_train_loader, device)
print(f"Extracted training embeddings shape: {train_embeddings.shape}")

# Normalize embeddings for FAISS
train_embeddings = train_embeddings.astype(np.float32)
faiss.normalize_L2(train_embeddings)

# Create FAISS index
print("Creating FAISS index...")
index = faiss.IndexFlatIP(train_embeddings.shape[1])
index.add(train_embeddings)
print("FAISS index created")

In [None]:
# Load and test images from the validation directory
VAL_DIR = '/kaggle/input/tammathon-task-1/val/val'
print(f"Validation directory: {VAL_DIR}")
print(f"Does validation directory exist? {os.path.exists(VAL_DIR)}")

# Create a dataset for the official validation images
class OfficialValDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_files = []
        self.cat_ids = []
        
        # Get all image files and their cat IDs from directory structure
        for cat_id in os.listdir(directory):
            cat_dir = os.path.join(directory, cat_id)
            if os.path.isdir(cat_dir):
                for img_file in os.listdir(cat_dir):
                    if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        self.image_files.append(os.path.join(cat_dir, img_file))
                        self.cat_ids.append(cat_id)
        
        print(f"Found {len(self.image_files)} validation images across {len(set(self.cat_ids))} cat classes")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        cat_id = self.cat_ids[idx]
        filename = os.path.basename(img_path)
        
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color='gray')
        
        if self.transform:
            img = self.transform(img)
        
        return img, idx, cat_id

# Create official validation dataset and dataloader
if os.path.exists(VAL_DIR):
    official_val_dataset = OfficialValDataset(VAL_DIR, transform=val_transform)
    official_val_loader = DataLoader(official_val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Extract embeddings for official validation images
    print("Extracting embeddings for official validation images...")
    val_embeddings, val_true_cat_ids, _ = extract_embeddings_triplet(model, official_val_loader, device)
    print(f"Extracted validation embeddings shape: {val_embeddings.shape}")

    # Normalize validation embeddings
    val_embeddings = val_embeddings.astype(np.float32)
    faiss.normalize_L2(val_embeddings)

    # Calculate overall accuracy on validation set
    print("Finding top matches for validation images...")
    D, I = index.search(val_embeddings, k=3)

    correct_top1 = 0
    correct_top3 = 0
    for i, true_cat_id in enumerate(val_true_cat_ids):
        top3_indices = I[i]
        top3_cat_ids = [train_cat_ids[idx] for idx in top3_indices]
        
        if true_cat_id == top3_cat_ids[0]:
            correct_top1 += 1
        if true_cat_id in top3_cat_ids:
            correct_top3 += 1

    total_val = len(val_true_cat_ids)
    top1_accuracy = correct_top1 / total_val
    top3_accuracy = correct_top3 / total_val

    print(f"Official Validation Results:")
    print(f"Top-1 Accuracy: {top1_accuracy:.4f} ({correct_top1}/{total_val})")
    print(f"Top-3 Accuracy: {top3_accuracy:.4f} ({correct_top3}/{total_val})")

    # Visualize validation results for random examples
    num_val_examples = 5
    indices = random.sample(range(len(val_embeddings)), min(num_val_examples, len(val_embeddings)))

    plt.figure(figsize=(20, 4 * num_val_examples))

    for plot_idx, idx in enumerate(indices):
        query_embedding = val_embeddings[idx:idx+1]
        true_cat_id = val_true_cat_ids[idx]
        
        # Find the image path from the dataset
        query_img_path = official_val_dataset.image_files[idx]
        
        # Search for similar cats
        D, I = index.search(query_embedding, k=5)
        
        # Get top cat IDs
        top_indices = I[0]
        top_cat_ids = [train_cat_ids[idx] for idx in top_indices]
        top_scores = D[0]
        
        # Display query image
        plt.subplot(num_val_examples, 6, plot_idx*6 + 1)
        query_img = Image.open(query_img_path).convert('RGB')
        plt.imshow(query_img)
        plt.title(f"Query: {true_cat_id}")
        plt.axis('off')
        
        # Display top 5 predicted cat images
        for j in range(5):
            plt.subplot(num_val_examples, 6, plot_idx*6 + j + 2)
            pred_cat_id = top_cat_ids[j]
            pred_path = find_image_by_cat_id(pred_cat_id)
            similarity = top_scores[j]
            is_correct = (pred_cat_id == true_cat_id)
            
            if pred_path:
                try:
                    pred_img = Image.open(pred_path).convert('RGB')
                    plt.imshow(pred_img)
                    
                    # Add colored border to indicate correct/incorrect
                    border_color = 'green' if is_correct else 'red'
                    plt.gca().add_patch(Rectangle((0, 0), 1, 1, fill=False, edgecolor=border_color, lw=5, 
                                                transform=plt.gca().transAxes))
                except Exception as e:
                    plt.text(0.5, 0.5, "Error loading image", ha='center', va='center', transform=plt.gca().transAxes)
            else:
                plt.text(0.5, 0.5, f"No image for\n{pred_cat_id}", ha='center', va='center', transform=plt.gca().transAxes)
            
            plt.title(f"Pred {j+1}: {pred_cat_id}\nScore: {similarity:.3f}\n{'✓' if is_correct else '✗'}")
            plt.axis('off')

    plt.tight_layout()
    plt.show()
else:
    print("Validation directory not found. Skipping validation accuracy calculation.")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os
import torch
import faiss
import numpy as np
from tqdm import tqdm
import random

# Load the CSV file
csv_path = '/kaggle/input/tammathon-task-1/train.csv'
print(f"Loading CSV from: {csv_path}")

# Load the CSV and filter to include only the first 10,000 classes
df = pd.read_csv(csv_path)
print(f"CSV loaded with {len(df)} total entries")

# Filter to only include first 10,000 classes
df['label'] = df['label'].astype(int)  # Ensure label is integer
filtered_df = df[df['label'] < 10000]
print(f"Filtered to {len(filtered_df)} entries with labels < 10000")

# Save filtered CSV temporarily
filtered_csv_path = '/kaggle/working/filtered_train.csv'
filtered_df.to_csv(filtered_csv_path, index=False)

# Create a dataset for the filtered CSV images
class CSVImageDataset(Dataset):
    def __init__(self, csv_file, base_dir='/kaggle/input/tammathon-task-1', transform=None):
        self.df = pd.read_csv(csv_file)
        self.base_dir = base_dir
        self.transform = transform
        
        # Clean up and validate data
        self.image_paths = []
        self.labels = []
        self.filenames = []
        
        for idx, row in self.df.iterrows():
            try:
                filename = row['filename']
                label = row['label']
                
                # Construct full path (try different possible paths)
                img_path = None
                possible_paths = [
                    os.path.join(self.base_dir, filename),
                    os.path.join(self.base_dir, 'train', filename),
                    filename if os.path.isabs(filename) else os.path.join(self.base_dir, filename)
                ]
                
                for path in possible_paths:
                    if os.path.exists(path):
                        img_path = path
                        break
                
                if img_path:
                    self.image_paths.append(img_path)
                    self.labels.append(label)
                    self.filenames.append(filename)
            except Exception as e:
                print(f"Error processing row {idx}: {e}")
        
        print(f"Found {len(self.image_paths)} valid images out of {len(self.df)} entries")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        filename = self.filenames[idx]
        
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color='gray')
        
        if self.transform:
            img = self.transform(img)
        
        return img, idx, str(label)  # Return as (image, index, label)

# Create dataset and dataloader for filtered CSV images
csv_dataset = CSVImageDataset(filtered_csv_path, transform=val_transform)
csv_loader = DataLoader(csv_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Load the best model
print("Loading best model...")
best_model_path = 'best_model_phase2.pth'  # Use Phase 2 model if it's better, otherwise use Phase 1
if not os.path.exists(best_model_path):
    best_model_path = 'best_model_phase1.pth'
    
model.load_state_dict(torch.load(best_model_path, weights_only=True))
model.eval()

# Extract embeddings for filtered CSV images
print("Extracting embeddings for CSV images (first 10,000 classes only)...")
csv_embeddings, csv_indices, csv_labels = extract_embeddings_triplet(model, csv_loader, device)
print(f"Extracted {len(csv_embeddings)} embeddings")

# Normalize embeddings
csv_embeddings = csv_embeddings.astype(np.float32)
faiss.normalize_L2(csv_embeddings)

# Create index for searching
print("Creating FAISS index...")
dimension = csv_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(csv_embeddings)

# Evaluate model predictions
print("Evaluating model predictions...")
correct_predictions = 0
total_predictions = 0
correct_examples = []

# Group images by label
label_to_indices = {}
for i, label in enumerate(csv_labels):
    if label not in label_to_indices:
        label_to_indices[label] = []
    label_to_indices[label].append(i)

# Evaluate each image
for idx, label in enumerate(tqdm(csv_labels, desc="Evaluating")):
    # Skip if this label has only one image
    if len(label_to_indices[label]) <= 1:
        continue
    
    # Use current image as query
    query_embedding = csv_embeddings[idx:idx+1]
    query_path = csv_dataset.image_paths[idx]
    query_label = label
    
    # Search for similar images (k=2 to include the query image itself)
    D, I = index.search(query_embedding, k=2)
    
    # Get top match (excluding self)
    top_match_idx = I[0][1] if I[0][0] == idx else I[0][0]
    predicted_label = csv_labels[top_match_idx]
    
    # Check if prediction is correct
    is_correct = (predicted_label == query_label)
    
    if is_correct:
        correct_predictions += 1
        # Store example if it's correct
        correct_examples.append({
            'query_path': query_path,
            'query_label': query_label,
            'pred_path': csv_dataset.image_paths[top_match_idx],
            'pred_label': predicted_label,
            'score': D[0][1] if I[0][0] == idx else D[0][0]
        })
    
    total_predictions += 1

# Calculate accuracy
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print(f"\nModel Performance Overview (First 10,000 Classes Only):")
print(f"Total predictions: {total_predictions}")
print(f"Correct predictions: {correct_predictions}")
print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

# Display 10 examples of correct predictions
if correct_examples:
    # Sort by confidence (score)
    correct_examples.sort(key=lambda x: x['score'], reverse=True)
    
    # Select 10 random examples from the top 100 to show diversity
    top_examples = correct_examples[:100]
    examples_to_show = random.sample(top_examples, min(10, len(top_examples)))
    
    # Create figure
    fig, axes = plt.subplots(10, 2, figsize=(12, 25))
    fig.suptitle('Examples of Correct Predictions (Query: Prediction)', fontsize=16)
    
    for i, example in enumerate(examples_to_show):
        if i >= 10:
            break
            
        # Display query image
        try:
            query_img = Image.open(example['query_path']).convert('RGB')
            axes[i, 0].imshow(query_img)
            axes[i, 0].set_title(f"Query: Label {example['query_label']}")
            axes[i, 0].axis('off')
        except Exception as e:
            axes[i, 0].text(0.5, 0.5, "Error loading image", ha='center', va='center')
            axes[i, 0].axis('off')
        
        # Display predicted image
        try:
            pred_img = Image.open(example['pred_path']).convert('RGB')
            axes[i, 1].imshow(pred_img)
            axes[i, 1].set_title(f"Prediction: Label {example['pred_label']}\nScore: {example['score']:.3f}")
            axes[i, 1].axis('off')
        except Exception as e:
            axes[i, 1].text(0.5, 0.5, "Error loading image", ha='center', va='center')
            axes[i, 1].axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()
else:
    print("No correct predictions to display.")

print("Evaluation complete!")