# Balancing an Imbalanced Image Dataset via Augmentation

**Goal:** Ensure each of 7 classes (`bus`, `car`, `cat`, `dog`, `cricket`, `football`, `product`) ends up with **800** images.

---

## Concept Overview

1. **Resize Originals**  
   - Standardize all images to **256 × 256** pixels.

2. **Downsample If ≥ 800**  
   - Randomly pick 800 originals and resize.

3. **Upsample If < 800**  
   - Copy and resize all existing images.
   - Generate additional ones via simple augmentations:
     - **Horizontal flip** (50% chance)  
     - **Rotation** (±20°)  
     - **Color jitter** (±20% brightness/contrast/saturation)

---

## Result Snapshot

| Class     | Original Count → Final Count |
|:----------|:----------------------------:|
| bus       | 542 → 800                    |
| car       | 469 → 800                    |
| cat       | 500 → 800                    |
| dog       | 542 → 800                    |
| cricket   | 90  →  800                    |
| football  | 100 →  800                    |
| product   | 800  →  800                    |

_All classes are now perfectly balanced for model training._


In [None]:
import os
import random
from PIL import Image
from torchvision import transforms

categories = ['bus', 'car', 'cat', 'dog', 'cricket', 'football', 'product']

root_dir = "/kaggle/input/openaimer-data/OpenAImer2025_Image_Classification/OpenAImer/train"

# Define valid image extensions.
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif')

# Define a simple transform for resizing (used for original images).
resize_transform = transforms.Resize((256, 256))

# Define an augmentation pipeline.
aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.Resize((256, 256))
])

# Function to process and save an image.
def process_and_save(img_path, save_path, transform):
    try:
        with Image.open(img_path) as img:
            # Ensure image is in RGB mode (to preserve colour channels)
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img_processed = transform(img)
            img_processed.save(save_path)
    except Exception as e:
        print(f"Error processing {img_path}: {e}")

# Function to get list of image file paths given a directory.
def get_image_paths(directory):
    return [os.path.join(directory, fname) for fname in os.listdir(directory)
            if fname.lower().endswith(valid_extensions)]
out_dir = "/kaggle/working"
# Loop over each category.
for cat in categories:
    src_dir = os.path.join(root_dir, cat)
    dst_dir = os.path.join(out_dir, cat + '-augment')
    os.makedirs(dst_dir, exist_ok=True)

    # Get list of original image paths.
    original_paths = get_image_paths(src_dir)
    num_originals = len(original_paths)
    print(f"Category '{cat}': {num_originals} original images found.")

    final_count = 800  # Desired count per category

    # Prepare list for final images (we will copy or generate until we have exactly final_count images)
    saved_count = 0

    # If there are too many originals, randomly choose 800 to simply resize and copy.
    if num_originals >= final_count:
        chosen = random.sample(original_paths, final_count)
        for idx, img_path in enumerate(chosen, start=1):
            save_name = f"{cat}_{idx:04d}.jpg"
            save_path = os.path.join(dst_dir, save_name)
            process_and_save(img_path, save_path, resize_transform)
            saved_count += 1
    else:
        # First, copy all original images (resized)
        for idx, img_path in enumerate(original_paths, start=1):
            save_name = f"{cat}_orig_{idx:04d}.jpg"
            save_path = os.path.join(dst_dir, save_name)
            process_and_save(img_path, save_path, resize_transform)
            saved_count += 1

        # Determine how many augmented images to create.
        aug_needed = final_count - saved_count
        print(f"Generating {aug_needed} augmented images for category '{cat}'.")

        for idx in range(aug_needed):
            # Randomly choose an image from the originals.
            src_img_path = random.choice(original_paths)
            # Generate a unique filename.
            save_name = f"{cat}_aug_{saved_count + 1:04d}.jpg"
            save_path = os.path.join(dst_dir, save_name)
            process_and_save(src_img_path, save_path, aug_transform)
            saved_count += 1

    # Verify that exactly 800 images are now in the destination.
    final_images = get_image_paths(dst_dir)
    print(f"Finished category '{cat}': {len(final_images)} images in {dst_dir}.")
    if len(final_images) != final_count:
        print(f"Warning: Expected {final_count} images but found {len(final_images)} in {dst_dir}.")

Category 'bus': 542 original images found.
Generating 258 augmented images for category 'bus'.
Finished category 'bus': 800 images in /kaggle/working/bus-augment.
Category 'car': 469 original images found.
Generating 331 augmented images for category 'car'.
Finished category 'car': 800 images in /kaggle/working/car-augment.
Category 'cat': 500 original images found.
Generating 300 augmented images for category 'cat'.
Finished category 'cat': 800 images in /kaggle/working/cat-augment.
Category 'dog': 542 original images found.
Generating 258 augmented images for category 'dog'.
Finished category 'dog': 800 images in /kaggle/working/dog-augment.
Category 'cricket': 90 original images found.
Generating 710 augmented images for category 'cricket'.
Finished category 'cricket': 800 images in /kaggle/working/cricket-augment.
Category 'football': 100 original images found.
Generating 700 augmented images for category 'football'.
Finished category 'football': 800 images in /kaggle/working/footb

