In [1]:
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR100
from typing import Optional, Callable
import os
import timm
import numpy as np
import pandas as pd
from torchvision.transforms import v2
from torch.backends import cudnn
from torch import GradScaler
from torch import optim
from tqdm import tqdm
# import wandb
from datetime import datetime

# ============ Configuration ============
config = {
    "dataset": "cifar100_noisy",
    "model": "resnet18",
    "pretrained": "imagenet",
    "epochs": 100,
    "batch_size": 128,
    "lr": 0.001,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "nesterov": True,
    "label_smoothing": 0.15,
    "optimizer": "adamw",
    "scheduler": "cosine",
    "cosine_eta_min": 1e-6,
    "early_stop_patience": 15,
    "early_stop_mode": "max",
    "early_stop_min_delta": 0.0,
    "device": "cuda",
    "mixed_precision": True,
    "wandb_project": "cifar100-noisy-competition",
    "upscale_size": 224,  # Upscale images from 32x32 to 224x224
    "aug_alpha": 0.5,          # Optimal alpha for Beta distribution
    "cutmix_prob": 1.0,        # Apply every batch (standard for strong regularization)
    "switch_epoch": 50         # 0-49: MixUp, 50-100: CutMix
}

device = torch.device(config["device"])
print(f"Using device: {device}")
cudnn.benchmark = True
pin_memory = True
enable_half = config["mixed_precision"]  # Disable for CPU, it is slower!
scaler = GradScaler(device, enabled=enable_half)

class SimpleCachedDataset(Dataset):
    def __init__(self, dataset):
        # Runtime transforms are not implemented in this simple cached dataset.
        self.data = tuple([x for x in dataset])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

class PreprocessedDataset(Dataset):
    """
    Cache preprocessed tensors - apply transforms once and store results.
    
    PERFORMANCE OPTIMIZATION:
    - Applies deterministic transforms (ToImage, Resize) once at startup
    - Stores uint8 tensors (4x less memory than float32)
    - Random augmentations applied at runtime each epoch
    - Test set only needs normalization at runtime (huge speedup!)
    """
    def __init__(self, dataset, transform):
        print(f"Preprocessing {len(dataset)} images (this happens once)...")
        self.data = []
        self.targets = []
        
        for img, target in tqdm(dataset, desc="Caching", leave=False):
            transformed = transform(img)
            self.data.append(transformed)
            self.targets.append(target)
        
        print(f"Cached {len(self.data)} preprocessed images")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        return self.data[i], self.targets[i]

class AugmentationWrapper(Dataset):
    """Apply runtime augmentations on already preprocessed tensors."""
    def __init__(self, preprocessed_dataset, runtime_transforms):
        self.dataset = preprocessed_dataset
        self.runtime_transforms = runtime_transforms
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, i):
        img_tensor, target = self.dataset[i]
        
        # Apply runtime augmentations (random transforms)
        if self.runtime_transforms is not None:
            img_tensor = self.runtime_transforms(img_tensor)
        
        return img_tensor, target

class CIFAR100_noisy_fine(Dataset):
    """
    See https://github.com/UCSC-REAL/cifar-10-100n, https://www.noisylabels.com/ and `Learning with Noisy Labels
    Revisited: A Study Using Real-World Human Annotations`.
    """

    def __init__(
        self, root: str, train: bool, transform: Optional[Callable], download: bool
    ):
        cifar100 = CIFAR100(
            root=root, train=train, transform=None, download=download
        )
        data, targets = tuple(zip(*cifar100))

        if train:
            noisy_label_file = os.path.join(root, "CIFAR-100-noisy.npz")
            if not os.path.isfile(noisy_label_file):
                raise FileNotFoundError(
                    f"{type(self).__name__} need {noisy_label_file} to be used!"
                )

            noise_file = np.load(noisy_label_file)
            if not np.array_equal(noise_file["clean_label"], targets):
                raise RuntimeError("Clean labels do not match!")
            targets = noise_file["noisy_label"]

        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, i: int):
        return self.data[i], self.targets[i]


