In [None]:
# Cell 1: Install necessary libraries
!pip install timm faiss-cpu albumentations pytorch-metric-learning

In [None]:
# Cell 2: Import libraries
import os
import cv2
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

# Image processing
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Models
import timm

# For hard triplet mining
from pytorch_metric_learning import miners, losses

# For similarity search
import faiss

# Set warnings
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check for multiple GPUs
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

In [None]:
# Cell 3: Configuration with debug info
# Paths
DATA_DIR = '/kaggle/input/tammathon-task-1'
print(f"DATA_DIR exists: {os.path.exists(DATA_DIR)}")

TRAIN_IMG_DIR = DATA_DIR + '/train'
print(f"TRAIN_IMG_DIR exists: {os.path.exists(TRAIN_IMG_DIR)}")

VAL_IMG_DIR = DATA_DIR + '/val'
print(f"VAL_IMG_DIR exists: {os.path.exists(VAL_IMG_DIR)}")

OUTPUT_DIR = './model_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"OUTPUT_DIR created: {os.path.exists(OUTPUT_DIR)}")

# Model parameters
IMG_SIZE = 256  # Using 256×256 for better accuracy
EMB_DIM = 512   # Embedding dimension
BATCH_SIZE = 64 
NUM_WORKERS = 2
EPOCHS = 15
LR = 3e-4
MARGIN = 0.3   # For triplet loss

# Load CSV files
try:
    print(f"Trying to load train CSV from: {os.path.join(DATA_DIR, 'train.csv')}")
    print(f"File exists: {os.path.exists(os.path.join(DATA_DIR, 'train.csv'))}")
    train_df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))
    print(f"Training samples: {len(train_df)}")
    print(f"Train DataFrame columns: {train_df.columns.tolist()}")
    print(train_df.head(2))

    # Rename the 'filename' column to 'path'
    train_df.rename(columns={'filename': 'path'}, inplace=True)

except Exception as e:
    print(f"Error loading train CSV: {e}")

try:
    print(f"Trying to load val CSV from: {os.path.join(DATA_DIR, 'val.csv')}")
    print(f"File exists: {os.path.exists(os.path.join(DATA_DIR, 'val.csv'))}")
    val_df = pd.read_csv(os.path.join(DATA_DIR, 'val.csv'))
    print(f"Validation samples: {len(val_df)}")
    print(f"Val DataFrame columns: {val_df.columns.tolist()}")
    print(val_df.head(2))

    # Rename the 'filename' column to 'path'
    val_df.rename(columns={'filename': 'path'}, inplace=True)

except Exception as e:
    print(f"Error loading val CSV: {e}")

In [None]:
# Cell 4: Finding and Displaying Sample Images