# **Embedding Generation Through Enhanced Supervised Contrastive Learning**

This Python code implements an advanced supervised contrastive learning framework using PyTorch and PyTorch Lightning for image feature extraction. The primary goal is to learn embeddings where same‑class images are **pulled together** and different‑class images are **pushed apart**, with extra emphasis on under‑represented classes and hard negatives.

---

## Key Components

### 1. **`EnhancedSupConLoss`**  
Custom supervised contrastive loss with:
- **Temperature scaling** ($\tau$) and **base temperature** ($\tau_0$).  
- **Class re‑weighting** to up‑weight rare classes:  
  $$w_c = \frac{\max_{c'} N_{c'}}{N_c},$$  
  then scale each anchor’s loss by $w_{y_i}$.  
- **Margin** $m$ on negatives:  
  $$\tilde s_{i,a} = \frac{z_i \cdot z_a}{\tau} - m,\quad a\in\mathcal{N}(i).$$  
- **Hard negative mining**: select top‑$k$ hardest negatives per anchor,  
  $$k = \max\bigl(1,\lfloor r\cdot|\mathcal{N}(i)|\rfloor\bigr),$$  
  where $r$ is the `hard_mining_ratio`.  

**Core loss** for anchor $i$ (view 0):  
$$
\ell_i = -\frac{1}{|\mathcal{P}(i)|}\sum_{p\in\mathcal{P}(i)}
\log\frac{\exp\!\bigl(z_i\cdot z_p/\tau\bigr)}
{\sum_{a\in\mathcal{P}(i)\cup\text{hardNeg}(i)}\exp\!\bigl(\tilde s_{i,a}\bigr)}
\quad,\quad
L = \frac{1}{N}\sum_i w_{y_i}\,\ell_i.
$$

---

### 2. **`AdvancedSupConDataset`**  
Handles data loading and computes per‑class weights:
- **Class weights** for sampler and loss:  
  $$w_c = \frac{\max_{c'} N_{c'}}{N_c}.$$  
- **Strong augmentation pipeline** (for each of $n_{\mathrm{views}}$):
  - Random resized crop, color jitter, grayscale, Gaussian blur  
  - Flips, rotation, affine transforms  
  - Normalization to ImageNet stats  
- Generates **multiple “views”** per image for contrastive pairs.

---

### 3. **`EnhancedEncoder`**  
Feature extractor with:
- **Pretrained CNN backbone** (ResNet, EfficientNet) We used EfficientNet as it proved to be a better feature extractor than ResNet for our task.  
- Optional **attention mechanism**:
  $$a = \sigma(W_2\,\mathrm{ReLU}(W_1\,f))\,,\quad f\in\mathbb{R}^{d},$$  
  then scale features by $a$.  
- **Projection head** (MLP) mapping to $d_{\mathrm{proj}}$, followed by $\ell_2$ normalization.

---

### 4. **`AdvancedSupConModule`**  
PyTorch Lightning module tying together encoder and loss:
- **`forward`** handles multi‑view batching:  
  reshape $[B,n_v,C,H,W]\to[B\,n_v,C,H,W]\to[B,n_v,d_{\mathrm{proj}}]$.  
- **`training_step`** computes $L$ via `EnhancedSupConLoss`, logs `train_loss`.  
- Optimizer: **AdamW** with **CosineAnnealingWarmRestarts**.

---

### 5. **`AdvancedSupConDataModule`**  
Lightning data module for train/val split:
- Splits dataset by ratio, seeds for reproducibility.
- **WeightedRandomSampler** on training set using $\{w_c\}$ to ensure balanced batches.
- Standard DataLoader for validation (no shuffling).

---

### 6. **`train_contrastive_model`**  
Orchestrator function:
- Instantiates `AdvancedSupConDataModule` & `AdvancedSupConModule`.
- Configures logging (CSVLogger), callbacks (checkpointing, early stopping, LR monitoring).
- Runs `Trainer.fit(...)` on GPU/CPU, saves best & final checkpoints.

---

By combining **temperature‑scaled contrastive loss**, **class re‑weighting**, **margin‑augmented hard negative mining**, and **strong data augmentations**, this framework yields highly discriminative, class‑balanced embeddings ready for downstream tasks.  


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
from torchvision import transforms, models
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import numpy as np
from collections import Counter