class EarlyStopping:
    """Early stopping to stop training when validation metric doesn't improve."""
    def __init__(self, patience=10, min_delta=0.0, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
        
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'max':
            # For accuracy (higher is better)
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1
        else:
            # For loss (lower is better)
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1
        
        if self.counter >= self.patience:
            self.early_stop = True
        
        return self.early_stop


# === PREPROCESSING (applied once and cached) ===
# Only deterministic, spatial transforms - stores uint8 tensors (saves memory!)
preprocess_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(config["upscale_size"]),  # Upscale from 32x32 to 128x128
])

# === RUNTIME AUGMENTATION (applied at each epoch) ===
# Random transforms for training (includes normalization at the end)
train_runtime_transforms = v2.Compose([
    v2.RandomCrop(config["upscale_size"], padding=4),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    v2.RandomRotation(15),
    v2.ToDtype(torch.float32, scale=True),  # Convert to float [0,1]
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),  # CIFAR-100 stats
    v2.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))
])

# Test set: only normalization needed (spatial transforms already done)
test_runtime_transforms = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load raw datasets
print("Loading datasets...")
train_set_raw = CIFAR100_noisy_fine('/kaggle/input/fii-atnn-2025-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=True, transform=None)
test_set_raw = CIFAR100_noisy_fine('/kaggle/input/fii-atnn-2025-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=False, transform=None)

# Cache raw PIL images (fast, lightweight)
train_set_cached = SimpleCachedDataset(train_set_raw)
test_set_cached = SimpleCachedDataset(test_set_raw)

# Preprocess and cache as tensors (done once!)
print("\n[TRAIN SET]")
train_set_preprocessed = PreprocessedDataset(train_set_cached, preprocess_transforms)
print("\n[TEST SET]")
test_set_preprocessed = PreprocessedDataset(test_set_cached, preprocess_transforms)

# Add runtime augmentations (applied each epoch for train, none for test)
train_set = AugmentationWrapper(train_set_preprocessed, train_runtime_transforms)
test_set = AugmentationWrapper(test_set_preprocessed, test_runtime_transforms)

print(f"\nTrain set ready: {len(train_set)} samples (with runtime augmentation)")
print(f"Test set ready: {len(test_set)} samples (fully cached)\n")

train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, pin_memory=pin_memory,num_workers=2,persistent_workers=True)
test_loader = DataLoader(test_set, batch_size=500, pin_memory=pin_memory,num_workers=2,persistent_workers=True)

# Load ResNet18 pretrained on ImageNet
print(f"Loading model: {config['model']} (pretrained on {config['pretrained']})")
model = timm.create_model(config["model"], pretrained=True, num_classes=100)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}\n")

# Label smoothing helps with noisy labels
criterion = nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"])

# Create optimizer based on config
if config["optimizer"].lower() == "adamw":
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
        fused=True
    )
    print(f"Optimizer: AdamW (lr={config['lr']}, weight_decay={config['weight_decay']})")
elif config["optimizer"].lower() == "sgd":
    optimizer = optim.SGD(
        model.parameters(), 
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=config["nesterov"],
        fused=True
    )
    print(f"Optimizer: SGD (lr={config['lr']}, momentum={config['momentum']}, weight_decay={config['weight_decay']}, nesterov={config['nesterov']})")
else:
    raise ValueError(f"Unknown optimizer: {config['optimizer']}. Supported: 'sgd', 'adamw'")

# Learning rate scheduler
if config["scheduler"] == "steplr":
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.get("step_size", 30),
        gamma=config.get("gamma", 0.1)
    )
    print(f"Scheduler: StepLR (step_size={config.get('step_size', 30)}, gamma={config.get('gamma', 0.1)})")
elif config["scheduler"] == "cosine":
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config["epochs"],
        eta_min=config.get("cosine_eta_min", 1e-6)
    )
    print(f"Scheduler: CosineAnnealingLR (T_max={config['epochs']}, eta_min={config.get('cosine_eta_min', 1e-6)})")
else:
    scheduler = None
    print("Scheduler: None")

