## Step 1: Import Libraries & SOTA Configuration

In [1]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import os, copy, random, gc
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

# ==========================================
# Reproducibility & Determinism
# ==========================================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): 
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

# ==========================================
# Device Setup
# ==========================================
if torch.backends.mps.is_available():
    device = torch.device('mps'); print('Using Apple MPS (GPU) üöÄ')
elif torch.cuda.is_available():
    device = torch.device('cuda'); print(f'Using CUDA: {torch.cuda.get_device_name(0)}')
else:
    device = torch.device('cpu'); print('Using CPU')

# ==========================================
# Paths
# ==========================================
if os.path.exists('/kaggle/input'):
    BASE_DIR = '/kaggle/input/cs-460-muffin-vs-chihuahua-classification-challenge'
else:
    BASE_DIR = './data'
TRAIN_DIR = f'{BASE_DIR}/train'
TEST_DIR  = f'{BASE_DIR}/kaggle_test_final'

# ==========================================
# SOTA Hyperparameters (2025/2026 Focus)
# ==========================================
IMG_SIZE        = 384      # 384x384 standard for ViTs (Swin)
BATCH_SIZE      = 8        # Small batch due to heavy models (Swin-V2 + ConvNeXt)
GRAD_ACCUM      = 4        # Effective batch = 32
PHASE1_EPOCHS   = 5        # Warmup (Head only)
PHASE2_EPOCHS   = 35       # Deep fine-tuning
PHASE1_LR       = 1e-3
PHASE2_LR       = 5e-5     # Lower LR for ViT stability
WEIGHT_DECAY    = 0.05     # Higher weight decay for Transformers (AdamW standard)
LABEL_SMOOTHING = 0.1
MIXUP_ALPHA     = 0.2      # Beta distribution param for Mixup
CUTMIX_ALPHA    = 1.0      # Beta distribution param for CutMix
MIX_PROB        = 0.5      # 50% chance to apply Mixup/CutMix
GRAD_CLIP       = 1.0      # Prevent exploding gradients
DROP_PATH_RATE  = 0.2      # Stochastic Depth (Crucial for deep ViT/CNN)
PATIENCE        = 8        # Early stopping patience
NUM_WORKERS     = 0        # 0 for MPS stability
N_FOLDS         = 5        # 5-Fold Stratified Split
FOLD_TO_TRAIN   = 0        # Train fold 0 to keep it manageable

print('‚öôÔ∏è SOTA Setup Complete!')

Using Apple MPS (GPU) üöÄ
‚öôÔ∏è SOTA Setup Complete!


## Step 2: Advanced Data Augmentation & Mixup/CutMix Engine
SOTA pipelines rely heavily on data mixing to prevent memorization.

In [2]:
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

# ‚îÄ‚îÄ Train: Extreme Augmentation (AutoAugment style) ‚îÄ‚îÄ
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)), # Slightly milder erasing
])

# ‚îÄ‚îÄ Val: Center Crop ‚îÄ‚îÄ
val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

# ‚îÄ‚îÄ Custom Crop Transform (for TTA corner crops) ‚îÄ‚îÄ
class CropTransform:
    """Crop a fixed region from a PIL Image."""
    def __init__(self, top, left, height, width):
        self.top, self.left, self.height, self.width = top, left, height, width
    def __call__(self, img):
        return transforms.functional.crop(img, self.top, self.left, self.height, self.width)