# --- Enhanced Supervised Contrastive Loss with Class Reweighting and Hard Negative Mining ---
class EnhancedSupConLoss(nn.Module):
    def __init__(self, temperature=0.05, base_temperature=0.07, contrast_mode='all',
                 hard_mining_ratio=0.35, margin=0.2):
        super(EnhancedSupConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.contrast_mode = contrast_mode
        self.hard_mining_ratio = hard_mining_ratio  # Ratio of hard negatives to mine
        self.margin = margin  # Margin to push negatives further

    def forward(self, features, labels=None, mask=None, class_weights=None):
        """
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
            class_weights: dictionary mapping class indices to weights.
        Returns:
            A loss scalar.
        """
        device = features.device

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')

        batch_size = features.shape[0]

        # Flatten if features is [bsz, n_views, ...]
        features = features.view(batch_size, -1, features.shape[-1])

        if self.contrast_mode == 'one':
            # Use first view as anchor
            anchor_feature = features[:, 0]
            contrast_feature = features[:, 1]
        elif self.contrast_mode == 'all':
            # Reshape to (batch_size * n_views, feature_dim)
            anchor_feature = features.view(-1, features.shape[-1])
            contrast_feature = anchor_feature

            if labels is not None:
                # Repeat labels for each view
                labels = labels.repeat_interleave(features.shape[1])

            batch_size = anchor_feature.shape[0]
        else:
            raise ValueError('Unknown contrast mode: {}'.format(self.contrast_mode))

        # Compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature
        )

        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        anchor_dot_contrast = anchor_dot_contrast - logits_max.detach()

        # Mask diagonal (self-contrast) and apply margin to negative pairs
        logits_mask = 1 - torch.eye(batch_size, device=device)

        # Create mask for positive pairs
        if mask is None:
            mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)

        # Hard negative mining: find the hardest negatives (highest similarity)
        neg_mask = 1 - mask - torch.eye(batch_size, device=device)  # Mask for negative pairs
        neg_logits = anchor_dot_contrast * neg_mask

        # Apply margin to negative pairs (push them further away)
        if self.margin > 0:
            neg_logits = neg_logits - self.margin * neg_mask
            anchor_dot_contrast = anchor_dot_contrast * mask + neg_logits

        # Hard negative mining: select the hardest negatives
        if self.hard_mining_ratio < 1.0 and self.hard_mining_ratio > 0:
            k = int(batch_size * self.hard_mining_ratio)
            k = max(k, 1)  # At least one negative

            # For each anchor, find the k hardest negatives
            for i in range(batch_size):
                neg_logits_i = neg_logits[i]
                # Filter zeros (which correspond to positives or self)
                valid_indices = torch.where(neg_mask[i] > 0)[0]
                if len(valid_indices) > 0:
                    valid_logits = neg_logits_i[valid_indices]
                    # Take top k if we have enough negatives
                    k_actual = min(k, len(valid_indices))
                    if k_actual > 0:
                        _, hard_indices = torch.topk(valid_logits, k_actual)
                        # Create a mask to zero out non-hard negatives
                        hard_indices = valid_indices[hard_indices]
                        hard_mask = torch.zeros_like(neg_logits_i, device=device)
                        hard_mask[hard_indices] = 1.0
                        # Update mask to only include hard negatives and positives
                        logits_mask[i] = (mask[i] + hard_mask * neg_mask[i]) > 0

        # Apply mask to exclude self-contrast cases
        final_mask = mask * logits_mask

        # Compute log_prob
        exp_logits = torch.exp(anchor_dot_contrast) * logits_mask
        log_prob = anchor_dot_contrast - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

        # Compute mean of log-likelihood over positive
        # Weight each class if class_weights is provided
        if class_weights is not None and labels is not None:
            # Convert class indices to weights
            weight_values = torch.tensor([class_weights.get(label.item(), 1.0)
                                      for label in labels], device=device)
            # Create weight matrix for positive pairs
            weight_matrix = weight_values.unsqueeze(1) * final_mask
            # Apply weights to log_prob
            mean_log_prob_pos = (weight_matrix * log_prob).sum(1) / (weight_matrix.sum(1) + 1e-12)
        else:
            mean_log_prob_pos = (final_mask * log_prob).sum(1) / (final_mask.sum(1) + 1e-12)

        # Loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()

        return loss

# --- Advanced Data Augmentation Strategy ---
class AdvancedSupConDataset(Dataset):
    def __init__(self, root_dir, transform=None, n_views=2, strong_augment=True):
        self.root_dir = root_dir
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.samples = []
        self.targets = []  # Store targets separately for sampling

        # Collect all samples
        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    path = os.path.join(cls_dir, fname)
                    class_idx = self.class_to_idx[cls]
                    self.samples.append((path, class_idx))
                    self.targets.append(class_idx)

        # Count class occurrences for class distribution
        self.class_counts = Counter(self.targets)
        print(f"Class distribution: {self.class_counts}")

        # Calculate class weights (inverse of frequency)
        max_count = max(self.class_counts.values())
        self.class_weights = {cls: max_count / count for cls, count in self.class_counts.items()}

        # Base transform
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),  # Larger size for better detail
            transforms.CenterCrop(224),     # Standard size for ResNet
        ])

        # ImageNet normalization
        imagenet_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])

        # Strong augmentation strategy for contrastive learning
        if strong_augment:
            augmentation_pipeline = [
                transforms.RandomResizedCrop(size=224, scale=(0.2, 1.0)),
                transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(degrees=15),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ToTensor(),
                imagenet_norm
            ]
        else:
            # Simpler augmentation for less distortion
            augmentation_pipeline = [
                transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                transforms.ToTensor(),
                imagenet_norm
            ]

        self.augmentations = transforms.Compose(augmentation_pipeline)
        self.n_views = n_views

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path).convert('RGB')

        # Apply base transform
        image = self.transform(image)

        # Generate n_views different augmented versions
        views = torch.stack([self.augmentations(image) for _ in range(self.n_views)], dim=0)

        return views, label