# === CUTMIX HELPER FUNCTION ===
def rand_bbox(size, lam):
    """Generates a random bounding box for CutMix."""
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def train(epoch):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    model.train()
    correct = 0
    total = 0
    running_loss = 0.0

    # Decide PHASE: MixUp or CutMix?
    use_cutmix = epoch >= config["switch_epoch"]
    aug_mode = "CutMix" if use_cutmix else "MixUp"
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        # Generate random permutation for mixing
        rand_index = torch.randperm(inputs.size(0)).to(device)
        target_a = targets
        target_b = targets[rand_index]
        # Beta distribution for lambda
        lam = np.random.beta(config["aug_alpha"], config["aug_alpha"])
        if use_cutmix:
            # === CutMix ===
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            # Adjust lambda to match exact pixel area changed
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            
            # Copy patch from image B to image A
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
        else:
            # === MixUp ===
            inputs = lam * inputs + (1 - lam) * inputs[rand_index, :]
        
        # --- PHASED AUGMENTATION LOGIC END ---
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            # Mixed Loss Calculation: lambda * loss(a) + (1-lambda) * loss(b)
            loss = lam * criterion(outputs, target_a) + (1 - lam) * criterion(outputs, target_b)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        optimizer.zero_grad()

        running_loss += loss.item() * inputs.size(0)
        
        # Accuracy calculation (Approximation for mixed labels: use standard argmax)
        # Note: Accuracy is less meaningful during Mixup/Cutmix training
        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc, aug_mode

@torch.inference_mode()
def val():
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        running_loss += loss.item() * inputs.size(0)
        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

@torch.inference_mode()
def inference():
    model.eval()
    
    labels = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)

        predicted = outputs.argmax(1).tolist()
        labels.extend(predicted)
    
    return labels

# Initialize WandB
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"cifar100noisy_{config['model']}_{config['optimizer']}_lr{config['lr']}_bs{config['batch_size']}_{timestamp}"

# wandb.init(
#     project=config["wandb_project"],
#     name=run_name,
#     config=config
# )

best = 0.0
best_epoch = 0

# Initialize early stopping
early_stopping = EarlyStopping(
    patience=config["early_stop_patience"],
    min_delta=config["early_stop_min_delta"],
    mode=config["early_stop_mode"]
)

print(f"\n{'='*70}")
print(f"Starting Training - {config['epochs']} epochs")
print(f"Model: {config['model']} (pretrained on {config['pretrained']})")
print(f"Optimizer: {config['optimizer'].upper()}, LR: {config['lr']}, Batch Size: {config['batch_size']}")
if config["optimizer"].lower() == "sgd":
    print(f"Momentum: {config['momentum']}, Nesterov: {config['nesterov']}")
print(f"Weight Decay: {config['weight_decay']}, Label Smoothing: {config['label_smoothing']}")
print(f"Scheduler: {config['scheduler']}")
print(f"Early Stopping: Enabled (patience={config['early_stop_patience']}, mode={config['early_stop_mode']})")
print(f"{'='*70}\n")

with tqdm(range(config["epochs"])) as tbar:
    for epoch in tbar:
        train_loss, train_acc, mode = train(epoch)
        val_loss, val_acc = val()
        
        # Update learning rate
        if scheduler is not None:
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
        else:
            current_lr = config["lr"]
        
        if val_acc > best:
            best = val_acc
            best_epoch = epoch
            # Save best model
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }
            if scheduler is not None:
                checkpoint['scheduler_state_dict'] = scheduler.state_dict()
            torch.save(checkpoint, './best_model.pth')
        
        # Log to WandB
        # wandb.log({
        #     "epoch": epoch,
        #     "train_loss": train_loss,
        #     "train_acc": train_acc,
        #     "val_loss": val_loss,
        #     "val_acc": val_acc,
        #     "best_val_acc": best,
        #     "lr": current_lr,
        #     "aug_mode": 1 if mode == "CutMix" else 0
        # })
        
        tbar.set_description(f"Epoch {epoch+1}/{config['epochs']} | Train: {train_acc:.2f}% | Val: {val_acc:.2f}% | Best: {best:.2f}% | LR: {current_lr:.6f}")
        
        # Early stopping check
        if early_stopping(val_acc, epoch):
            print(f"\n{'='*60}")
            print(f"Early stopping triggered at epoch {epoch+1}")
            print(f"Best Val Accuracy: {best:.2f}% at epoch {best_epoch+1}")
            print(f"No improvement for {config['early_stop_patience']} epochs")
            print(f"{'='*60}\n")
            break
    

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"Best Val Accuracy: {best:.2f}% at epoch {best_epoch+1}")
print(f"Loading best model for inference...")
print(f"{'='*60}\n")

