# Fine-tuning ConvNeXtV2 for Emoji Vendor Classification

Using deterministic augmentation and TTA for improved performance.

## Install and Import

In [None]:
# Install required packages
%pip install -q kagglehub transformers torch torchvision pillow datasets accelerate pandas scikit-learn

import kagglehub
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
from transformers import AutoImageProcessor, AutoModelForImageClassification
from transformers.modeling_outputs import ImageClassifierOutput
import torch.nn as nn
from datasets import Dataset as HFDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
from tqdm import tqdm
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import pandas as pd
import random
import hashlib

# GPU Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    torch.cuda.empty_cache()
else:
    print("WARNING: CUDA not available. Training will be slow on CPU.")

## Deterministic Augmentation System

In [None]:
# Unified Deterministic Augmentation System
# Same augmentation methods used for both training and TTA inference
# Deterministic based on image hash/index for reproducibility

import hashlib

class DeterministicAugmentation:
    """
    Unified deterministic augmentation system for both training and TTA.
    Uses the same augmentation methods, but deterministic based on image content.
    """
    
    def __init__(self, image_size=224, seed=42):
        self.image_size = image_size
        self.seed = seed
        
        # Define augmentation parameters (same for training and TTA)
        self.rotation_angles = [-10, -5, 5, 10]  # Fixed rotation angles
        self.crop_ratios = [0.75, 0.85, 0.9, 0.95]  # Fixed crop ratios
        self.color_jitter_params = {
            'brightness': 0.3,
            'contrast': 0.3,
            'saturation': 0.3,
            'hue': 0.1
        }
        self.translate_range = (0.1, 0.1)
        self.blur_sigma = (0.1, 0.5)
        
    def _get_deterministic_seed(self, image_or_hash):
        """Generate deterministic seed from image hash or index."""
        if isinstance(image_or_hash, Image.Image):
            # Use image content hash
            img_bytes = image_or_hash.tobytes()
            hash_val = int(hashlib.md5(img_bytes).hexdigest()[:8], 16)
        elif isinstance(image_or_hash, (str, int)):
            # Use provided hash/index
            hash_val = hash(str(image_or_hash)) & 0xFFFFFFFF
        else:
            hash_val = hash(str(image_or_hash)) & 0xFFFFFFFF
        return hash_val
    
    def horizontal_flip(self, image, apply=True):
        """Deterministic horizontal flip."""
        if apply:
            return F.hflip(image)
        return image
    
    def rotation(self, image, angle):
        """Deterministic rotation."""
        return F.rotate(image, angle)
    
    def center_crop(self, image, crop_ratio=0.9):
        """Deterministic center crop."""
        w, h = image.size
        crop_size = int(min(w, h) * crop_ratio)
        return F.center_crop(image, [crop_size, crop_size])
    
    def corner_crop(self, image, crop_ratio=0.9, position='tl'):
        """Deterministic corner crop (top-left, top-right, bottom-left, bottom-right)."""
        w, h = image.size
        crop_size = int(min(w, h) * crop_ratio)
        
        if position == 'tl':  # Top-left
            return F.crop(image, 0, 0, crop_size, crop_size)
        elif position == 'tr':  # Top-right
            return F.crop(image, 0, w - crop_size, crop_size, crop_size)
        elif position == 'bl':  # Bottom-left
            return F.crop(image, h - crop_size, 0, crop_size, crop_size)
        elif position == 'br':  # Bottom-right
            return F.crop(image, h - crop_size, w - crop_size, crop_size, crop_size)
        return image
    
    def resized_crop(self, image, crop_ratio=0.85):
        """Deterministic resized crop (simulating RandomResizedCrop)."""
        w, h = image.size
        crop_size = int(min(w, h) * crop_ratio)
        # Use center crop as deterministic version
        cropped = F.center_crop(image, [crop_size, crop_size])
        return cropped.resize((self.image_size, self.image_size), Image.BILINEAR)
    
    def color_jitter(self, image, seed_val):
        """Deterministic color jitter based on seed."""
        # Use seed to deterministically select jitter parameters
        np.random.seed(seed_val % (2**32))
        brightness_factor = 1.0 + np.random.uniform(-self.color_jitter_params['brightness'], 
                                                      self.color_jitter_params['brightness'])
        contrast_factor = 1.0 + np.random.uniform(-self.color_jitter_params['contrast'],
                                                  self.color_jitter_params['contrast'])
        saturation_factor = 1.0 + np.random.uniform(-self.color_jitter_params['saturation'],
                                                     self.color_jitter_params['saturation'])
        hue_factor = np.random.uniform(-self.color_jitter_params['hue'],
                                      self.color_jitter_params['hue'])
        
        # Apply deterministic color jitter
        img = F.adjust_brightness(image, brightness_factor)
        img = F.adjust_contrast(img, contrast_factor)
        img = F.adjust_saturation(img, saturation_factor)
        img = F.adjust_hue(img, hue_factor)
        return img
    
    def affine_transform(self, image, seed_val):
        """Deterministic affine transform (translation)."""
        np.random.seed(seed_val % (2**32))
        translate_x = np.random.uniform(-self.translate_range[0], self.translate_range[0])
        translate_y = np.random.uniform(-self.translate_range[1], self.translate_range[1])
        return F.affine(image, angle=0, translate=(translate_x * image.width, translate_y * image.height),
                        scale=1.0, shear=0.0)
    
    def gaussian_blur(self, image, seed_val):
        """Deterministic Gaussian blur."""
        np.random.seed(seed_val % (2**32))
        sigma = np.random.uniform(self.blur_sigma[0], self.blur_sigma[1])
        return F.gaussian_blur(image, kernel_size=3, sigma=[sigma, sigma])
    
    def get_augmentations(self, image, num_augmentations=10, seed_source=None):
        """
        Generate deterministic augmented versions for TTA.
        Uses same augmentation methods as training.
        
        Args:
            image: PIL Image
            num_augmentations: Number of augmentations to generate
            seed_source: Optional seed source (image hash, index, etc.) for determinism
        """
        augmentations = []
        
        # Get deterministic seed
        if seed_source is None:
            seed_source = self._get_deterministic_seed(image)
        seed_val = seed_source
        
        # Original (resized)
        augmentations.append(image.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Horizontal flip
        flipped = self.horizontal_flip(image, apply=True)
        augmentations.append(flipped.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Rotations (deterministic angles)
        for angle in self.rotation_angles[:min(4, num_augmentations - len(augmentations))]:
            rotated = self.rotation(image, angle)
            augmentations.append(rotated.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Corner crops (4 corners)
        corners = ['tl', 'tr', 'bl', 'br']
        for corner in corners[:min(4, num_augmentations - len(augmentations))]:
            cropped = self.corner_crop(image, crop_ratio=0.9, position=corner)
            augmentations.append(cropped.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Center crop
        if len(augmentations) < num_augmentations:
            center_cropped = self.center_crop(image, crop_ratio=0.9)
            augmentations.append(center_cropped.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Resized crop (simulating RandomResizedCrop)
        if len(augmentations) < num_augmentations:
            resized_cropped = self.resized_crop(image, crop_ratio=0.85)
            augmentations.append(resized_cropped)
        
        # Color jitter
        if len(augmentations) < num_augmentations:
            jittered = self.color_jitter(image, seed_val)
            augmentations.append(jittered.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Affine transform
        if len(augmentations) < num_augmentations:
            affine_img = self.affine_transform(image, seed_val + 1)
            augmentations.append(affine_img.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        # Gaussian blur
        if len(augmentations) < num_augmentations:
            blurred = self.gaussian_blur(image, seed_val + 2)
            augmentations.append(blurred.resize((self.image_size, self.image_size), Image.BILINEAR))
        
        return augmentations[:num_augmentations]
    
    def apply_training_augmentation(self, image, index=None):
        """
        Apply deterministic training augmentation.
        Uses same methods as TTA but applied once per training sample.
        
        Args:
            image: PIL Image
            index: Optional index for deterministic seed
        """
        # Get deterministic seed from index or image
        if index is not None:
            seed_val = hash(str(index)) & 0xFFFFFFFF
        else:
            seed_val = self._get_deterministic_seed(image)
        
        # Apply augmentations deterministically based on seed
        np.random.seed(seed_val % (2**32))
        
        # Horizontal flip (50% probability, but deterministic)
        should_flip = (seed_val % 2 == 0)
        if should_flip:
            image = self.horizontal_flip(image, apply=True)
        
        # Rotation (deterministic angle selection)
        angle_idx = (seed_val // 2) % len(self.rotation_angles)
        angle = self.rotation_angles[angle_idx]
        image = self.rotation(image, angle)
        
        # Resized crop (deterministic crop ratio)
        crop_idx = (seed_val // 10) % len(self.crop_ratios)
        crop_ratio = self.crop_ratios[crop_idx]
        w, h = image.size
        crop_size = int(min(w, h) * crop_ratio)
        image = F.center_crop(image, [crop_size, crop_size])
        
        # Color jitter (deterministic)
        image = self.color_jitter(image, seed_val)
        
        # Affine transform (deterministic, 50% probability)
        if (seed_val // 3) % 2 == 0:
            image = self.affine_transform(image, seed_val + 1)
        
        # Gaussian blur (deterministic, 20% probability)
        if (seed_val // 5) % 5 == 0:
            image = self.gaussian_blur(image, seed_val + 2)
        
        # Resize to final size
        image = image.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        return image

# Create global augmentation instance
augmentation_system = DeterministicAugmentation(image_size=224, seed=42)

# For backward compatibility
EnhancedTTAAugmentation = DeterministicAugmentation
TTAAugmentation = DeterministicAugmentation

# Training transform using deterministic augmentation
# This will be applied in EmojiDataset using the augmentation_system
train_transform = None  # Will use augmentation_system.apply_training_augmentation instead

print("Unified deterministic augmentation system defined!")
print("Same augmentation methods used for both training and TTA inference.")


## Load ConvNeXtV2 Model

In [None]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("facebook/convnextv2-tiny-22k-224")
base_model = AutoModelForImageClassification.from_pretrained("facebook/convnextv2-tiny-22k-224")

# Move model to GPU
base_model = base_model.to(device)
print(f"ConvNeXtV2 model loaded and moved to {device}")
print(f"Model config: {base_model.config}")

## Define Vendor Classes

In [None]:
VENDOR_CLASSES = [
    "Apple", "DoCoMo", "Facebook", "Gmail", "Google", "JoyPixels",
    "KDDI", "Samsung", "SoftBank", "Twitter", "Windows"
]

VENDOR_TO_IDX = {vendor: idx for idx, vendor in enumerate(VENDOR_CLASSES)}
IDX_TO_VENDOR = {idx: vendor for vendor, idx in VENDOR_TO_IDX.items()}

print(f"Number of vendor classes: {len(VENDOR_CLASSES)}")
print("Vendor classes:", VENDOR_CLASSES)

## Dataset Class

In [None]:
class EmojiDataset(Dataset):
    """Dataset class with support for deterministic training-time augmentation."""
    def __init__(self, image_paths, labels, processor, use_augmentation=False, augmentation_system=None):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.use_augmentation = use_augmentation
        self.augmentation_system = augmentation_system if augmentation_system is not None else globals().get('augmentation_system', None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            image = Image.new('RGB', (224, 224), color='white')

        if self.use_augmentation:
            aug_system = globals().get('augmentation_system', None)
            if aug_system is not None:
                image = aug_system.apply_training_augmentation(image, index=idx)

        inputs = self.processor(image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0)

        return {
            'pixel_values': pixel_values,
            'labels': torch.tensor(label, dtype=torch.long)
        }

## Model Definition

In [None]:
class ConvNeXtV2ForEmojiClassification(nn.Module):    
    def __init__(self, num_labels=len(VENDOR_CLASSES), hidden_size=None):        
        super().__init__()        
        self.base_model = base_model        
        self.num_labels = num_labels        # Get the hidden size from the model config        
        if hidden_size is None:            # ConvNeXtV2 uses hidden_sizes list            
            if hasattr(base_model.config, 'hidden_sizes'):                
                hidden_size = base_model.config.hidden_sizes[-1]            
            elif hasattr(base_model.config, 'hidden_size'):                
                hidden_size = base_model.config.hidden_size            
            else:                # Default for tiny model                
                hidden_size = 768        # Improved classification head with more capacity        
        self.classifier = nn.Sequential(nn.LayerNorm(hidden_size),
        nn.Dropout(0.3),            
        nn.Linear(hidden_size, hidden_size // 2),            
        nn.GELU(),            
        nn.LayerNorm(hidden_size // 2),            
        nn.Dropout(0.2),            
        nn.Linear(hidden_size // 2, num_labels))    
        
    def forward(self, pixel_values, labels=None):        # Get embeddings from ConvNeXtV2        # ConvNeXtV2 outputs ImageClassifierOutput        outputs = self.base_model(pixel_values=pixel_values, output_hidden_states=True)                # Extract pooled features        # For ConvNeXtV2, hidden_states contains feature maps        if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:            # Get the last hidden state (feature map)            feature_map = outputs.hidden_states[-1]  # Shape: (batch, channels, H, W)            # Global average pooling            pooled_output = feature_map.mean(dim=[2, 3])  # Shape: (batch, channels)        else:            # Fallback: use the model's pooler if available            # For some models, we can access the backbone directly            with torch.no_grad():                # Get features before classifier                backbone_output = self.base_model.convnextv2(pixel_values)                if len(backbone_output.shape) == 4:                    pooled_output = backbone_output.mean(dim=[2, 3])                else:                    pooled_output = backbone_output        # Classification        logits = self.classifier(pooled_output)        loss = None        if labels is not None:            # Add label smoothing for better generalization            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))        return ImageClassifierOutput(            loss=loss,            logits=logits        )# Create the modelclassification_model = ConvNeXtV2ForEmojiClassification(num_labels=len(VENDOR_CLASSES))classification_model = classification_model.to(device)print("ConvNeXtV2 model created and moved to device")
        # Get features from ConvNeXtV2 backbone
        # Access the convnextv2 backbone directly
        features = self.base_model.convnextv2(pixel_values)
        # Global average pooling: (B, C, H, W) -> (B, C)
        if len(features.shape) == 4:
            pooled_output = features.mean(dim=[2, 3])
        else:
            pooled_output = features

        # Classification
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return ImageClassifierOutput(loss=loss, logits=logits)

## Training Functions

In [None]:
def train_epoch(model, train_loader, optimizer, device, scaler=None):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    is_cuda_available = (device.type == 'cuda')

    for batch in tqdm(train_loader, desc="Training"):
        pixel_values = batch['pixel_values'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=is_cuda_available):
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return total_loss / len(train_loader), 100 * correct / total

def validate(model, val_loader, device):
    """Validate model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    is_cuda_available = (device.type == 'cuda')

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            pixel_values = batch['pixel_values'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=is_cuda_available):
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

            total_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    return total_loss / len(val_loader), accuracy, all_preds, all_labels

## TTA Inference Functions (for final tests only)

In [None]:
def predict_with_weighted_tta(model, image, processor, tta_aug, num_augmentations=10, device='cuda'):
    """Predict using Test-Time Augmentation with weighted averaging."""
    model.eval()
    augmented_images = tta_aug.get_augmentations(image, num_augmentations=num_augmentations)
    all_logits = []
    weights = []
    
    with torch.no_grad():
        for i, aug_image in enumerate(augmented_images):
            inputs = processor(aug_image, return_tensors="pt")
            pixel_values = inputs['pixel_values'].to(device)
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            all_logits.append(logits)
            weight = 2.0 if i == 0 else 1.0
            weights.append(weight)
    
    weights = torch.tensor(weights, device=device).view(-1, 1, 1)
    weighted_logits = torch.stack(all_logits) * weights
    averaged_logits = weighted_logits.sum(dim=0) / weights.sum()
    probabilities = torch.softmax(averaged_logits, dim=-1)
    predicted_class = torch.argmax(averaged_logits, dim=-1)
    return averaged_logits, predicted_class, probabilities

def predict_with_tta(model, image, processor, tta_aug, num_augmentations=10, device='cuda'):
    """Predict using Test-Time Augmentation."""
    return predict_with_weighted_tta(model, image, processor, tta_aug, num_augmentations, device)

tta_aug = augmentation_system
print("TTA inference functions defined!")
print("Using unified deterministic augmentation system for TTA (same as training).")

## Dataset Preparation Functions

In [None]:
def prepare_dataset(dataset_path):
    """Prepare dataset by finding all images and their corresponding vendor labels."""
    image_paths = []
    labels = []
    dataset_path = Path(dataset_path)
    image_extensions = {'.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'}

    for vendor in VENDOR_CLASSES:
        vendor_dir = dataset_path / vendor
        if vendor_dir.exists() and vendor_dir.is_dir():
            for ext in image_extensions:
                images = list(vendor_dir.glob(f"*{ext}"))
                for img_path in images:
                    image_paths.append(str(img_path))
                    labels.append(VENDOR_TO_IDX[vendor])

    if len(image_paths) == 0:
        for ext in image_extensions:
            all_images = list(dataset_path.rglob(f"*{ext}"))
            for img_path in all_images:
                filename = img_path.name.lower()
                for vendor in VENDOR_CLASSES:
                    if vendor.lower() in filename or vendor.lower() in str(img_path.parent).lower():
                        image_paths.append(str(img_path))
                        labels.append(VENDOR_TO_IDX[vendor])
                        break

    return image_paths, labels

def prepare_dataset_from_csv(train_dir, csv_path):
    """Prepare dataset by loading images and labels from CSV file."""
    image_paths = []
    labels = []
    train_dir = Path(train_dir)
    csv_path = Path(csv_path)

    if not csv_path.exists() or not train_dir.exists():
        print(f"WARNING: CSV or train directory not found")
        return image_paths, labels

    df = pd.read_csv(csv_path)
    explicit_mapping = {'messenger': 'Facebook', 'whatsapp': 'Facebook', 'mozilla': 'Google'}
    unique_labels = df['Label'].str.lower().unique()
    label_mapping = {}
    
    for csv_label in unique_labels:
        matched = False
        if csv_label in explicit_mapping:
            mapped_vendor = explicit_mapping[csv_label]
            for idx, vendor in enumerate(VENDOR_CLASSES):
                if vendor.lower() == mapped_vendor.lower():
                    label_mapping[csv_label.lower()] = idx
                    matched = True
                    break
        if not matched:
            for idx, vendor in enumerate(VENDOR_CLASSES):
                if csv_label == vendor.lower():
                    label_mapping[csv_label.lower()] = idx
                    matched = True
                    break
        if not matched:
            for idx, vendor in enumerate(VENDOR_CLASSES):
                if csv_label in vendor.lower() or vendor.lower() in csv_label:
                    label_mapping[csv_label.lower()] = idx
                    matched = True
                    break
        if not matched:
            print(f"WARNING: Label '{csv_label}' not found, skipping")

    skipped_count = 0
    for _, row in df.iterrows():
        image_id = str(row['Id']).zfill(5)
        label_str = str(row['Label']).lower()
        
        if label_str not in label_mapping:
            skipped_count += 1
            continue
        
        image_found = False
        for ext in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']:
            image_path = train_dir / f"{image_id}{ext}"
            if image_path.exists():
                image_paths.append(str(image_path))
                labels.append(label_mapping[label_str])
                image_found = True
                break
        
        if not image_found:
            print(f"WARNING: Image not found for ID {image_id}")
    
    if skipped_count > 0:
        print(f"WARNING: Skipped {skipped_count} images due to unmapped labels")
    
    print(f"Loaded {len(image_paths)} images with {len(set(labels))} unique classes")
    if len(labels) > 0:
        label_counts = np.bincount(labels)
        print(f"Label distribution: {label_counts}")
    
    return image_paths, labels

## Phase 1: First Dataset - Download and Prepare

In [None]:
# Download first dataset
path = kagglehub.dataset_download("subinium/emojiimage-dataset")
print(f"Path to dataset files: {path}")

# Prepare dataset
image_paths, labels = prepare_dataset(path)

print(f"\nFound {len(image_paths)} images")
if len(labels) > 0:
    print(f"Labels distribution: {np.bincount(labels)}")

## Phase 1: Split First Dataset and Create Data Loaders

In [None]:
# Split first dataset into train and test
if len(image_paths) > 0:
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )

    train_dataset = EmojiDataset(train_paths, train_labels, processor, use_augmentation=True)
    test_dataset = EmojiDataset(test_paths, test_labels, processor, use_augmentation=False)

    batch_size = 16 if torch.cuda.is_available() else 8

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=4 if torch.cuda.is_available() else 2,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=4 if torch.cuda.is_available() else 2,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True if torch.cuda.is_available() else False
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Batch size: {batch_size}")
    print("Training augmentation: ENABLED")
else:
    print("ERROR: No images found.")

## Phase 1: Training Setup

In [None]:
classification_model = base_model

In [None]:
# Training parameters
num_epochs = 15
learning_rate = 1e-5

optimizer = torch.optim.AdamW(
    classification_model.parameters(), lr=learning_rate, weight_decay=0.01
)

from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True, min_lr=1e-7, cooldown=2
)

scaler = None
if torch.cuda.is_available():
    model_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    if model_dtype == torch.float16:
        scaler = torch.cuda.amp.GradScaler()
        print("Mixed precision training: Enabled (float16 with GradScaler)")
    elif model_dtype == torch.bfloat16:
        print("Mixed precision training: Enabled (bfloat16 without GradScaler)")
    else:
        print("Mixed precision training: Disabled")
else:
    print("Mixed precision training: Disabled (CPU)")

print("Training setup complete!")

## Phase 1: Train on First Dataset

In [None]:
# Training on first dataset
if len(image_paths) > 0:
    print("Starting training on first dataset...")
    best_val_acc = 0
    early_stopping_patience = 5
    early_stopping_counter = 0
    best_epoch = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 50)

        train_loss, train_acc = train_epoch(classification_model, train_loader, optimizer, device, scaler)
        val_loss, val_acc, val_preds, val_labels = validate(classification_model, test_loader, device)

        scheduler.step(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Test Loss: {val_loss:.4f}, Test Acc: {val_acc:.2f}%")

        if torch.cuda.is_available():
            print(f"GPU Memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB / {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            early_stopping_counter = 0
            torch.save(classification_model.state_dict(), 'best_model_phase1.pt')
            print(f"✓ Saved best model with test accuracy: {best_val_acc:.2f}%")
        else:
            early_stopping_counter += 1
            print(f"No improvement. Patience: {early_stopping_counter}/{early_stopping_patience}")
        
        if early_stopping_counter >= early_stopping_patience:
            print(f"\nEarly stopping triggered!")
            break

    print(f"\nPhase 1 training completed! Best test accuracy: {best_val_acc:.2f}% at epoch {best_epoch}")
    
    # Print detailed scores
    print("\n" + "="*50)
    print("Phase 1 Final Scores")
    print("="*50)
    print(f"Best Test Accuracy: {best_val_acc:.2f}%")
    print(f"\nClassification Report:")
    print(classification_report(val_labels, val_preds, target_names=VENDOR_CLASSES))
else:
    print("ERROR: Cannot train without data.")

## Phase 1: Fine-tune on Test Split

In [None]:
# Combine train and test splits for final training (no data leakage)
if len(test_paths) > 0:
    print("\n" + "="*50)
    print("Final Training on Combined Dataset (Train + Test)")
    print("="*50)
    
    # Load best model from phase 1
    if os.path.exists('best_model_phase1.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase1.pt', map_location=device))
        print("Loaded best model from phase 1")
    
    # Combine train and test splits
    combined_train_paths = train_paths + test_paths
    combined_train_labels = train_labels + test_labels
    
    print(f"Combined dataset: {len(combined_train_paths)} samples")
    print(f"  - Original train: {len(train_paths)} samples")
    print(f"  - Original test: {len(test_paths)} samples")
    print(f"Combined label distribution: {np.bincount(combined_train_labels)}")
    
    # Create combined dataset with augmentation
    combined_train_dataset = EmojiDataset(
        combined_train_paths, combined_train_labels, processor, use_augmentation=True
    )
    
    combined_train_loader = DataLoader(
        combined_train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=4 if torch.cuda.is_available() else 2,
        pin_memory=torch.cuda.is_available()
    )
    
    # Final training parameters (lower learning rate for fine-tuning)
    final_epochs = 5
    final_lr = 5e-6
    
    final_optimizer = torch.optim.AdamW(
        classification_model.parameters(), lr=final_lr, weight_decay=0.01
    )
    
    final_scheduler = ReduceLROnPlateau(
        final_optimizer, mode='max', factor=0.5, patience=2, verbose=True, min_lr=1e-7, cooldown=1
    )
    
    print(f"\nTraining on combined dataset for {final_epochs} epochs...")
    
    for epoch in range(final_epochs):
        print(f"\nFinal Training Epoch {epoch + 1}/{final_epochs}")
        print("-" * 50)
        
        train_loss, train_acc = train_epoch(
            classification_model, combined_train_loader, final_optimizer, device, scaler
        )
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        final_scheduler.step(train_acc)
    
    # Save final model
    torch.save(classification_model.state_dict(), 'best_model_phase1_final.pt')
    print("\n✓ Final training completed! Saved best_model_phase1_final.pt")
else:
    print("No test split available for final training.")

## Phase 2: Second Dataset - Load and Prepare

In [None]:
# Load second dataset (single path, no alternatives)
dataset_base = Path("/content/drive/MyDrive/vision/2-computer-vision-2025-b-sc-aidams-final-proj")
train_dir = dataset_base / "train"
csv_path = dataset_base / "train_labels.csv"

if not train_dir.exists():
    print(f"ERROR: Train directory not found at {train_dir}")
    print("Please ensure the dataset is in the correct location.")
    second_dataset_paths = []
    second_dataset_labels = []
else:
    print(f"Found train directory at: {train_dir}")
    print(f"Found CSV file at: {csv_path}")
    
    second_dataset_paths, second_dataset_labels = prepare_dataset_from_csv(train_dir, csv_path)
    print(f"\nFound {len(second_dataset_paths)} labeled images from second dataset")
    
    if len(second_dataset_paths) > 0:
        print(f"Label distribution: {np.bincount(second_dataset_labels)}")

## Phase 2: Split Second Dataset and Re-fine-tune

In [None]:
# Split second dataset into train and test
if len(second_dataset_paths) > 0:
    print("\n" + "="*50)
    print("Re-fine-tuning on Second Dataset")
    print("="*50)
    
    if os.path.exists('best_model_phase1_final.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase1_final.pt', map_location=device))
        print("Loaded best final model from phase 1 (trained on combined data)")
    elif os.path.exists('best_model_phase1.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase1.pt', map_location=device))
        print("Loaded best model from phase 1")
    
    if len(second_dataset_paths) > 100:
        unique_labels, label_counts = np.unique(second_dataset_labels, return_counts=True)
        min_class_count = label_counts.min()
        can_stratify = min_class_count >= 2
        
        if can_stratify:
            try:
                second_train_paths, second_val_paths, second_train_labels, second_val_labels = train_test_split(
                    second_dataset_paths, second_dataset_labels, test_size=0.1, random_state=42, stratify=second_dataset_labels
                )
                print(f"✓ Split with STRATIFIED sampling: {len(second_train_paths)} train, {len(second_val_paths)} validation")
            except ValueError as e:
                print(f"⚠️ Stratification failed: {e}")
                second_train_paths, second_val_paths, second_train_labels, second_val_labels = train_test_split(
                    second_dataset_paths, second_dataset_labels, test_size=0.1, random_state=42
                )
                print(f"Split without stratification: {len(second_train_paths)} train, {len(second_val_paths)} validation")
        else:
            print(f"⚠️ Warning: Cannot stratify (min class count: {min_class_count} < 2)")
            second_train_paths, second_val_paths, second_train_labels, second_val_labels = train_test_split(
                second_dataset_paths, second_dataset_labels, test_size=0.1, random_state=42
            )
        
        print(f"\nTrain label distribution: {np.bincount(second_train_labels)}")
        print(f"Validation label distribution: {np.bincount(second_val_labels)}")
    else:
        second_train_paths, second_val_paths = second_dataset_paths, []
        second_train_labels, second_val_labels = second_dataset_labels, []
        print("Dataset too small for validation split, using all for training")
    
    second_train_dataset = EmojiDataset(second_train_paths, second_train_labels, processor, use_augmentation=True)
    
    if len(second_val_paths) > 0:
        second_val_dataset = EmojiDataset(second_val_paths, second_val_labels, processor, use_augmentation=False)
        second_val_loader = DataLoader(
            second_val_dataset, batch_size=batch_size, shuffle=False,
            num_workers=4 if torch.cuda.is_available() else 2, pin_memory=torch.cuda.is_available()
        )
    else:
        second_val_loader = None
    
    second_train_loader = DataLoader(
        second_train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=4 if torch.cuda.is_available() else 2, pin_memory=torch.cuda.is_available()
    )
    
    print(f"Re-fine-tuning samples: {len(second_train_dataset)}")
    if second_val_loader:
        print(f"Validation samples: {len(second_val_dataset)}")
    
    refinetune_epochs = 7
    refinetune_lr = 1e-5
    refinetune_optimizer = torch.optim.AdamW(classification_model.parameters(), lr=refinetune_lr, weight_decay=0.01)
    refinetune_scheduler = ReduceLROnPlateau(refinetune_optimizer, mode='max', factor=0.5, patience=2, verbose=True, min_lr=1e-7, cooldown=1)
    
    best_refinetune_acc = 0
    refinetune_early_stopping_patience = 3
    refinetune_early_stopping_counter = 0
    best_refinetune_epoch = 0
    
    for epoch in range(refinetune_epochs):
        print(f"\nRe-fine-tuning Epoch {epoch + 1}/{refinetune_epochs}")
        print("-" * 50)

        train_loss, train_acc = train_epoch(classification_model, second_train_loader, refinetune_optimizer, device, scaler)

        if second_val_loader:
            val_loss, val_acc, val_preds, val_labels = validate(classification_model, second_val_loader, device)
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            refinetune_scheduler.step(val_acc)

            if val_acc > best_refinetune_acc:
                best_refinetune_acc = val_acc
                best_refinetune_epoch = epoch + 1
                refinetune_early_stopping_counter = 0
                torch.save(classification_model.state_dict(), 'best_model_phase2.pt')
                print(f"✓ Saved best re-fine-tuned model with validation accuracy: {best_refinetune_acc:.2f}%")
            else:
                refinetune_early_stopping_counter += 1
                print(f"No improvement. Patience: {refinetune_early_stopping_counter}/{refinetune_early_stopping_patience}")
            
            if refinetune_early_stopping_counter >= refinetune_early_stopping_patience:
                print(f"\nEarly stopping triggered!")
                break
        else:
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            refinetune_scheduler.step(train_acc)
            torch.save(classification_model.state_dict(), 'best_model_phase2.pt')

    print(f"\nPhase 2 re-fine-tuning completed! Best validation accuracy: {best_refinetune_acc:.2f}% at epoch {best_refinetune_epoch}")
    
    if second_val_loader:
        print("\n" + "="*50)
        print("Phase 2 Final Scores")
        print("="*50)
        print(f"Best Validation Accuracy: {best_refinetune_acc:.2f}%")
        print(f"\nClassification Report:")
        print(classification_report(val_labels, val_preds, target_names=VENDOR_CLASSES))
else:
    print("Skipping re-fine-tuning (no second dataset found)")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

if second_val_loader:
    cm = confusion_matrix(val_labels, val_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=VENDOR_CLASSES, yticklabels=VENDOR_CLASSES)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()


## Phase 2: Final Training on Combined Train+Validation Data

In [None]:
# Combine train and validation splits for final training (no data leakage)
if len(second_val_paths) > 0:
    print("\n" + "="*50)
    print("Final Training on Combined Second Dataset (Train + Validation)")
    print("="*50)
    
    # Load best model from phase 2
    if os.path.exists('best_model_phase2.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase2.pt', map_location=device))
        print("Loaded best model from phase 2")
    
    # Combine train and validation splits
    combined_second_train_paths = second_train_paths + second_val_paths
    combined_second_train_labels = second_train_labels + second_val_labels
    
    print(f"Combined second dataset: {len(combined_second_train_paths)} samples")
    print(f"  - Original train: {len(second_train_paths)} samples")
    print(f"  - Original validation: {len(second_val_paths)} samples")
    print(f"Combined label distribution: {np.bincount(combined_second_train_labels)}")
    
    # Create combined dataset with augmentation
    combined_second_train_dataset = EmojiDataset(
        combined_second_train_paths, combined_second_train_labels, processor, use_augmentation=True
    )
    
    combined_second_train_loader = DataLoader(
        combined_second_train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=4 if torch.cuda.is_available() else 2, pin_memory=torch.cuda.is_available()
    )
    
    # Final training parameters
    final_second_epochs = 3
    final_second_lr = 5e-6
    
    final_second_optimizer = torch.optim.AdamW(
        classification_model.parameters(), lr=final_second_lr, weight_decay=0.01
    )
    
    final_second_scheduler = ReduceLROnPlateau(
        final_second_optimizer, mode='max', factor=0.5, patience=2, verbose=True, min_lr=1e-7, cooldown=1
    )
    
    print(f"\nTraining on combined second dataset for {final_second_epochs} epochs...")
    
    for epoch in range(final_second_epochs):
        print(f"\nFinal Training Epoch {epoch + 1}/{final_second_epochs}")
        print("-" * 50)
        
        train_loss, train_acc = train_epoch(
            classification_model, combined_second_train_loader, final_second_optimizer, device, scaler
        )
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        final_second_scheduler.step(train_acc)
    
    # Save final model
    torch.save(classification_model.state_dict(), 'best_model_final.pt')
    print("\n✓ Final training completed! Saved best_model_final.pt")
else:
    print("No validation split available for final training.")

## Generate Predictions on Original Test Data

In [None]:
# Load original test data and generate predictions
test_dataset_path = dataset_base / "test"

if test_dataset_path.exists():
    print("\n" + "="*50)
    print("Generating Predictions on Original Test Data")
    print("="*50)
    
    # Load best final model (trained on combined data)
    if os.path.exists('best_model_final.pt'):
        classification_model.load_state_dict(torch.load('best_model_final.pt', map_location=device))
        print("Using best final model (trained on combined data)")
    elif os.path.exists('best_model_phase2.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase2.pt', map_location=device))
        print("Using best phase 2 model")
    elif os.path.exists('best_model_phase1_final.pt'):
        classification_model.load_state_dict(torch.load('best_model_phase1_final.pt', map_location=device))
        print("Using best phase 1 final model")
    
    classification_model.eval()
    
    test_image_paths = []
    image_extensions = {'.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'}
    for ext in image_extensions:
        test_image_paths.extend(list(test_dataset_path.rglob(f"*{ext}")))
    test_image_paths = [str(p) for p in test_image_paths]
    test_image_paths.sort()
    
    print(f"Found {len(test_image_paths)} test images")
    
    predictions = []
    image_ids = []
    
    print(f"Processing {len(test_image_paths)} test images with TTA...")
    
    with torch.no_grad():
        for image_path in tqdm(test_image_paths, desc="Generating predictions"):
            try:
                image_id = Path(image_path).stem
                image = Image.open(image_path).convert('RGB')
                
                _, predicted_class, probabilities = predict_with_tta(
                    classification_model, image, processor, tta_aug,
                    num_augmentations=10, device=device
                )
                
                predicted_idx = predicted_class.item()
                if predicted_idx >= len(VENDOR_CLASSES):
                    print(f"WARNING: Invalid prediction index {predicted_idx}, using first class")
                    predicted_idx = 0
                
                predicted_label = IDX_TO_VENDOR[predicted_idx]
                predictions.append(predicted_label)
                image_ids.append(image_id)
                
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
                import traceback
                traceback.print_exc()
                predictions.append(VENDOR_CLASSES[0])
                image_ids.append(Path(image_path).stem)
    
    predictions_file = "predictions.csv"
    with open(predictions_file, 'w') as f:
        f.write("Id,Label\n")
        for img_id, pred_label in zip(image_ids, predictions):
            clean_id = str(img_id).strip()
            f.write(f"{clean_id},{pred_label}\n")
    
    print(f"\nPredictions saved to {predictions_file}")
    print(f"Total predictions: {len(predictions)}")
    print(f"Unique image IDs: {len(set(image_ids))}")
    
    from collections import Counter
    label_counts = Counter(predictions)
    print(f"\nPrediction distribution:")
    for label, count in sorted(label_counts.items()):
        percentage = 100 * count / len(predictions)
        print(f"  {label}: {count} ({percentage:.1f}%)")
    
    verification_df = pd.read_csv(predictions_file)
    print(f"\nVerification - Loaded {len(verification_df)} rows from {predictions_file}")
    print(f"Columns: {verification_df.columns.tolist()}")
    print(f"\nSample predictions:")
    print(verification_df.head(10))
    
    print("\n✓ Predictions generation completed!")
else:
    print(f"Test directory not found at {test_dataset_path}")
    print("Cannot generate predictions.")