# --- Enhanced Feature Extractor Network ---
class EnhancedEncoder(nn.Module):
    def __init__(self, backbone='resnet50', proj_dim=512, dropout=0.3, use_attention=True):
        super(EnhancedEncoder, self).__init__()

        # Load pretrained backbone
        if backbone == 'resnet50':
            base_model = models.resnet50(pretrained=True)
            feat_dim = 2048
        elif backbone == 'resnet101':
            base_model = models.resnet101(pretrained=True)
            feat_dim = 2048
        elif backbone == 'efficientnet_b3':
            base_model = models.efficientnet_b3(pretrained=True)
            feat_dim = 1536
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        # Remove classification head
        self.encoder = nn.Sequential(*list(base_model.children())[:-1])

        # Attention mechanism (optional)
        self.use_attention = use_attention
        if use_attention:
            self.attention = nn.Sequential(
                nn.Flatten(),
                nn.Linear(feat_dim, feat_dim // 8),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim // 8, feat_dim),
                nn.Sigmoid()
            )

        # Non-linear projection head (MLP)
        self.projector = nn.Sequential(
            nn.Flatten(),
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(feat_dim, feat_dim // 2),
            nn.BatchNorm1d(feat_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(feat_dim // 2, proj_dim),
            nn.BatchNorm1d(proj_dim)
        )

    def forward(self, x):
        features = self.encoder(x)

        if self.use_attention:
            # Apply attention mechanism
            flat_features = features.view(features.size(0), -1)
            attention_weights = self.attention(features)
            weighted_features = flat_features * attention_weights
            weighted_features = weighted_features.view_as(features)
            features = weighted_features

        # Project features
        projections = self.projector(features)

        # L2 normalize embeddings
        projections = F.normalize(projections, dim=1)

        return projections

# --- Lightning Module for Advanced SupCon ---
class AdvancedSupConModule(pl.LightningModule):
    def __init__(self, encoder_type='resnet50', proj_dim=512, learning_rate=3e-4,
                 weight_decay=1e-4, temperature=0.07, use_attention=True,
                 hard_mining_ratio=0.5, margin=0.1):
        super(AdvancedSupConModule, self).__init__()
        self.save_hyperparameters()

        # Initialize enhanced encoder
        self.encoder = EnhancedEncoder(
            backbone=encoder_type,
            proj_dim=proj_dim,
            dropout=0.3,
            use_attention=use_attention
        )

        # Loss function with class weights passed at training time
        self.criterion = EnhancedSupConLoss(
            temperature=temperature,
            hard_mining_ratio=hard_mining_ratio,
            margin=margin
        )

        self.train_losses = []
        self.dataset_class_weights = None

    def forward(self, x):
        # For inference/validation reshape (only one view at test time)
        if len(x.shape) == 4:  # [B, C, H, W]
            return self.encoder(x)

        # For training with multiple views [B, n_views, C, H, W]
        batch_size, n_views, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # Reshape to [B*n_views, C, H, W]
        features = self.encoder(x)  # Get embeddings [B*n_views, proj_dim]
        features = features.view(batch_size, n_views, -1)  # Reshape to [B, n_views, proj_dim]
        return features

    def training_step(self, batch, batch_idx):
        views, labels = batch  # views shape: [B, n_views, C, H, W]
        features = self(views)

        # Calculate loss with class weights if available
        loss = self.criterion(features, labels, class_weights=self.dataset_class_weights)

        # Log metrics
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.train_losses.append(loss.detach())

        return loss

    def validation_step(self, batch, batch_idx):
        views, labels = batch
        if isinstance(views, list):
            # Handle different formats
            first_view = views[0]
        else:
            # Take only first view for validation
            first_view = views[:, 0]

        embeddings = self.encoder(first_view)

        # Just log that validation was performed
        self.log('val_step', batch_idx)

        return embeddings

    def on_train_epoch_end(self):
        if len(self.train_losses) > 0:
            avg_loss = torch.stack(self.train_losses).mean()
            self.log('epoch_loss', avg_loss)
            print(f"Epoch {self.current_epoch} --> Average Loss: {avg_loss:.4f}")
            self.train_losses.clear()

    def configure_optimizers(self):
        # Use AdamW optimizer
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,
            T_mult=2,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "train_loss"
            }
        }

# --- Advanced Data Module ---
class AdvancedSupConDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4, strong_augment=True, val_split=0.1):
        super(AdvancedSupConDataModule, self).__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.strong_augment = strong_augment
        self.val_split = val_split

    def setup(self, stage=None):
        # Create full dataset
        self.full_dataset = AdvancedSupConDataset(
            root_dir=self.data_dir,
            strong_augment=self.strong_augment
        )

        # Get dataset size and calculate split indices
        dataset_size = len(self.full_dataset)
        val_size = int(self.val_split * dataset_size)
        train_size = dataset_size - val_size

        # Create train/val splits
        indices = list(range(dataset_size))
        # Randomize the indices
        np.random.seed(42)
        np.random.shuffle(indices)

        train_indices = indices[:train_size]
        val_indices = indices[train_size:]

        # Create subsets
        self.train_dataset = Subset(self.full_dataset, train_indices)
        self.val_dataset = Subset(self.full_dataset, val_indices)

        # We need to keep track of targets for the sampler
        self.train_targets = [self.full_dataset.targets[i] for i in train_indices]

        # Store class weights
        self.class_weights = self.full_dataset.class_weights

        print(f"Train size: {train_size}, Validation size: {val_size}")

    def train_dataloader(self):
        # Create a weighted sampler for the training set
        class_counts = Counter(self.train_targets)
        max_samples = max(class_counts.values())
        weights = [self.class_weights[target] for target in self.train_targets]

        sampler = WeightedRandomSampler(
            weights=weights,
            num_samples=len(self.train_targets),
            replacement=True
        )

        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=sampler,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

# --- Main Training Function ---
def train_contrastive_model(train_dir, batch_size=16, max_epochs=400,
                          encoder_type='resnet50', proj_dim=512,
                          learning_rate=3e-4, use_attention=True):

    # Initialize data module
    data_module = AdvancedSupConDataModule(
        data_dir=train_dir,
        batch_size=batch_size,
        num_workers=4,
        strong_augment=True,
        val_split=0.1
    )

    # Set up data module
    data_module.setup()

    # Initialize model
    model = AdvancedSupConModule(
        encoder_type=encoder_type,
        proj_dim=proj_dim,
        learning_rate=learning_rate,
        weight_decay=1e-4,
        temperature=0.07,
        use_attention=use_attention,
        hard_mining_ratio=0.5,
        margin=0.1
    )

    # Pass dataset class weights to model
    model.dataset_class_weights = data_module.class_weights

    # Set up logging
    csv_logger = CSVLogger("logs/", name="advanced_supcon_run")

    # Set up callbacks
    checkpoint_callback_best = ModelCheckpoint(
        monitor='train_loss',
        mode='min',
        save_top_k=1,
        filename='best-{epoch:02d}-{train_loss:.4f}-efficientnet_b3'
    )


    early_stopping = EarlyStopping(
        monitor='train_loss',
        patience=30,
        mode='min'
    )

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Set up trainer
    if torch.cuda.is_available():
        trainer = pl.Trainer(
            max_epochs=20,
            accelerator='gpu',
            devices=1,
            logger=csv_logger,
            callbacks=[checkpoint_callback_best, early_stopping, lr_monitor],
            gradient_clip_val=1.0,  # Clip gradients to avoid exploding gradients
        )
    else:
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            accelerator='cpu',
            devices=1,
            logger=csv_logger,
            callbacks=[checkpoint_callback_best,early_stopping, lr_monitor],
            gradient_clip_val=1.0,
        )

    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Save final model
    trainer.save_checkpoint("advanced_supcon_final_resnet101.ckpt")

    return model