# Load best model for inference
checkpoint = torch.load('./best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model loaded (Epoch {checkpoint['epoch']+1}, Val Acc: {checkpoint['val_acc']:.2f}%)\n")

# Generate submission
data = {
    "ID": [],
    "target": []
}

print("Generating predictions...")
for i, label in enumerate(inference()):
    data["ID"].append(i)
    data["target"].append(label)

df = pd.DataFrame(data)
df.to_csv("/kaggle/working/submission.csv", index=False)

# Log final results to WandB
# wandb.summary["final_best_val_acc"] = best
# wandb.summary["best_epoch"] = best_epoch
# wandb.summary["total_epochs"] = config["epochs"]
# if scheduler is not None:
#     wandb.summary["final_lr"] = scheduler.get_last_lr()[0]
# else:
#     wandb.summary["final_lr"] = config["lr"]
# wandb.summary["early_stopped"] = early_stopping.early_stop
# wandb.summary["epochs_trained"] = best_epoch + 1 if early_stopping.early_stop else config["epochs"]

print(f"\n{'='*60}")
print(f"Submission saved to: ./submission.csv")
print(f"Best model saved to: ./best_model.pth")
print(f"Best Val Accuracy: {best:.2f}% (Epoch {best_epoch+1})")
if early_stopping.early_stop:
    print(f"Training stopped early (patience reached)")
print(f"{'='*60}\n")

# Finish WandB run
# 

Using device: cuda
Loading datasets...

[TRAIN SET]
Preprocessing 50000 images (this happens once)...


                                                                

Cached 50000 preprocessed images

[TEST SET]
Preprocessing 10000 images (this happens once)...


                                                               

Cached 10000 preprocessed images

Train set ready: 50000 samples (with runtime augmentation)
Test set ready: 10000 samples (fully cached)

Loading model: resnet18 (pretrained on imagenet)


model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Total parameters: 11,227,812
Trainable parameters: 11,227,812

Optimizer: AdamW (lr=0.001, weight_decay=0.0005)
Scheduler: CosineAnnealingLR (T_max=100, eta_min=1e-06)

Starting Training - 100 epochs
Model: resnet18 (pretrained on imagenet)
Optimizer: ADAMW, LR: 0.001, Batch Size: 128
Weight Decay: 0.0005, Label Smoothing: 0.15
Scheduler: cosine
Early Stopping: Enabled (patience=15, mode=max)



  0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1/100


Epoch 1/100 | Train: 14.96% | Val: 54.82% | Best: 54.82% | LR: 0.001000:   1%|          | 1/100 [03:51<6:22:23, 231.76s/it]


Epoch 2/100


Epoch 2/100 | Train: 22.15% | Val: 62.69% | Best: 62.69% | LR: 0.000999:   2%|▏         | 2/100 [07:35<6:10:25, 226.79s/it]


Epoch 3/100


Epoch 3/100 | Train: 23.67% | Val: 62.56% | Best: 62.69% | LR: 0.000998:   3%|▎         | 3/100 [11:18<6:03:52, 225.08s/it]


Epoch 4/100


Epoch 4/100 | Train: 24.77% | Val: 62.64% | Best: 62.69% | LR: 0.000996:   4%|▍         | 4/100 [14:58<5:56:53, 223.06s/it]


Epoch 5/100


Epoch 5/100 | Train: 25.17% | Val: 66.29% | Best: 66.29% | LR: 0.000994:   5%|▌         | 5/100 [18:42<5:53:52, 223.49s/it]


Epoch 6/100


Epoch 6/100 | Train: 25.82% | Val: 66.41% | Best: 66.41% | LR: 0.000991:   6%|▌         | 6/100 [22:22<5:48:33, 222.49s/it]


Epoch 7/100


Epoch 7/100 | Train: 26.56% | Val: 67.64% | Best: 67.64% | LR: 0.000988:   7%|▋         | 7/100 [26:06<5:45:21, 222.81s/it]


Epoch 8/100


Epoch 8/100 | Train: 27.35% | Val: 65.80% | Best: 67.64% | LR: 0.000984:   8%|▊         | 8/100 [29:47<5:41:01, 222.41s/it]


Epoch 9/100


Epoch 9/100 | Train: 27.91% | Val: 67.97% | Best: 67.97% | LR: 0.000980:   9%|▉         | 9/100 [33:28<5:36:26, 221.83s/it]


Epoch 10/100


Epoch 10/100 | Train: 25.77% | Val: 67.72% | Best: 67.97% | LR: 0.000976:  10%|█         | 10/100 [37:12<5:33:47, 222.53s/it]


Epoch 11/100


Epoch 11/100 | Train: 28.11% | Val: 66.21% | Best: 67.97% | LR: 0.000970:  11%|█         | 11/100 [40:56<5:30:45, 222.98s/it]


Epoch 12/100


Epoch 12/100 | Train: 31.39% | Val: 66.24% | Best: 67.97% | LR: 0.000965:  12%|█▏        | 12/100 [44:37<5:25:56, 222.23s/it]


Epoch 13/100


Epoch 13/100 | Train: 30.44% | Val: 66.43% | Best: 67.97% | LR: 0.000959:  13%|█▎        | 13/100 [48:18<5:21:42, 221.87s/it]


Epoch 14/100


Epoch 14/100 | Train: 31.07% | Val: 67.54% | Best: 67.97% | LR: 0.000952:  14%|█▍        | 14/100 [51:59<5:17:35, 221.57s/it]


Epoch 15/100


Epoch 15/100 | Train: 31.04% | Val: 66.55% | Best: 67.97% | LR: 0.000946:  15%|█▌        | 15/100 [55:39<5:13:20, 221.18s/it]


Epoch 16/100


Epoch 16/100 | Train: 30.36% | Val: 66.02% | Best: 67.97% | LR: 0.000938:  16%|█▌        | 16/100 [59:20<5:09:48, 221.29s/it]


Epoch 17/100


Epoch 17/100 | Train: 31.01% | Val: 65.86% | Best: 67.97% | LR: 0.000930:  17%|█▋        | 17/100 [1:03:03<5:06:32, 221.59s/it]


Epoch 18/100


Epoch 18/100 | Train: 31.44% | Val: 66.53% | Best: 67.97% | LR: 0.000922:  18%|█▊        | 18/100 [1:06:48<5:04:16, 222.64s/it]


Epoch 19/100


Epoch 19/100 | Train: 30.50% | Val: 66.47% | Best: 67.97% | LR: 0.000914:  19%|█▉        | 19/100 [1:10:28<4:59:43, 222.02s/it]


Epoch 20/100


Epoch 20/100 | Train: 34.09% | Val: 65.08% | Best: 67.97% | LR: 0.000905:  20%|██        | 20/100 [1:14:10<4:56:06, 222.08s/it]


Epoch 21/100


Epoch 21/100 | Train: 34.47% | Val: 66.24% | Best: 67.97% | LR: 0.000895:  21%|██        | 21/100 [1:17:58<4:54:42, 223.83s/it]


Epoch 22/100


Epoch 22/100 | Train: 32.06% | Val: 64.97% | Best: 67.97% | LR: 0.000885:  22%|██▏       | 22/100 [1:21:43<4:51:20, 224.10s/it]


Epoch 23/100


Epoch 23/100 | Train: 33.39% | Val: 64.20% | Best: 67.97% | LR: 0.000875:  23%|██▎       | 23/100 [1:25:26<4:47:09, 223.76s/it]


Epoch 24/100


Epoch 24/100 | Train: 34.42% | Val: 65.48% | Best: 67.97% | LR: 0.000865:  23%|██▎       | 23/100 [1:29:06<4:58:20, 232.47s/it]
  checkpoint = torch.load('./best_model.pth')



Early stopping triggered at epoch 24
Best Val Accuracy: 67.97% at epoch 9
No improvement for 15 epochs


Training Complete!
Best Val Accuracy: 67.97% at epoch 9
Loading best model for inference...

Best model loaded (Epoch 9, Val Acc: 67.97%)

Generating predictions...

Submission saved to: ./submission.csv
Best model saved to: ./best_model.pth
Best Val Accuracy: 67.97% (Epoch 9)
Training stopped early (patience reached)