# ‚îÄ‚îÄ TTA (Saccadic Vision 12-Passes) ‚îÄ‚îÄ
_SZ = IMG_SIZE; _LG = int(IMG_SIZE * 1.15)
tta_transforms_list = [
    # 1) Normal center crop
    val_transforms,
    # 2) Horizontal flip
    transforms.Compose([transforms.Resize((_SZ,_SZ)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    # 3-4) Close up center crop + flipped
    transforms.Compose([transforms.Resize((_LG,_LG)), transforms.CenterCrop(_SZ), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    transforms.Compose([transforms.Resize((_LG,_LG)), transforms.CenterCrop(_SZ), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    # 5-6) Rotations
    transforms.Compose([transforms.Resize((_SZ,_SZ)), transforms.RandomRotation((10,10)), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    transforms.Compose([transforms.Resize((_SZ,_SZ)), transforms.RandomRotation((-10,-10)), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    # 7) Color shifts
    transforms.Compose([transforms.Resize((_SZ,_SZ)), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    # 8-11) Corner crops (Top-Left, Top-Right, Bottom-Left, Bottom-Right)
    transforms.Compose([transforms.Resize((_LG,_LG)), CropTransform(0, 0, _SZ, _SZ), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    transforms.Compose([transforms.Resize((_LG,_LG)), CropTransform(0, _LG-_SZ, _SZ, _SZ), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    transforms.Compose([transforms.Resize((_LG,_LG)), CropTransform(_LG-_SZ, 0, _SZ, _SZ), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    transforms.Compose([transforms.Resize((_LG,_LG)), CropTransform(_LG-_SZ, _LG-_SZ, _SZ, _SZ), transforms.ToTensor(), transforms.Normalize(MEAN,STD)]),
    # 12) Gaussian Blur
    transforms.Compose([transforms.Resize((_SZ,_SZ)), transforms.GaussianBlur(3, 0.5), transforms.ToTensor(), transforms.Normalize(MEAN,STD)])
]

# ‚îÄ‚îÄ Mixup / CutMix Engine ‚îÄ‚îÄ
def rand_bbox(size, lam):
    """Generate bounding box for CutMix"""
    W, H = size[2], size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup_cutmix(x, y):
    """Randomly apply Mixup or CutMix"""
    mode = random.choice(['mixup', 'cutmix'])
    alpha = MIXUP_ALPHA if mode == 'mixup' else CUTMIX_ALPHA
    
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0)).to(x.device)
    
    if mode == 'mixup':
        mixed_x = lam * x + (1 - lam) * x[idx]
    else: # cutmix
        bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
        mixed_x = x.clone()
        mixed_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
        # Adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
        
    return mixed_x, y, y[idx], lam

print('SOTA Augmentation & Mixup Engine Ready!')

SOTA Augmentation & Mixup Engine Ready!


## Step 3: Stratified K-Fold Data Loading

In [3]:
class TransformSubset(Dataset):
    def __init__(self, dataset, indices, transform):
        self.dataset, self.indices, self.transform = dataset, indices, transform
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        if self.transform: img = self.transform(img)
        return img, label

full_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=None)
classes      = full_dataset.classes
class_to_idx = full_dataset.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}
print(f'Classes: {classes} | Total Images: {len(full_dataset)}')

# Stratified 5-Fold Setup
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
folds = list(skf.split(np.zeros(len(full_dataset)), full_dataset.targets))
train_idx, val_idx = folds[FOLD_TO_TRAIN]

train_dataset = TransformSubset(full_dataset, train_idx, train_transforms)
val_dataset   = TransformSubset(full_dataset, val_idx,   val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print(f'Fold {FOLD_TO_TRAIN}: Train={len(train_dataset)} | Val={len(val_dataset)}')

Classes: ['chihuahua', 'muffin'] | Total Images: 4733
Fold 0: Train=3786 | Val=947


## Step 4: SOTA Hybrid Ensemble Architecture (ViT + CNN)
- **Model A**: Swin-V2-Base (State-of-the-Art Vision Transformer)
- **Model B**: ConvNeXt-Base (State-of-the-Art CNN)

In [4]:
class ModelEMA:
    """ Exponential Moving Average (EMA) of model weights for extreme stability. """
    def __init__(self, model, decay=0.999):
        self.model = copy.deepcopy(model).eval()
        self.decay = decay
        for param in self.model.parameters():
            param.requires_grad = False
            
    def update(self, new_model):
        with torch.no_grad():
            for ema_v, model_v in zip(self.model.state_dict().values(), new_model.state_dict().values()):
                ema_v.copy_(self.decay * ema_v + (1.0 - self.decay) * model_v)

def create_model(model_name='swin_v2_b', num_classes=2, freeze_backbone=True):
    if model_name == 'swin_v2_b':
        print("Loading Swin-V2-Base (ViT)...")
        # Pre-trained Swin-V2-Base (stochastic depth is built-in)
        model = models.swin_v2_b(weights=models.Swin_V2_B_Weights.IMAGENET1K_V1)
        backbone_params = list(model.features.parameters())
        head_name = "head"
        in_ft = model.head.in_features
        model.head = nn.Linear(in_ft, num_classes)
        
    elif model_name == 'convnext_base':
        print("Loading ConvNeXt-Base (CNN)...")
        model = models.convnext_base(weights=models.ConvNeXt_Base_Weights.IMAGENET1K_V1)
        backbone_params = list(model.features.parameters())
        head_name = "classifier"
        in_ft = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_ft, num_classes)
    else:
        raise ValueError("Unknown model")

    if freeze_backbone:
        for p in backbone_params:
            p.requires_grad = False
            
    return model.to(device)

# Provide Test Dataset class too
class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir, self.transform = test_dir, transform
        self.image_files = sorted([f for f in os.listdir(test_dir) if f.lower().endswith(('.jpg','.jpeg','.png'))])
    def __len__(self): return len(self.image_files)
    def __getitem__(self, idx):
        name = self.image_files[idx]
        img  = Image.open(os.path.join(self.test_dir, name)).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, name

## Step 5: Advanced Training & EMA Loop

In [5]:
def train_and_evaluate(model, model_name, save_path):
    print(f"\n{'='*60}\nüöÄ Starting Training Pipeline for: {model_name}\n{'='*60}")
    
    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
    ema_model = ModelEMA(model)
    best_val_acc = 0.0
    best_weights = copy.deepcopy(model.state_dict())
    
    # ‚îÄ‚îÄ Phase 1: Warmup Head ‚îÄ‚îÄ
    opt1 = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=PHASE1_LR, weight_decay=WEIGHT_DECAY)
    
    print("\n--- Phase 1: Warming up classification head ---")
    for ep in range(1, PHASE1_EPOCHS + 1):
        # Train
        model.train()
        loss_sum, total = 0.0, 0
        opt1.zero_grad()
        for i, (images, labels) in enumerate(tqdm(train_loader, leave=False, desc=f"P1 - Ep {ep}")):
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            loss = criterion(out, labels)
            (loss / GRAD_ACCUM).backward()
            
            if (i + 1) % GRAD_ACCUM == 0 or (i + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                opt1.step()
                opt1.zero_grad()
                ema_model.update(model) # Update EMA heavily even in P1
            loss_sum += loss.item() * images.size(0)
            total += images.size(0)
            
        # Eval (using EMA model for stability)
        ema_model.model.eval()
        v_loss, correct, v_total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                out = ema_model.model(images.to(device))
                v_loss += criterion(out, labels.to(device)).item() * images.size(0)
                _, preds = torch.max(out, 1)
                correct += (preds == labels.to(device)).sum().item()
                v_total += images.size(0)
                
        t_loss, v_loss, v_acc = loss_sum/total, v_loss/v_total, 100.0*correct/v_total
        print(f"  [Head-Only Epoch {ep}] Train Loss: {t_loss:.4f} | EMA Val Acc: {v_acc:.2f}%")

    
    # ‚îÄ‚îÄ Phase 2: Full Backbone Fine-tuning with Mixup/CutMix ‚îÄ‚îÄ
    # Unfreeze all
    for p in model.parameters(): p.requires_grad = True
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n--- Phase 2: Full Fine-tuning ({trainable:,} params) ---")
    
    opt2 = optim.AdamW(model.parameters(), lr=PHASE2_LR, weight_decay=WEIGHT_DECAY)
    # Cosine scheduling down to tiny LR
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=PHASE2_EPOCHS, eta_min=1e-6)
    
    no_improve = 0
    for ep in range(1, PHASE2_EPOCHS + 1):
        model.train()
        loss_sum, total = 0.0, 0
        opt2.zero_grad()
        for i, (images, labels) in enumerate(tqdm(train_loader, leave=False, desc=f"P2 - Ep {ep}")):
            images, labels = images.to(device), labels.to(device)
            
            # Apply Mixup/CutMix 50% of the time in Phase 2
            if random.random() < MIX_PROB:
                mixed, y_a, y_b, lam = apply_mixup_cutmix(images, labels)
                out = model(mixed)
                loss = lam * criterion(out, y_a) + (1 - lam) * criterion(out, y_b)
            else:
                out = model(images)
                loss = criterion(out, labels)
                
            (loss / GRAD_ACCUM).backward()
            
            if (i + 1) % GRAD_ACCUM == 0 or (i + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                opt2.step()
                opt2.zero_grad()
                ema_model.update(model)
                
            loss_sum += loss.item() * images.size(0)
            total += images.size(0)
            
        scheduler.step()
        
        # Eval using EMA
        ema_model.model.eval()
        correct, v_total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                out = ema_model.model(images.to(device))
                _, preds = torch.max(out, 1)
                correct += (preds == labels.to(device)).sum().item()
                v_total += images.size(0)
                
        t_loss = loss_sum / total
        v_acc = 100.0 * correct / v_total
        
        flag = " ‚≠ê BEST (Saving)" if v_acc > best_val_acc else ""
        print(f"  [Epoch {ep:02d}/{PHASE2_EPOCHS}] Train Loss: {t_loss:.4f} | EMA Val Acc: {v_acc:.2f}% {flag}")
        
        if v_acc > best_val_acc:
            best_val_acc = v_acc
            best_weights = copy.deepcopy(ema_model.model.state_dict())
            torch.save(best_weights, save_path)
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                print(f"üõë Early stopping triggered after {ep} epochs.")
                break
                
    # Free memory
    del model, ema_model, opt1, opt2, criterion
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return best_val_acc

## Step 6: Train Model A ‚Äî Swin-V2-Base (Vision Transformer)

In [6]:
model_a = create_model('swin_v2_b', freeze_backbone=True)
acc_a = train_and_evaluate(model_a, 'Swin-V2-Base', 'best_swin_v2.pth')

Loading Swin-V2-Base (ViT)...

üöÄ Starting Training Pipeline for: Swin-V2-Base

--- Phase 1: Warming up classification head ---


                                                            

  [Head-Only Epoch 1] Train Loss: 0.3331 | EMA Val Acc: 46.36%


                                                            

  [Head-Only Epoch 2] Train Loss: 0.2452 | EMA Val Acc: 86.80%


                                                            

  [Head-Only Epoch 3] Train Loss: 0.2376 | EMA Val Acc: 96.30%


                                                            

  [Head-Only Epoch 4] Train Loss: 0.2345 | EMA Val Acc: 98.20%


                                                            

  [Head-Only Epoch 5] Train Loss: 0.2322 | EMA Val Acc: 99.05%

--- Phase 2: Full Fine-tuning (86,907,898 params) ---


                                                            

  [Epoch 01/35] Train Loss: 0.2812 | EMA Val Acc: 99.37%  ‚≠ê BEST (Saving)


                                                              

  [Epoch 02/35] Train Loss: 0.2654 | EMA Val Acc: 99.58%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 03/35] Train Loss: 0.2698 | EMA Val Acc: 99.58% 


                                                            

  [Epoch 04/35] Train Loss: 0.2636 | EMA Val Acc: 99.79%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 05/35] Train Loss: 0.2621 | EMA Val Acc: 99.79% 


                                                            

  [Epoch 06/35] Train Loss: 0.2537 | EMA Val Acc: 99.79% 


                                                            

  [Epoch 07/35] Train Loss: 0.2645 | EMA Val Acc: 99.79% 


                                                            

  [Epoch 08/35] Train Loss: 0.2669 | EMA Val Acc: 99.79% 


                                                            

  [Epoch 09/35] Train Loss: 0.2555 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 10/35] Train Loss: 0.2578 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 11/35] Train Loss: 0.2608 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 12/35] Train Loss: 0.2646 | EMA Val Acc: 99.79% 
üõë Early stopping triggered after 12 epochs.


## Step 7: Train Model B ‚Äî ConvNeXt-Base (CNN)

In [7]:
model_b = create_model('convnext_base', freeze_backbone=True)
acc_b = train_and_evaluate(model_b, 'ConvNeXt-Base', 'best_convnext_sota.pth')

print(f"\nüèÜ Final Single Model Accuracies -> Swin-V2: {acc_a:.2f}% | ConvNeXt: {acc_b:.2f}%")

Loading ConvNeXt-Base (CNN)...

üöÄ Starting Training Pipeline for: ConvNeXt-Base

--- Phase 1: Warming up classification head ---


                                                            

  [Head-Only Epoch 1] Train Loss: 0.2687 | EMA Val Acc: 92.29%


                                                            

  [Head-Only Epoch 2] Train Loss: 0.2294 | EMA Val Acc: 97.99%


                                                            

  [Head-Only Epoch 3] Train Loss: 0.2243 | EMA Val Acc: 98.94%


                                                            

  [Head-Only Epoch 4] Train Loss: 0.2260 | EMA Val Acc: 99.26%


                                                            

  [Head-Only Epoch 5] Train Loss: 0.2229 | EMA Val Acc: 99.37%

--- Phase 2: Full Fine-tuning (87,568,514 params) ---


                                                            

  [Epoch 01/35] Train Loss: 0.2835 | EMA Val Acc: 99.37%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 02/35] Train Loss: 0.2771 | EMA Val Acc: 99.47%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 03/35] Train Loss: 0.2667 | EMA Val Acc: 99.47% 


                                                            

  [Epoch 04/35] Train Loss: 0.2708 | EMA Val Acc: 99.47% 


                                                            

  [Epoch 05/35] Train Loss: 0.2719 | EMA Val Acc: 99.58%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 06/35] Train Loss: 0.2650 | EMA Val Acc: 99.68%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 07/35] Train Loss: 0.2679 | EMA Val Acc: 99.68% 


                                                            

  [Epoch 08/35] Train Loss: 0.2653 | EMA Val Acc: 99.79%  ‚≠ê BEST (Saving)


                                                            

  [Epoch 09/35] Train Loss: 0.2617 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 10/35] Train Loss: 0.2565 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 11/35] Train Loss: 0.2569 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 12/35] Train Loss: 0.2620 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 13/35] Train Loss: 0.2656 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 14/35] Train Loss: 0.2640 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 15/35] Train Loss: 0.2696 | EMA Val Acc: 99.79% 


                                                             

  [Epoch 16/35] Train Loss: 0.2556 | EMA Val Acc: 99.79% 