# --- Example Usage ---
if __name__ == '__main__':
    train_dir = '/kaggle/input/augmented-256/augmented_dataset'
    model = train_contrastive_model(
        train_dir=train_dir,
        batch_size=16,  # Reduced batch size to avoid CUDA OOM
        max_epochs=20,
        encoder_type='efficientnet_b3',
        proj_dim=1024,
        learning_rate=3e-4,
        use_attention=True
    )

Class distribution: Counter({0: 800, 1: 800, 2: 800, 3: 800, 4: 800, 5: 800, 6: 800})
Train size: 5040, Validation size: 560


Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth
100%|██████████| 47.2M/47.2M [00:00<00:00, 181MB/s]


Class distribution: Counter({0: 800, 1: 800, 2: 800, 3: 800, 4: 800, 5: 800, 6: 800})
Train size: 5040, Validation size: 560


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0 --> Average Loss: 3.6431


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1 --> Average Loss: 2.5069


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2 --> Average Loss: 2.2393


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3 --> Average Loss: 2.1468


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4 --> Average Loss: 2.0106


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5 --> Average Loss: 2.0187


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6 --> Average Loss: 1.8905


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7 --> Average Loss: 1.8807


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8 --> Average Loss: 1.8496


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9 --> Average Loss: 1.8547


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10 --> Average Loss: 1.9926


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11 --> Average Loss: 2.0534


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12 --> Average Loss: 1.9950


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13 --> Average Loss: 1.9334


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14 --> Average Loss: 1.9680


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15 --> Average Loss: 1.9157


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16 --> Average Loss: 1.8918


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17 --> Average Loss: 1.8897


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18 --> Average Loss: 1.8563


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19 --> Average Loss: 1.8472