# Let's properly explore the directory structure and find images
def find_sample_images(base_dir, num_samples=3):
    """Find and return paths to sample images in the dataset"""
    image_paths = []
    
    print(f"Exploring directory structure in {base_dir}...")
    # List top-level contents
    contents = os.listdir(base_dir)
    print(f"Top-level contents: {contents[:5]}")
    
    # If there's a 'train' subdirectory inside the base_dir, use that
    if 'train' in contents and os.path.isdir(os.path.join(base_dir, 'train')):
        train_dir = os.path.join(base_dir, 'train')
        print(f"Found train subdirectory: {train_dir}")
        
        # List cat folders
        cat_folders = os.listdir(train_dir)[:10]  # First 10 cat folders
        print(f"Cat folders: {cat_folders}")
        
        # Go through some cat folders to find images
        for cat_folder in cat_folders[:5]:  # Check first 5 cat folders
            cat_path = os.path.join(train_dir, cat_folder)
            if os.path.isdir(cat_path):
                print(f"Looking in {cat_path}...")
                files = [f for f in os.listdir(cat_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
                if files:
                    print(f"Found {len(files)} images, e.g., {files[:3]}")
                    # Add full paths to some images
                    for file in files[:num_samples]:
                        image_paths.append(os.path.join(cat_path, file))
                    if len(image_paths) >= num_samples:
                        break
    
    return image_paths

def display_images(image_paths):
    """Display multiple images in a grid"""
    n = len(image_paths)
    if n == 0:
        print("No images to display")
        return
    
    # Calculate grid dimensions
    cols = min(n, 3)
    rows = (n + cols - 1) // cols
    
    plt.figure(figsize=(cols * 4, rows * 4))
    for i, path in enumerate(image_paths):
        try:
            img = cv2.imread(path)
            if img is None:
                print(f"Could not read image at {path}")
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            plt.subplot(rows, cols, i + 1)
            plt.imshow(img)
            plt.title(f"Cat Image {i+1}\n{os.path.basename(os.path.dirname(path))}")
            plt.axis('off')
        except Exception as e:
            print(f"Error displaying image {path}: {e}")
    
    plt.tight_layout()
    plt.show()

# Fix the file paths based on the directory structure we discovered
def construct_correct_path(filename, train_dir, data_dir):
    """Construct the correct file path based on the dataset structure"""
    # If filename starts with 'train/', check if path should include the extra 'train' directory
    if filename.startswith('train/'):
        # Try the direct path first (with single 'train')
        direct_path = os.path.join(data_dir, filename)
        if os.path.exists(direct_path):
            return direct_path
        
        # Try with extra 'train' directory
        parts = filename.split('/')
        if len(parts) >= 2:
            # Reconstruct path with extra 'train' directory
            extra_train_path = os.path.join(train_dir, 'train', parts[1], parts[2])
            if os.path.exists(extra_train_path):
                return extra_train_path
    
    return None

# Find some sample images from the dataset
try:
    # Find sample images by exploring directory structure
    sample_paths = find_sample_images(TRAIN_IMG_DIR)
    
    if sample_paths:
        print(f"\nFound {len(sample_paths)} sample images:")
        for path in sample_paths:
            print(f" - {path}")
        
        # Display the images
        display_images(sample_paths)
    else:
        print("\nCould not find sample images through directory exploration.")
        
        # Try another approach using the CSV filenames
        print("\nTrying to construct paths from CSV filenames...")
        sample_paths = []
        for idx in range(min(5, len(train_df))):
            filename = train_df['filename'][idx]
            correct_path = construct_correct_path(filename, TRAIN_IMG_DIR, DATA_DIR)
            
            if correct_path and os.path.exists(correct_path):
                print(f"Found valid path for {filename}: {correct_path}")
                sample_paths.append(correct_path)
            else:
                # Try one more approach - direct concatenation with the train dir
                possible_path = os.path.join(TRAIN_IMG_DIR, filename)
                if os.path.exists(possible_path):
                    print(f"Found via direct concatenation: {possible_path}")
                    sample_paths.append(possible_path)
        
        if sample_paths:
            display_images(sample_paths)
        else:
            print("Could not find any valid image paths.")
            
            # Last resort - print out some directories to help debug
            print("\nListing some directories to help debug:")
            data_dir_contents = os.listdir(DATA_DIR)
            print(f"DATA_DIR contents: {data_dir_contents}")
            
            train_dir_contents = os.listdir(TRAIN_IMG_DIR)
            print(f"TRAIN_IMG_DIR contents: {train_dir_contents[:10]}")

except Exception as e:
    print(f"Error during path exploration: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Cell 5 (Revised): Custom Dataset Class with Fixed Path Construction

class CatFaceDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
        # Get unique cat IDs and map them to indices (for sampling)
        self.unique_cats = df['label'].unique()
        self.cat_to_indices = {}
        for cat_id in self.unique_cats:
            self.cat_to_indices[cat_id] = df.index[df['label'] == cat_id].tolist()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['filename']
        label = row['label']
        
        # Fix path construction based on our directory exploration
        # We saw that images are in /kaggle/input/tammathon-task-1/train/train/{CAT_ID}/{IMAGE_NUM}.png
        if img_name.startswith('train/'):
            # Extract cat_id and image_name from filename (e.g., "train/092947/00.png")
            parts = img_name.split('/')
            if len(parts) == 3:  # Should have 3 parts: "train", "092947", "00.png"
                cat_id = parts[1]
                img_file = parts[2]
                # Construct proper path with double "train"
                img_path = os.path.join(self.img_dir, 'train', 'train', cat_id, img_file)
            else:
                # Fallback
                img_path = os.path.join(self.img_dir, img_name)
        else:
            img_path = os.path.join(self.img_dir, img_name)
        
        # Print path for debugging (first few times)
        if idx < 5:  # Only print for the first 5 items to avoid flooding output
            print(f"Loading image from: {img_path}")
            print(f"File exists: {os.path.exists(img_path)}")
        
        # Load and process image
        img = cv2.imread(img_path)
        
        # Check if image was loaded successfully
        if img is None:
            print(f"Failed to load image at {img_path}")
            # Create a dummy black image as fallback
            img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = self.transform(image=img)['image']
        
        return img, label

    def get_label(self, idx):
        """Get the label for a specific index"""
        return self.df.iloc[idx]['label']
    
    def get_pos_neg_items(self, idx):
        """Get a positive (same cat) and negative (different cat) sample for the anchor at idx"""
        anchor_label = self.get_label(idx)
        
        # Get a positive sample (same cat, different image)
        pos_indices = self.cat_to_indices[anchor_label]
        if len(pos_indices) <= 1:  # If only one image for this cat, use the same image
            pos_idx = idx
        else:
            # Find a different image of the same cat
            other_indices = [i for i in pos_indices if i != idx]
            if other_indices:
                pos_idx = random.choice(other_indices)
            else:
                pos_idx = idx
        
        # Get a negative sample (different cat)
        other_cats = [cat for cat in self.unique_cats if cat != anchor_label]
        neg_cat = random.choice(other_cats)
        neg_idx = random.choice(self.cat_to_indices[neg_cat])
        
        return pos_idx, neg_idx


# Let's test our dataset class with a more focused approach
def test_dataset_focused(df, img_dir, transform=None, num_samples=3):
    """Test our dataset class with careful error handling"""
    dataset = CatFaceDataset(df, img_dir, transform)
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of unique cats: {len(dataset.unique_cats)}")
    
    # Try specific indices rather than random ones
    valid_images = []
    max_attempts = 20
    attempts = 0
    
    while len(valid_images) < num_samples and attempts < max_attempts:
        try:
            idx = attempts  # Try sequential indices
            img, label = dataset[idx]
            
            if isinstance(img, torch.Tensor):
                # If it's a tensor, it's likely valid
                valid_images.append((img, label, idx))
            elif img is not None and img.size > 0:
                # If it's a numpy array and not empty
                valid_images.append((img, label, idx))
                
        except Exception as e:
            print(f"Error with index {idx}: {e}")
        
        attempts += 1
    
    # Display valid images
    if valid_images:
        plt.figure(figsize=(15, 5))
        for i, (img, label, idx) in enumerate(valid_images):
            if i >= num_samples:
                break
                
            if isinstance(img, torch.Tensor):
                # Convert tensor to numpy for display
                img = img.permute(1, 2, 0).numpy()
                # Denormalize if needed
                img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                img = np.clip(img, 0, 1)
            
            plt.subplot(1, num_samples, i + 1)
            plt.imshow(img)
            plt.title(f"Cat ID: {label}\nIndex: {idx}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print("Could not find any valid images after multiple attempts.")
        
        # Debug: print details about a few sample filenames
        print("\nSample filenames from DataFrame:")
        for i in range(5):
            if i < len(df):
                filename = df['filename'].iloc[i]
                print(f"{i}: {filename}")

# Now test our dataset with the improved dataset class
try:
    # For test, use simple resize transform without normalization to see original colors
    test_transform = A.Compose([
        A.Resize(height=IMG_SIZE, width=IMG_SIZE),
    ])
    
    print("Testing training dataset with fixed path construction:")
    test_dataset_focused(train_df, DATA_DIR, test_transform)
except Exception as e:
    print(f"Error testing dataset: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Cell 6.1: Image Transformations
from torchvision import transforms

# Define image transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("Image transformations defined!")

In [None]:
# Cell 6.2: Data Loaders

from torch.utils.data import DataLoader

# Create full data loaders
def create_data_loaders(train_df, val_df, img_dir, batch_size, num_workers):
    """Create training and validation data loaders"""
    # Create training dataset
    train_dataset = CatFaceDataset(
        df=train_df,
        img_dir=img_dir,
        transform=train_transform
    )
    
    # Create validation dataset
    val_dataset = CatFaceDataset(
        df=val_df,
        img_dir=img_dir, 
        transform=val_transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, train_dataset, val_dataset

# Call the function to create data loaders
print("Creating data loaders...")
train_loader, val_loader, train_dataset, val_dataset = create_data_loaders(
    train_df, val_df, DATA_DIR, BATCH_SIZE, NUM_WORKERS
)

print(f"Train loader: {len(train_loader)} batches of size {BATCH_SIZE}")
print(f"Val loader: {len(val_loader)} batches of size {BATCH_SIZE}")

In [None]:
# Cell 6.3: Fixed Model Architecture and Dataset Issues

import os
import torch.nn as nn
import torch.nn.functional as F
import timm
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Define the dataset class with proper handling of `path_column`
class CatFaceDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_train=True):
        """
        Args:
            df (DataFrame): Dataframe containing image paths and labels.
            img_dir (str): Root directory for image files.
            transform (callable, optional): Transform to be applied to images.
            is_train (bool): Whether this is a training dataset or validation dataset.
        """
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.is_train = is_train

        # Find the appropriate column for image paths
        path_columns = ['image_path', 'path', 'filepath', 'img_path', 'filename']
        self.path_column = None
        for col in path_columns:
            if col in df.columns:
                self.path_column = col
                break

        # If no column is found, use the first string column as a fallback
        if not self.path_column:
            for col in df.columns:
                if df[col].dtype == object:
                    self.path_column = col
                    break
        
        if not self.path_column:
            raise ValueError("No valid path column found in the dataframe!")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            # Get the label
            label = int(self.df.iloc[idx]['label'])

            # Get the relative image path
            rel_path = str(self.df.iloc[idx][self.path_column])

            # Build the full image path
            full_path = os.path.join(self.img_dir, rel_path)

            # Check if the file exists
            if not os.path.exists(full_path):
                print(f"Warning: File not found at {full_path}")
                # Return a default blank tensor
                return torch.zeros((3, 224, 224)), 0

            # Read the image using OpenCV
            img = cv2.imread(full_path)

            # If the image is None (not found or corrupted), create a blank image
            if img is None:
                print(f"Warning: Could not read image at {full_path}")
                img = np.zeros((224, 224, 3), dtype=np.uint8)
            else:
                # Convert BGR to RGB
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # Convert numpy array to PyTorch tensor
            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0  # Scale to [0, 1]

            # Apply the transform (if any)
            if self.transform:
                img = self.transform(img)

            return img, label

        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            # Return a default blank tensor
            return torch.zeros((3, 224, 224)), 0


# Define the embedding model architecture
class CatEmbeddingModel(nn.Module):
    def __init__(self, embedding_dim=512):
        super(CatEmbeddingModel, self).__init__()
        
        # Use a pretrained EfficientNet model
        self.backbone = timm.create_model('efficientnet_b0', pretrained=True)
        
        # Replace the classifier with an embedding layer
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        
        # Add embedding projector to get desired embedding size
        self.embedding_projector = nn.Sequential(
            nn.Linear(in_features, embedding_dim),
            nn.BatchNorm1d(embedding_dim)
        )
        
    def forward(self, x):
        # Extract features from the backbone
        features = self.backbone(x)
        
        # Project to embedding space
        embeddings = self.embedding_projector(features)
        
        # L2 normalize the embeddings (important for cosine similarity)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings


# Initialize the model
print("Initializing model...")
model = CatEmbeddingModel(embedding_dim=EMB_DIM)
model = model.to(device)
print(f"Model created with embedding dimension: {EMB_DIM}")

# Reinitialize datasets (ensure `path_column` is set correctly)
train_dataset = CatFaceDataset(train_df, DATA_DIR, transform=train_transform, is_train=True)
val_dataset = CatFaceDataset(val_df, DATA_DIR, transform=val_transform, is_train=False)

# Test if datasets are working correctly
try:
    img, label = train_dataset[0]
    print(f"Train dataset: Successfully loaded image with shape {img.shape}, label {label}")
except Exception as e:
    print(f"Train dataset test failed: {e}")

try:
    img, label = val_dataset[0]
    print(f"Validation dataset: Successfully loaded image with shape {img.shape}, label {label}")
except Exception as e:
    print(f"Validation dataset test failed: {e}")

# Recreate data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Train loader created with {len(train_loader)} batches.")
print(f"Validation loader created with {len(val_loader)} batches.")

In [None]:
# Cell 6.4: Optimizer and Loss

# Initialize optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
triplet_loss = nn.TripletMarginLoss(margin=MARGIN)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True
)

print("Optimizer, loss, and scheduler initialized!")

In [None]:
# # Cell 7: Complete Training Loop with Triplet Loss

# import torch.nn.functional as F
# from torch.cuda.amp import autocast, GradScaler
# import time
# import faiss
# import matplotlib.pyplot as plt
# import copy

# # Define triplet mining function
# def get_hard_triplets(embeddings, labels, margin=0.2):
#     """Mine hard triplets from embeddings"""
#     # Move tensors to CPU for computation
#     embeddings = embeddings.detach().cpu()
#     labels = labels.detach().cpu()
    
#     # Calculate pairwise distances
#     distances = torch.cdist(embeddings, embeddings)
    
#     # Create masks for positive and negative pairs
#     same_label_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
#     diff_label_mask = ~same_label_mask
    
#     # Exclude self comparisons
#     identity_mask = torch.eye(len(labels), dtype=bool)
#     valid_pos_mask = same_label_mask & ~identity_mask
    
#     # Lists for triplets
#     anchors = []
#     positives = []
#     negatives = []
    
#     # For each anchor
#     for anchor_idx in range(len(labels)):
#         # Skip if no positives
#         if not valid_pos_mask[anchor_idx].any():
#             continue
        
#         # Get hardest positive
#         pos_dists = distances[anchor_idx].clone()
#         pos_dists[~valid_pos_mask[anchor_idx]] = float('inf')
#         pos_idx = torch.argmin(pos_dists)
        
#         # Get hardest negative
#         neg_dists = distances[anchor_idx].clone()
#         neg_dists[same_label_mask[anchor_idx]] = float('inf')
#         neg_idx = torch.argmin(neg_dists)
        
#         # Append to lists
#         anchors.append(anchor_idx)
#         positives.append(pos_idx)
#         negatives.append(neg_idx)
    
#     return [torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives)]

# # Define triplet loss function
# def triplet_loss(embeddings, labels, triplet_indices):
#     """Compute triplet loss with margin"""
#     anchors, positives, negatives = triplet_indices
    
#     # No triplets case
#     if len(anchors) == 0:
#         return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
    
#     # Get embeddings for each point in triplets
#     anchor_embs = embeddings[anchors]
#     positive_embs = embeddings[positives]
#     negative_embs = embeddings[negatives]
    
#     # Compute distances
#     pos_distances = F.pairwise_distance(anchor_embs, positive_embs)
#     neg_distances = F.pairwise_distance(anchor_embs, negative_embs)
    
#     # Apply margin and compute loss
#     losses = F.relu(pos_distances - neg_distances + MARGIN)
    
#     return losses.mean()

# # Training function
# def train_epoch(model, train_loader, optimizer, epoch, device, scaler):
#     """Train for one epoch"""
#     model.train()
#     running_loss = 0.0
#     processed_batches = 0
    
#     # Use tqdm for a progress bar
#     progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
#     for batch_idx, (images, labels) in enumerate(progress_bar):
#         try:
#             # Move data to device
#             images = images.to(device)
#             labels = labels.to(device)
            
#             # Zero the gradients
#             optimizer.zero_grad()
            
#             # Forward pass with mixed precision
#             with autocast():
#                 # Get embeddings
#                 embeddings = model(images)
                
#                 # Get hard triplets
#                 triplet_indices = get_hard_triplets(embeddings, labels, margin=MARGIN)
                
#                 # If no triplets found, continue to next batch
#                 if len(triplet_indices[0]) == 0:
#                     progress_bar.set_postfix({'status': 'no triplets'})
#                     continue
                
#                 # Compute triplet loss
#                 loss = triplet_loss(embeddings, labels, triplet_indices)
            
#             # Backward pass with gradient scaling
#             scaler.scale(loss).backward()
#             scaler.step(optimizer)
#             scaler.update()
            
#             # Update running loss
#             running_loss += loss.item()
#             processed_batches += 1
            
#             # Update progress bar
#             progress_bar.set_postfix({
#                 'loss': f"{loss.item():.4f}",
#                 'avg_loss': f"{running_loss / processed_batches:.4f}",
#                 'triplets': f"{len(triplet_indices[0])}"
#             })
                
#         except Exception as e:
#             print(f"Error in batch {batch_idx}: {str(e)}")
#             continue
    
#     # Calculate average loss for the epoch
#     epoch_loss = running_loss / max(processed_batches, 1)
#     return epoch_loss


# def evaluate(model, val_loader, device):
#     """Evaluate the model on validation data"""
#     model.eval()
#     embeddings_list = []
#     labels_list = []
    
#     with torch.no_grad():
#         for images, labels in tqdm(val_loader, desc="Evaluating"):
#             # Move data to device
#             images = images.to(device)
            
#             # Forward pass
#             embeddings = model(images)
            
#             # Store embeddings and labels
#             embeddings_list.append(embeddings.cpu())
#             labels_list.append(labels)
    
#     # Concatenate all embeddings and labels
#     all_embeddings = torch.cat(embeddings_list, dim=0)
#     all_labels = torch.cat(labels_list, dim=0)
    
#     # Calculate metrics
#     return compute_retrieval_metrics(all_embeddings, all_labels)


# def compute_retrieval_metrics(embeddings, labels, k_values=[1, 5]):
#     """Compute Top-K retrieval metrics using FAISS"""
#     # Convert to numpy for FAISS
#     embeddings_np = embeddings.numpy()
#     labels_np = labels.numpy()
    
#     # Create FAISS index
#     d = embeddings_np.shape[1]  # Embedding dimension
#     index = faiss.IndexFlatL2(d)  # L2 distance for similarity
#     index.add(embeddings_np)
    
#     # Search for nearest neighbors
#     max_k = max(k_values)
#     _, indices = index.search(embeddings_np, max_k + 1)  # +1 because first result is self
    
#     # Calculate Top-K accuracy
#     metrics = {}
#     for k in k_values:
#         # For each query, check if any of the top-k retrieved items have the same label
#         # Start from 1 to exclude self
#         correct = 0
#         for i, idx_list in enumerate(indices):
#             query_label = labels_np[i]
#             retrieved_labels = labels_np[idx_list[1:k+1]]  # Exclude self
#             if query_label in retrieved_labels:
#                 correct += 1
        
#         accuracy = correct / len(labels_np)
#         metrics[f'top_{k}_accuracy'] = accuracy
    
#     return metrics


# # Main training loop
# def train_model(model, train_loader, val_loader, optimizer, scheduler, device, num_epochs, output_dir, 
#                 resume_path=None):
#     """Full training loop with validation and checkpointing"""
#     best_top1 = 0.0
#     best_model_wts = None
#     scaler = GradScaler()  # For mixed precision training
    
#     # Create output directory if it doesn't exist
#     os.makedirs(output_dir, exist_ok=True)
    
#     # For resuming training from checkpoint
#     start_epoch = 0
#     if resume_path and os.path.exists(resume_path):
#         print(f"Loading checkpoint from {resume_path}")
#         checkpoint = torch.load(resume_path)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         if 'scheduler_state_dict' in checkpoint:
#             scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#         start_epoch = checkpoint['epoch']
#         best_top1 = checkpoint.get('best_top1', 0.0)
#         print(f"Resuming from epoch {start_epoch} with Top-1 accuracy: {best_top1:.4f}")
    
#     # Training history
#     history = {
#         'train_loss': [],
#         'val_top1': [],
#         'val_top5': [],
#         'lr': []
#     }
    
#     print(f"Starting training at {time.strftime('%Y-%m-%d %H:%M:%S')}")
#     print(f"Training for {num_epochs} epochs")
    
#     # Training loop
#     for epoch in range(start_epoch, num_epochs):
#         # Get current learning rate
#         current_lr = optimizer.param_groups[0]['lr']
#         print(f"\nEpoch {epoch+1}/{num_epochs} - Learning Rate: {current_lr:.6f}")
        
#         # Train for one epoch
#         epoch_loss = train_epoch(model, train_loader, optimizer, epoch, device, scaler)
#         print(f"Training loss: {epoch_loss:.4f}")
        
#         # Add to history
#         history['train_loss'].append(epoch_loss)
#         history['lr'].append(current_lr)
        
#         # Evaluate on validation data
#         metrics = evaluate(model, val_loader, device)
#         print(f"Validation metrics: {metrics}")
        
#         # Add to history
#         history['val_top1'].append(metrics['top_1_accuracy'])
#         history['val_top5'].append(metrics['top_5_accuracy'])
        
#         # Update learning rate based on training loss
#         scheduler.step(epoch_loss)
        
#         # Save regular checkpoint
#         checkpoint = {
#             'epoch': epoch + 1,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#             'best_top1': best_top1,
#             'embedding_dim': EMB_DIM,
#             'history': history
#         }
#         torch.save(checkpoint, os.path.join(output_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        
#         # Save best model
#         top1_accuracy = metrics['top_1_accuracy']
#         if top1_accuracy > best_top1:
#             best_top1 = top1_accuracy
#             best_model_wts = copy.deepcopy(model.state_dict())
#             torch.save(checkpoint, os.path.join(output_dir, 'best_model.pth'))
#             print(f"✅ Saved new best model with Top-1 accuracy: {best_top1:.4f}")
        
#         # Plot and save training progress after each epoch
#         if (epoch + 1) % 2 == 0:  # Every 2 epochs
#             plt.figure(figsize=(15, 5))
            
#             plt.subplot(1, 3, 1)
#             plt.plot(history['train_loss'])
#             plt.title('Training Loss')
#             plt.xlabel('Epoch')
#             plt.ylabel('Loss')
#             plt.grid(True)
            
#             plt.subplot(1, 3, 2)
#             plt.plot(history['val_top1'], label='Top-1')
#             plt.plot(history['val_top5'], label='Top-5')
#             plt.title('Validation Accuracy')
#             plt.xlabel('Epoch')
#             plt.ylabel('Accuracy')
#             plt.legend()
#             plt.grid(True)
            
#             plt.subplot(1, 3, 3)
#             plt.plot(history['lr'])
#             plt.title('Learning Rate')
#             plt.xlabel('Epoch')
#             plt.ylabel('LR')
#             plt.yscale('log')
#             plt.grid(True)
            
#             plt.tight_layout()
#             plt.savefig(os.path.join(output_dir, 'training_progress.png'))
#             plt.close()
    
#     # Load best model weights
#     if best_model_wts is not None:
#         model.load_state_dict(best_model_wts)
        
#     print(f"\nTraining complete! Best Top-1 accuracy: {best_top1:.4f}")
#     print(f"Best model saved to {os.path.join(output_dir, 'best_model.pth')}")
    
#     # Final history plot
#     plt.figure(figsize=(15, 5))
    
#     plt.subplot(1, 3, 1)
#     plt.plot(history['train_loss'])
#     plt.title('Training Loss')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.grid(True)
    
#     plt.subplot(1, 3, 2)
#     plt.plot(history['val_top1'], label='Top-1')
#     plt.plot(history['val_top5'], label='Top-5')
#     plt.title('Validation Accuracy')
#     plt.xlabel('Epoch')
#     plt.ylabel('Accuracy')
#     plt.legend()
#     plt.grid(True)
    
#     plt.subplot(1, 3, 3)
#     plt.plot(history['lr'])
#     plt.title('Learning Rate')
#     plt.xlabel('Epoch')
#     plt.ylabel('LR')
#     plt.yscale('log')
#     plt.grid(True)
    
#     plt.tight_layout()
#     plt.savefig(os.path.join(output_dir, 'final_training_history.png'))
    
#     return model, history

# # Start training
# print("Starting training...")
# model, history = train_model(
#     model=model,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     optimizer=optimizer,
#     scheduler=scheduler,
#     device=device,
#     num_epochs=EPOCHS,
#     output_dir=OUTPUT_DIR
# )

# print("Training complete! Final results:")
# print(f"Best Top-1 Accuracy: {max(history['val_top1']):.4f}")
# print(f"Best Top-5 Accuracy: {max(history['val_top5']):.4f}")
# print(f"Total training time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")

In [None]:
# Cell 7: Training Loop and Evaluation

# Training function
def train_epoch(model, train_loader, optimizer, scheduler, epoch, device, scaler):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    
    # Use tqdm for a progress bar
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        # Move data to device
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with autocast():
            # Get embeddings
            embeddings = model(images)
            
            # Get hard triplets
            triplet_indices = get_hard_triplets(embeddings, labels, margin=MARGIN)
            
            # If no triplets found, continue to next batch
            if len(triplet_indices[0]) == 0:
                continue
            
            # Compute triplet loss
            loss = triplet_loss(embeddings, labels, triplet_indices)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Update running loss
        running_loss += loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'avg_loss': f"{running_loss / (batch_idx + 1):.4f}"
        })
    
    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    return epoch_loss


def evaluate(model, val_loader, device):
    """Evaluate the model on validation data"""
    model.eval()
    embeddings_list = []
    labels_list = []
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Evaluating"):
            # Move data to device
            images = images.to(device)
            
            # Forward pass
            embeddings = model(images)
            
            # Store embeddings and labels
            embeddings_list.append(embeddings.cpu())
            labels_list.append(labels)
    
    # Concatenate all embeddings and labels
    all_embeddings = torch.cat(embeddings_list, dim=0)
    all_labels = torch.cat(labels_list, dim=0)
    
    # Calculate metrics (using k-nearest neighbors for validation)
    # For this task, we'll compute Top-1 and Top-5 accuracy
    return compute_retrieval_metrics(all_embeddings, all_labels)


def compute_retrieval_metrics(embeddings, labels, k_values=[1, 5]):
    """Compute Top-K retrieval metrics using FAISS"""
    # Convert to numpy for FAISS
    embeddings_np = embeddings.numpy()
    labels_np = labels.numpy()
    
    # Create FAISS index
    d = embeddings_np.shape[1]  # Embedding dimension
    index = faiss.IndexFlatL2(d)  # L2 distance for similarity
    index.add(embeddings_np)
    
    # Search for nearest neighbors
    max_k = max(k_values)
    _, indices = index.search(embeddings_np, max_k + 1)  # +1 because first result is self
    
    # Calculate Top-K accuracy
    metrics = {}
    for k in k_values:
        # For each query, check if any of the top-k retrieved items have the same label
        # Start from 1 to exclude self
        correct = 0
        for i, idx_list in enumerate(indices):
            query_label = labels_np[i]
            retrieved_labels = labels_np[idx_list[1:k+1]]  # Exclude self
            if query_label in retrieved_labels:
                correct += 1
        
        accuracy = correct / len(labels_np)
        metrics[f'top_{k}_accuracy'] = accuracy
    
    return metrics


# Main training loop
def train_model(model, train_loader, val_loader, optimizer, scheduler, device, num_epochs, output_dir):
    """Full training loop with validation and checkpointing"""
    best_top1 = 0.0
    scaler = GradScaler()  # For mixed precision training
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train for one epoch
        epoch_loss = train_epoch(model, train_loader, optimizer, scheduler, epoch, device, scaler)
        print(f"Training loss: {epoch_loss:.4f}")
        
        # Evaluate on validation data
        metrics = evaluate(model, val_loader, device)
        print(f"Validation metrics: {metrics}")
        
        # Update learning rate based on validation performance
        scheduler.step(epoch_loss)
        
        # Save checkpoint if improved
        top1_accuracy = metrics['top_1_accuracy']
        if top1_accuracy > best_top1:
            best_top1 = top1_accuracy
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_top1': best_top1,
                'embedding_dim': EMB_DIM
            }
            torch.save(checkpoint, os.path.join(output_dir, 'best_model.pth'))
            print(f"Saved new best model with Top-1 accuracy: {best_top1:.4f}")
    
    return model

# Start training
print("Starting training...")
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=EPOCHS,
    output_dir=OUTPUT_DIR
)

print("Training complete!")