üõë Early stopping triggered after 16 epochs.

üèÜ Final Single Model Accuracies -> Swin-V2: 99.79% | ConvNeXt: 99.79%


## Step 8: Multi-Scale Saccadic TTA Inference (ViT + CNN Ensemble)

In [8]:
print('\n=========================================================')
print('üåç Generating Final SOTA Ensemble Predictions (2 Models √ó 12 TTA)')
print('=========================================================')

# Reload best weights
model_a = create_model('swin_v2_b', freeze_backbone=False)
model_a.load_state_dict(torch.load('best_swin_v2.pth', weights_only=True))
model_a.eval()

model_b = create_model('convnext_base', freeze_backbone=False)
model_b.load_state_dict(torch.load('best_convnext_sota.pth', weights_only=True))
model_b.eval()

models_list = [model_a, model_b]
softmax = nn.Softmax(dim=1)
all_probs, all_filenames = None, None

total_passes = len(models_list) * len(tta_transforms_list)
pass_n = 0

for m_idx, mdl in enumerate(models_list):
    md_name = "Swin-V2" if m_idx == 0 else "ConvNeXt"
    for t_idx, tta_tf in enumerate(tta_transforms_list):
        pass_n += 1
        print(f'  üîç TTA Pass {pass_n}/{total_passes} ([{md_name}] Aug {t_idx+1})')
        td = TestDataset(TEST_DIR, transform=tta_tf)
        tl = DataLoader(td, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
        
        pass_probs, pass_files = [], []
        with torch.no_grad():
            for imgs, fnames in tqdm(tl, desc='Inference', leave=False):
                p = softmax(mdl(imgs.to(device))).cpu().numpy()
                pass_probs.append(p)
                if t_idx == 0 and m_idx == 0: 
                    pass_files.extend(fnames)
                    
        pass_probs = np.concatenate(pass_probs, axis=0)
        if all_probs is None:
            all_probs, all_filenames = pass_probs, pass_files
        else:
            all_probs += pass_probs

# Average the probabilities across all completely
all_probs /= total_passes
pred_indices = np.argmax(all_probs, axis=1)
predictions  = [idx_to_class[i] for i in pred_indices]
print(f'\n‚úÖ Done! High-confidence predictions generated for {len(predictions)} images.')


üåç Generating Final SOTA Ensemble Predictions (2 Models √ó 12 TTA)
Loading Swin-V2-Base (ViT)...
Loading ConvNeXt-Base (CNN)...
  üîç TTA Pass 1/24 ([Swin-V2] Aug 1)


                                                            

  üîç TTA Pass 2/24 ([Swin-V2] Aug 2)


                                                            

  üîç TTA Pass 3/24 ([Swin-V2] Aug 3)


                                                            

  üîç TTA Pass 4/24 ([Swin-V2] Aug 4)


                                                            

  üîç TTA Pass 5/24 ([Swin-V2] Aug 5)


                                                            

  üîç TTA Pass 6/24 ([Swin-V2] Aug 6)


                                                            

  üîç TTA Pass 7/24 ([Swin-V2] Aug 7)


                                                            

  üîç TTA Pass 8/24 ([Swin-V2] Aug 8)


                                                            

  üîç TTA Pass 9/24 ([Swin-V2] Aug 9)


                                                            

  üîç TTA Pass 10/24 ([Swin-V2] Aug 10)


                                                            

  üîç TTA Pass 11/24 ([Swin-V2] Aug 11)


                                                            

  üîç TTA Pass 12/24 ([Swin-V2] Aug 12)


                                                            

  üîç TTA Pass 13/24 ([ConvNeXt] Aug 1)


                                                            

  üîç TTA Pass 14/24 ([ConvNeXt] Aug 2)


                                                            

  üîç TTA Pass 15/24 ([ConvNeXt] Aug 3)


                                                            

  üîç TTA Pass 16/24 ([ConvNeXt] Aug 4)


                                                            

  üîç TTA Pass 17/24 ([ConvNeXt] Aug 5)


                                                            

  üîç TTA Pass 18/24 ([ConvNeXt] Aug 6)


                                                            

  üîç TTA Pass 19/24 ([ConvNeXt] Aug 7)


                                                            

  üîç TTA Pass 20/24 ([ConvNeXt] Aug 8)


                                                            

  üîç TTA Pass 21/24 ([ConvNeXt] Aug 9)


                                                            

  üîç TTA Pass 22/24 ([ConvNeXt] Aug 10)


                                                            

  üîç TTA Pass 23/24 ([ConvNeXt] Aug 11)


                                                            

  üîç TTA Pass 24/24 ([ConvNeXt] Aug 12)


                                                            


‚úÖ Done! High-confidence predictions generated for 1138 images.




## Step 9: Save Ensembled Submission

In [9]:
submission_df = pd.DataFrame({'ID': all_filenames, 'Predict': predictions})
submission_df.to_csv('submission_sota.csv', index=False)
print('\nüíæ submission_sota.csv saved!')
print("\nPrediction distribution:")
print(submission_df['Predict'].value_counts())

# Cleanup huge models from RAM
del model_a, model_b
gc.collect()


üíæ submission_sota.csv saved!

Prediction distribution:
Predict
chihuahua    641
muffin       497
Name: count, dtype: int64


40