# Embedding Classification Using MLP and Focal Loss - Key Mathematical Concepts

This code implements an image classification pipeline using PyTorch Lightning that leverages features from a pre-trained Supervised Contrastive (SupCon) model. The approach employs transfer learning: utilizing powerful representations learned during SupCon pre-training for downstream classification tasks.

## Key Mathematical Components

### 1. Focal Loss for Class Imbalance
The `FocalLoss` class implements a specialized loss function designed to address class imbalance:

- **Mathematical formulation**:
  ```
  FL(pt) = -αt * (1-pt)^γ * log(pt)
  ```
  - Where pt is the model's probability estimate for the true class
  - γ (gamma) is the focusing parameter (set to 2.0 by default)
  - αt are class weights inversely proportional to frequencies

- The focusing parameter down-weights easy examples, forcing the model to concentrate on difficult cases
- Class-specific weights compensate for imbalanced distributions in the dataset

### 2. Transfer Learning Architecture

The `LinearClassifierFromSupCon` utilizes a frozen pre-trained model with a trainable MLP:

- **Feature extraction**: Pre-trained SupCon model maps images to embeddings
  - These embeddings capture semantic information learned during contrastive training
  - Weights are frozen to preserve learned representations

- **MLP classifier head**:
  ```
  Input (proj_dim) → Linear → BN → ReLU → Dropout →
  Linear → BN → ReLU → Dropout → Linear → Output (num_classes)
  ```
  - Dimensions: proj_dim → 256 → 128 → num_classes
  - BatchNorm layers normalize activations, stabilizing training
  - Dropout (0.3) provides regularization to prevent overfitting

### 3. Evaluation Metrics

- **Macro F1-Score**: Primary performance metric
  ```
  F1 = 2 * (precision * recall) / (precision + recall)
  ```
  - Calculated per-class then averaged evenly across all classes
  - Appropriate for imbalanced datasets as it gives equal importance to minority classes

### 4. Data Handling and Training

- **Stratified sampling**: Maintains class distribution between training and validation sets
- **Data normalization**: Images normalized using ImageNet statistics
  ```
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  ```
- **Optimization**: AdamW optimizer combines Adam with decoupled weight decay
- **Training monitoring**: Early stopping based on validation F1-score prevents overfitting

## Implementation Structure

The code maintains a clean separation between:
- Data preparation (`ClassificationDataset`, `ClassifierDataModule`)
- Model architecture (`LinearClassifierFromSupCon`)
- Training logic (PyTorch Lightning Trainer)

This modular design allows for easy experimentation with different models, datasets, and hyperparameters while maintaining the core mathematical principles behind the classification approach.

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchmetrics.classification import MulticlassF1Score
from sklearn.model_selection import train_test_split

# --- Focal Loss for class imbalance ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=1.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        if alpha is None:

            frequencies = torch.tensor([800,800,800,800,800,800,800],dtype=torch.float)
            alpha = 1.0 / frequencies
            alpha = alpha / alpha.sum()
        self.register_buffer('alpha', alpha)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: logits [N, C], targets: [N]
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        probs = torch.softmax(inputs, dim=1)
        one_hot = nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()
        p_t = (probs * one_hot).sum(dim=1)
        alpha_t = self.alpha[targets]
        focal = (1 - p_t) ** self.gamma
        loss = alpha_t * focal * ce_loss
        if self.reduction == 'mean':
            return loss.mean()
        if self.reduction == 'sum':
            return loss.sum()
        return loss

# --- Image Classification Dataset ---
class ClassificationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.classes = sorted([
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d))
        ])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.samples = []
        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    path = os.path.join(cls_dir, fname)
                    self.samples.append((path, self.class_to_idx[cls]))
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, label

# --- DataModule for classification ---
class ClassifierDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4, train_split=0.8):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_split = train_split
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])

    def setup(self, stage=None):
        full = ClassificationDataset(self.data_dir, transform=self.transform)
        targets = [full[i][1] for i in range(len(full))]
        idx = list(range(len(full)))
        train_idx, val_idx = train_test_split(
            idx, train_size=self.train_split,
            stratify=targets, random_state=42
        )
        self.train_ds = Subset(full, train_idx)
        self.val_ds   = Subset(full, val_idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds, batch_size=self.batch_size,
            shuffle=True, num_workers=self.num_workers
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds, batch_size=self.batch_size,
            shuffle=False, num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            self.val_ds, batch_size=self.batch_size,
            shuffle=False, num_workers=self.num_workers
        )

# --- Classifier using frozen SupCon model ---
class LinearClassifierFromSupCon(pl.LightningModule):
    def __init__(self, supcon_model: pl.LightningModule, num_classes=7,
                 learning_rate=1e-3, focal_gamma=2.0):
        super().__init__()
        self.save_hyperparameters(ignore=['supcon_model'])
        # Freeze SupCon
        self.supcon_model = supcon_model
        self.supcon_model.eval()
        for p in self.supcon_model.parameters():
            p.requires_grad = False
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(self.supcon_model.hparams.proj_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        self.criterion = FocalLoss(gamma=focal_gamma)
        self.f1 = MulticlassF1Score(num_classes=num_classes, average='macro')
        self.train_losses = []
    def forward(self, x):
        # x: [B,C,H,W]
        # get SupCon embedding
        with torch.no_grad():
            emb = self.supcon_model(x)
        # emb: [B, proj_dim]
        logits = self.classifier(emb)
        return logits

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = self.criterion(logits, labels)
        self.log('train_loss', loss, prog_bar=True)
        self.train_losses.append(loss.detach())  # Store the detached loss
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = self.criterion(logits, labels)
        preds = torch.argmax(logits, dim=1)
        self.f1.update(preds, labels)
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        if self.train_losses:  # Ensure there are losses to average
            avg_loss = torch.stack(self.train_losses).mean()
            self.log('train_loss_epoch', avg_loss, on_epoch=True, prog_bar=True)
            print(f"Epoch {self.current_epoch} --> Train Loss: {avg_loss:.4f}")
            self.train_losses.clear()  # Reset for next epoch

    def on_validation_epoch_end(self):
        f1 = self.f1.compute()
        self.print(f"Epoch {self.current_epoch} --> {f1:.4f}")
        self.log('val_f1', f1, prog_bar=True)
        self.f1.reset()

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = self.criterion(logits, labels)
        preds = torch.argmax(logits, dim=1)
        self.f1.update(preds, labels)
        self.log('test_loss', loss, prog_bar=True)
        return loss

    def on_test_epoch_end(self):
        f1 = self.f1.compute()
        self.log('test_f1', f1, prog_bar=True)
        self.f1.reset()
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.classifier.parameters(), lr=self.hparams.learning_rate)
        return opt

# --- Entrypoint ---

def main_classification(checkpoint_path, data_dir):
    # load SupCon checkpoint
    supcon =AdvancedSupConModule.load_from_checkpoint(checkpoint_path)
    # data
    dm = ClassifierDataModule(data_dir)
    # classifier
    clf = LinearClassifierFromSupCon(supcon_model=supcon)

    csv_logger = CSVLogger(save_dir="logs/", name="classifier_run")

    # trainer
    trainer = pl.Trainer(
        max_epochs=50,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        logger=csv_logger,
        callbacks=[
            ModelCheckpoint(monitor='val_f1', mode='max', save_last=True),
            EarlyStopping(monitor='val_f1', mode='max', patience=10)
        ]
    )
    trainer.fit(clf, dm)
    trainer.test(clf, dm)

if __name__ == '__main__':
    main_classification(
        checkpoint_path='/kaggle/input/advanced-supcon-final-resnet101-2/advanced_supcon_final_resnet101 (2).ckpt',
        data_dir='/kaggle/input/augmented-256/augmented_dataset'
    )



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 0 --> 0.0419


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0 --> 1.0000
Epoch 0 --> Train Loss: 0.0091


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1 --> 1.0000
Epoch 1 --> Train Loss: 0.0015


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2 --> 1.0000
Epoch 2 --> Train Loss: 0.0006


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3 --> 1.0000
Epoch 3 --> Train Loss: 0.0009


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4 --> 1.0000
Epoch 4 --> Train Loss: 0.0012


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5 --> 1.0000
Epoch 5 --> Train Loss: 0.0007


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6 --> 1.0000
Epoch 6 --> Train Loss: 0.0010


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7 --> 1.0000
Epoch 7 --> Train Loss: 0.0005


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8 --> 1.0000
Epoch 8 --> Train Loss: 0.0006


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9 --> 1.0000
Epoch 9 --> Train Loss: 0.0006


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10 --> 0.9991
Epoch 10 --> Train Loss: 0.0008


Testing: |          | 0/? [00:00<?, ?it/s]

# Inference Pipeline for Embedding Classification

This code implements an inference pipeline for the previously defined embedding classification model. The pipeline handles loading test images, running them through the trained classifier, and generating predictions in a CSV format suitable for submission.

## Key Components

### 1. Test Dataset Definition

The `TestDatasetForClassifier` class:
- Loads images from a specified directory
- Extracts image IDs from filenames using regex
- Applies consistent preprocessing transformations:
  ```python
  transforms.Compose([
      transforms.Resize((128, 128)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
  ])
  ```
- Returns tuples of (transformed_image, image_id) for inference

### 2. DataLoader Creation

The `get_test_dataloader` function:
- Creates a PyTorch DataLoader for efficient batch processing
- Uses the `TestDatasetForClassifier` with specified parameters
- Returns a DataLoader with configurable batch size and worker count

### 3. Inference Logic

The `run_inference_linear_classifier` function:
- Sets the model to evaluation mode
- Processes images in batches for efficiency
- Performs forward pass through the model to get logits
- Converts prediction indices to human-readable labels
- Returns list of (image_id, predicted_label) pairs

### 4. Main Execution Flow

The `main` function orchestrates the entire process:
1. Defines paths to test data and model checkpoints
2. Loads the trained classifier model from checkpoint
3. Creates the test DataLoader
4. Runs inference on all test images
5. Saves predictions to a CSV file in the required format

## Usage

The pipeline uses a pre-defined set of class labels:
```python
LABELS = ['bus', 'car', 'cat', 'cricket', 'dog', 'football', 'product']
```

This inference pipeline complements the training pipeline by providing a streamlined way to generate predictions from the trained model on new, unseen data.

In [None]:
import os
import re
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl

# List of label names in order (same as in your classifier)
LABELS = ['bus', 'car', 'cat', 'cricket', 'dog', 'football', 'product']

# ------------------------
# 1. Test Dataset Definition
# ------------------------
class TestDatasetForClassifier(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Directory containing test images.
            transform (callable, optional): Image transform to apply.
        """
        self.root_dir = root_dir
        # List only image files
        self.filenames = [f for f in os.listdir(root_dir)
                          if os.path.isfile(os.path.join(root_dir, f))
                          and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
        self.transform = transform or transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def _extract_id(self, filename):
        # This regex extracts the trailing number from the filename.
        match = re.search(r'(\d+)(?:\.\w+)?$', filename)
        return match.group(1) if match else filename

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        img_path = os.path.join(self.root_dir, filename)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        image_id = self._extract_id(filename)
        return image, image_id

# ------------------------
# 2. Create Test DataLoader
# ------------------------
def get_test_dataloader(test_dir, batch_size=32, num_workers=4):
    dataset = TestDatasetForClassifier(root_dir=test_dir)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# ------------------------
# 3. Inference Function
# ------------------------
def run_inference_linear_classifier(model, test_loader, device):
    model.eval()
    model.to(device)
    results = []  # will hold tuples: (id, predicted_label)
    with torch.no_grad():
        for images, ids in test_loader:
            images = images.to(device)
            # Forward pass: classifier accepts images of shape [B, C, H, W]
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            preds = preds.cpu().numpy()
            for image_id, pred_idx in zip(ids, preds):
                label = LABELS[pred_idx]
                results.append((image_id, label))
    return results

# ------------------------
# 4. Main Inference Routine
# ------------------------
def main():

    test_dir = '/kaggle/input/openaimer-data/OpenAImer2025_Image_Classification/OpenAImer/test'  # test image folder path
    classifier_checkpoint = '/kaggle/input/new-classifier-3/epoch10_classifier(3).ckpt'  # update with your checkpoint path

    batch_size = 32
    num_workers = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



    model = LinearClassifierFromSupCon.load_from_checkpoint(classifier_checkpoint,supcon_model=AdvancedSupConModule.load_from_checkpoint('/kaggle/input/advanced-supcon-final-resnet101-2/advanced_supcon_final_resnet101 (2).ckpt'))


    model.freeze()


    test_loader = get_test_dataloader(test_dir, batch_size=batch_size, num_workers=num_workers)


    predictions = run_inference_linear_classifier(model, test_loader, device)


    df = pd.DataFrame(predictions, columns=['id', 'label'])
    output_csv_path = '/kaggle/working/test_predictions.csv'
    df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")
      # Display a preview of the predictions
    print("\nPreview of predictions:")
    print(df.head(10))  # Show first 10 rows

    #  Show class distribution
    print("\nClass distribution in predictions:")
    print(df['label'].value_counts())
if __name__ == '__main__':
    main()



Predictions saved to /kaggle/working/test_predictions.csv

Preview of predictions:
       id    label
0   63392      cat
1   60822      cat
2  287801      bus
3  385827  product
4  175511      dog
5  249279  product
6  225936      bus
7  450136  product
8  410832      dog
9  371055      car

Class distribution in predictions:
label
product     505
bus         268
dog         263
car         141
cat         139
football     59
cricket      26
Name: count, dtype: int64



## Model checkpoints and augmented datasets are accessible at the specified paths:
* Since we already trained our model,we use our checkpoints using which our submissions were made
- Classifier checkpoint: "https://www.kaggle.com/datasets/alcidesthegreat/new-classifier-3/data"
- SupCon model checkpoint: "https://www.kaggle.com/datasets/alcidesthegreat/advanced-supcon-final-resnet101-2/data"
- Augmented dataset: "https://www.kaggle.com/datasets/alcidesthegreat/augmented-256"
- Original data: "https://www.kaggle.com/datasets/anubhabbhattacharya7/openaimer-data"

