In [1]:
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR100
from typing import Optional, Callable
import os
import timm
import numpy as np
import pandas as pd
from torchvision.transforms import v2
from torch.backends import cudnn
from torch import GradScaler
from torch import optim
from tqdm import tqdm
# import wandb
from datetime import datetime

# ============ Configuration ============
config = {
    "dataset": "cifar100_noisy",
    "model": "resnet18",
    "pretrained": "imagenet",
    "epochs": 100,
    "batch_size": 128,
    "lr": 0.001,
    "momentum": 0.9,
    "weight_decay": 0.01,
    "nesterov": True,
    "label_smoothing": 0.1,
    "optimizer": "adamw",
    "scheduler": "warm_restarts",     
    "warm_restarts_T0": 20,              
    "warm_restarts_mult": 1,             
    "cosine_eta_min": 1e-5,
    "scheduler_step_per_batch": True,
    "early_stop_patience": 15,
    "early_stop_mode": "max",
    "early_stop_min_delta": 0.1,
    "device": "cuda",
    "mixed_precision": True,
    "wandb_project": "cifar100-noisy-competition",
    "upscale_size": 224,  
    "aug_alpha": 0.5,          
    "cutmix_prob": 1.0,        
    "switch_epoch": 25,        
    "warmup_epochs": 5,
    "loss_threshold": 2.5,
    "dynamic_threshold_decay": 0.997
}

device = torch.device(config["device"])
print(f"Using device: {device}")
cudnn.benchmark = True
pin_memory = True
enable_half = config["mixed_precision"]  # Disable for CPU, it is slower!
scaler = GradScaler(device, enabled=enable_half)

class SimpleCachedDataset(Dataset):
    def __init__(self, dataset):
        # Runtime transforms are not implemented in this simple cached dataset.
        self.data = tuple([x for x in dataset])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

class PreprocessedDataset(Dataset):
    """
    Cache preprocessed tensors - apply transforms once and store results.
    
    PERFORMANCE OPTIMIZATION:
    - Applies deterministic transforms (ToImage, Resize) once at startup
    - Stores uint8 tensors (4x less memory than float32)
    - Random augmentations applied at runtime each epoch
    - Test set only needs normalization at runtime (huge speedup!)
    """
    def __init__(self, dataset, transform):
        print(f"Preprocessing {len(dataset)} images (this happens once)...")
        self.data = []
        self.targets = []
        
        for img, target in tqdm(dataset, desc="Caching", leave=False):
            transformed = transform(img)
            self.data.append(transformed)
            self.targets.append(target)
        
        print(f"Cached {len(self.data)} preprocessed images")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        return self.data[i], self.targets[i]

class AugmentationWrapper(Dataset):
    """Apply runtime augmentations on already preprocessed tensors."""
    def __init__(self, preprocessed_dataset, runtime_transforms):
        self.dataset = preprocessed_dataset
        self.runtime_transforms = runtime_transforms
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, i):
        img_tensor, target = self.dataset[i]
        
        # Apply runtime augmentations (random transforms)
        if self.runtime_transforms is not None:
            img_tensor = self.runtime_transforms(img_tensor)
        
        return img_tensor, target

class CIFAR100_noisy_fine(Dataset):
    """
    See https://github.com/UCSC-REAL/cifar-10-100n, https://www.noisylabels.com/ and `Learning with Noisy Labels
    Revisited: A Study Using Real-World Human Annotations`.
    """

    def __init__(
        self, root: str, train: bool, transform: Optional[Callable], download: bool
    ):
        cifar100 = CIFAR100(
            root=root, train=train, transform=None, download=download
        )
        data, targets = tuple(zip(*cifar100))

        if train:
            noisy_label_file = os.path.join(root, "CIFAR-100-noisy.npz")
            if not os.path.isfile(noisy_label_file):
                raise FileNotFoundError(
                    f"{type(self).__name__} need {noisy_label_file} to be used!"
                )

            noise_file = np.load(noisy_label_file)
            if not np.array_equal(noise_file["clean_label"], targets):
                raise RuntimeError("Clean labels do not match!")
            targets = noise_file["noisy_label"]

        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, i: int):
        return self.data[i], self.targets[i]


class EarlyStopping:
    """Early stopping to stop training when validation metric doesn't improve."""
    def __init__(self, patience=10, min_delta=0.0, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
        
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'max':
            # For accuracy (higher is better)
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1
        else:
            # For loss (lower is better)
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1
        
        if self.counter >= self.patience:
            self.early_stop = True
        
        return self.early_stop


# === PREPROCESSING (applied once and cached) ===
# Only deterministic, spatial transforms - stores uint8 tensors (saves memory!)
preprocess_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(config["upscale_size"]),  # Upscale from 32x32 to 128x128
])

# === RUNTIME AUGMENTATION (applied at each epoch) ===
# Random transforms for training (includes normalization at the end)
train_runtime_transforms = v2.Compose([
    v2.RandomCrop(config["upscale_size"], padding=4),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    v2.RandomRotation(15),
    v2.ToDtype(torch.float32, scale=True),  # Convert to float [0,1]
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),  # CIFAR-100 stats
    v2.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))
])

# Test set: only normalization needed (spatial transforms already done)
test_runtime_transforms = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load raw datasets
print("Loading datasets...")
train_set_raw = CIFAR100_noisy_fine('/kaggle/input/fii-atnn-2025-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=True, transform=None)
test_set_raw = CIFAR100_noisy_fine('/kaggle/input/fii-atnn-2025-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=False, transform=None)

# Cache raw PIL images (fast, lightweight)
train_set_cached = SimpleCachedDataset(train_set_raw)
test_set_cached = SimpleCachedDataset(test_set_raw)

# Preprocess and cache as tensors (done once!)
print("\n[TRAIN SET]")
train_set_preprocessed = PreprocessedDataset(train_set_cached, preprocess_transforms)
print("\n[TEST SET]")
test_set_preprocessed = PreprocessedDataset(test_set_cached, preprocess_transforms)

# Add runtime augmentations (applied each epoch for train, none for test)
train_set = AugmentationWrapper(train_set_preprocessed, train_runtime_transforms)
test_set = AugmentationWrapper(test_set_preprocessed, test_runtime_transforms)

print(f"\nTrain set ready: {len(train_set)} samples (with runtime augmentation)")
print(f"Test set ready: {len(test_set)} samples (fully cached)\n")

train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, pin_memory=pin_memory,num_workers=2,persistent_workers=True)
test_loader = DataLoader(test_set, batch_size=500, pin_memory=pin_memory,num_workers=2,persistent_workers=True)

# Load ResNet18 pretrained on ImageNet
print(f"Loading model: {config['model']} (pretrained on {config['pretrained']})")
model = timm.create_model(config["model"], pretrained=True, num_classes=100)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}\n")

# Label smoothing helps with noisy labels
criterion = nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"])

# Create optimizer based on config
if config["optimizer"].lower() == "adamw":
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
        fused=True
    )
    print(f"Optimizer: AdamW (lr={config['lr']}, weight_decay={config['weight_decay']})")
elif config["optimizer"].lower() == "sgd":
    optimizer = optim.SGD(
        model.parameters(), 
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=config["nesterov"],
        fused=True
    )
    print(f"Optimizer: SGD (lr={config['lr']}, momentum={config['momentum']}, weight_decay={config['weight_decay']}, nesterov={config['nesterov']})")
else:
    raise ValueError(f"Unknown optimizer: {config['optimizer']}. Supported: 'sgd', 'adamw'")

# Learning rate scheduler
if config["scheduler"] == "steplr":
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.get("step_size", 30),
        gamma=config.get("gamma", 0.1)
    )
    print(f"Scheduler: StepLR (step_size={config.get('step_size', 30)}, gamma={config.get('gamma', 0.1)})")

elif config["scheduler"] == "cosine":
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config["epochs"],
        eta_min=config.get("cosine_eta_min", 1e-6)
    )
    print(f"Scheduler: CosineAnnealingLR (T_max={config['epochs']}, eta_min={config['cosine_eta_min']})")

elif config["scheduler"] == "warm_restarts":
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=config["warm_restarts_T0"],         
        T_mult=config.get("warm_restarts_mult", 1), 
        eta_min=config.get("cosine_eta_min", 1e-6)
    )
    # Highlight the per-batch setting so you know it's active
    step_mode = "Per-Batch" if config.get("scheduler_step_per_batch") else "Per-Epoch"
    print(f"Scheduler: WarmRestarts (T_0={config['warm_restarts_T0']}, T_mult={config['warm_restarts_mult']}, Step: {step_mode})")

else:
    scheduler = None
    print("Scheduler: None")

# === CUTMIX HELPER FUNCTION ===
def rand_bbox(size, lam):
    """Generates a random bounding box for CutMix."""
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

loss_threshold = config["loss_threshold"]

def train(epoch):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    model.train()
    correct = 0
    total = 0
    running_loss = 0.0
    global loss_threshold
    
    # Track data filtering for logs
    initial_batch_count = 0 
    
    use_cutmix = epoch >= config["switch_epoch"]
    aug_mode = "CutMix" if use_cutmix else "MixUp"
    
    # We use enumerate to get batch_idx for the scheduler
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        initial_batch_count += targets.size(0)

        # SMALL-LOSS FILTERING (Only after warm-up) 
        if epoch >= config["warmup_epochs"]:
            with torch.no_grad():
                with torch.autocast(device.type, enabled=enable_half):
                    raw_outputs = model(inputs)
                    sample_losses = torch.nn.functional.cross_entropy(raw_outputs, targets, reduction='none')
                mask = sample_losses < loss_threshold
            
            if mask.sum() < 2: 
                continue
            inputs = inputs[mask]
            targets = targets[mask]

        # PHASED AUGMENTATION LOGIC
        rand_index = torch.randperm(inputs.size(0)).to(device)
        target_a = targets
        target_b = targets[rand_index]
        lam = np.random.beta(config["aug_alpha"], config["aug_alpha"])

        if use_cutmix:
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
        else:
            inputs = lam * inputs + (1 - lam) * inputs[rand_index, :]
        
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = lam * criterion(outputs, target_a) + (1 - lam) * criterion(outputs, target_b)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        # Update per-batch scheduler if enabled
        if config.get("scheduler_step_per_batch") and scheduler is not None:
            # Formula for smooth warm restarts
            scheduler.step(epoch + batch_idx / len(train_loader))

        running_loss += loss.item() * inputs.size(0)
        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    # UPDATE THRESHOLD
    if epoch >= config["warmup_epochs"]:
        loss_threshold *= config["dynamic_threshold_decay"]
    
    # LOGGING STATS
    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    keep_rate = (total / initial_batch_count) * 100
    
    print(f"Keep Rate: {keep_rate:.2f}% | Threshold: {loss_threshold:.4f}")
    
    return epoch_loss, epoch_acc, aug_mode

@torch.inference_mode()
def val():
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        running_loss += loss.item() * inputs.size(0)
        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

@torch.inference_mode()
def inference():
    model.eval()
    
    labels = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)

        predicted = outputs.argmax(1).tolist()
        labels.extend(predicted)
    
    return labels

# Initialize WandB
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"cifar100noisy_{config['model']}_{config['optimizer']}_lr{config['lr']}_bs{config['batch_size']}_{timestamp}"

# wandb.init(
#     project=config["wandb_project"],
#     name=run_name,
#     config=config
# )

best = 0.0
best_epoch = 0

# Initialize early stopping
early_stopping = EarlyStopping(
    patience=config["early_stop_patience"],
    min_delta=config["early_stop_min_delta"],
    mode=config["early_stop_mode"]
)

print(f"\n{'='*70}")
print(f"Starting Training - {config['epochs']} epochs")
print(f"Model: {config['model']} (pretrained on {config['pretrained']})")
print(f"Optimizer: {config['optimizer'].upper()}, LR: {config['lr']}, Batch Size: {config['batch_size']}")
if config["optimizer"].lower() == "sgd":
    print(f"Momentum: {config['momentum']}, Nesterov: {config['nesterov']}")
print(f"Weight Decay: {config['weight_decay']}, Label Smoothing: {config['label_smoothing']}")
print(f"Scheduler: {config['scheduler']}")
print(f"Early Stopping: Enabled (patience={config['early_stop_patience']}, mode={config['early_stop_mode']})")
print(f"{'='*70}\n")

with tqdm(range(config["epochs"])) as tbar:
    for epoch in tbar:
        # 1. Run Train (includes internal per-batch scheduler steps)
        train_loss, train_acc, mode = train(epoch)
        
        # 2. Run Validation
        val_loss, val_acc = val()
        
        # 3. Handle LR logging (step only if NOT per-batch)
        if scheduler is not None:
            if not config.get("scheduler_step_per_batch"):
                scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
        else:
            current_lr = config["lr"]
        
        # 4. Checkpointing
        if val_acc > best:
            best = val_acc
            best_epoch = epoch
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }
            if scheduler is not None:
                checkpoint['scheduler_state_dict'] = scheduler.state_dict()
            torch.save(checkpoint, './best_model.pth')
        
        # 5. Update Progress Bar & Console
        status = f"Epoch {epoch+1}/{config['epochs']} | Train: {train_acc:.2f}% | Val: {val_acc:.2f}% | Best: {best:.2f}% | LR: {current_lr:.6f}"
        tbar.set_description(status)
        print(status)
        
        # 6. Early stopping check
        if early_stopping(val_acc, epoch):
            print(f"\n{'='*60}")
            print(f"Early stopping triggered at epoch {epoch+1}")
            print(f"Best Val Accuracy: {best:.2f}% at epoch {best_epoch+1}")
            print(f"{'='*60}\n")
            break
    

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"Best Val Accuracy: {best:.2f}% at epoch {best_epoch+1}")
print(f"Loading best model for inference...")
print(f"{'='*60}\n")

# Load best model for inference
checkpoint = torch.load('./best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model loaded (Epoch {checkpoint['epoch']+1}, Val Acc: {checkpoint['val_acc']:.2f}%)\n")

# Generate submission
data = {
    "ID": [],
    "target": []
}

print("Generating predictions...")
for i, label in enumerate(inference()):
    data["ID"].append(i)
    data["target"].append(label)

df = pd.DataFrame(data)
df.to_csv("/kaggle/working/submission.csv", index=False)

# Log final results to WandB
# wandb.summary["final_best_val_acc"] = best
# wandb.summary["best_epoch"] = best_epoch
# wandb.summary["total_epochs"] = config["epochs"]
# if scheduler is not None:
#     wandb.summary["final_lr"] = scheduler.get_last_lr()[0]
# else:
#     wandb.summary["final_lr"] = config["lr"]
# wandb.summary["early_stopped"] = early_stopping.early_stop
# wandb.summary["epochs_trained"] = best_epoch + 1 if early_stopping.early_stop else config["epochs"]

print(f"\n{'='*60}")
print(f"Submission saved to: ./submission.csv")
print(f"Best model saved to: ./best_model.pth")
print(f"Best Val Accuracy: {best:.2f}% (Epoch {best_epoch+1})")
if early_stopping.early_stop:
    print(f"Training stopped early (patience reached)")
print(f"{'='*60}\n")

# Finish WandB run
# 

Using device: cuda
Loading datasets...

[TRAIN SET]
Preprocessing 50000 images (this happens once)...


                                                                

Cached 50000 preprocessed images

[TEST SET]
Preprocessing 10000 images (this happens once)...


                                                               

Cached 10000 preprocessed images

Train set ready: 50000 samples (with runtime augmentation)
Test set ready: 10000 samples (fully cached)

Loading model: resnet18 (pretrained on imagenet)


model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Total parameters: 11,227,812
Trainable parameters: 11,227,812

Optimizer: AdamW (lr=0.001, weight_decay=0.01)
Scheduler: WarmRestarts (T_0=20, T_mult=1, Step: Per-Batch)

Starting Training - 100 epochs
Model: resnet18 (pretrained on imagenet)
Optimizer: ADAMW, LR: 0.001, Batch Size: 128
Weight Decay: 0.01, Label Smoothing: 0.1
Scheduler: warm_restarts
Early Stopping: Enabled (patience=15, mode=max)



  0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1/100
Keep Rate: 100.00% | Threshold: 2.5000


Epoch 1/100 | Train: 14.56% | Val: 54.59% | Best: 54.59% | LR: 0.000994:   1%|          | 1/100 [03:57<6:31:21, 237.18s/it]

Epoch 1/100 | Train: 14.56% | Val: 54.59% | Best: 54.59% | LR: 0.000994

Epoch 2/100
Keep Rate: 100.00% | Threshold: 2.5000


Epoch 2/100 | Train: 22.49% | Val: 59.06% | Best: 59.06% | LR: 0.000976:   2%|▏         | 2/100 [07:49<6:22:21, 234.09s/it]

Epoch 2/100 | Train: 22.49% | Val: 59.06% | Best: 59.06% | LR: 0.000976

Epoch 3/100
Keep Rate: 100.00% | Threshold: 2.5000


Epoch 3/100 | Train: 24.90% | Val: 62.86% | Best: 62.86% | LR: 0.000946:   3%|▎         | 3/100 [11:47<6:21:32, 236.00s/it]

Epoch 3/100 | Train: 24.90% | Val: 62.86% | Best: 62.86% | LR: 0.000946

Epoch 4/100
Keep Rate: 100.00% | Threshold: 2.5000


Epoch 4/100 | Train: 25.64% | Val: 63.49% | Best: 63.49% | LR: 0.000906:   4%|▍         | 4/100 [15:44<6:18:02, 236.28s/it]

Epoch 4/100 | Train: 25.64% | Val: 63.49% | Best: 63.49% | LR: 0.000906

Epoch 5/100
Keep Rate: 100.00% | Threshold: 2.5000


Epoch 5/100 | Train: 27.32% | Val: 65.13% | Best: 65.13% | LR: 0.000855:   5%|▌         | 5/100 [19:39<6:13:41, 236.01s/it]

Epoch 5/100 | Train: 27.32% | Val: 65.13% | Best: 65.13% | LR: 0.000855

Epoch 6/100
Keep Rate: 66.33% | Threshold: 2.4925


Epoch 6/100 | Train: 38.28% | Val: 66.43% | Best: 66.43% | LR: 0.000796:   6%|▌         | 6/100 [23:47<6:16:18, 240.20s/it]

Epoch 6/100 | Train: 38.28% | Val: 66.43% | Best: 66.43% | LR: 0.000796

Epoch 7/100
Keep Rate: 66.19% | Threshold: 2.4850


Epoch 7/100 | Train: 35.90% | Val: 66.77% | Best: 66.77% | LR: 0.000730:   7%|▋         | 7/100 [27:46<6:11:32, 239.70s/it]

Epoch 7/100 | Train: 35.90% | Val: 66.77% | Best: 66.77% | LR: 0.000730

Epoch 8/100
Keep Rate: 66.60% | Threshold: 2.4776


Epoch 8/100 | Train: 37.77% | Val: 67.52% | Best: 67.52% | LR: 0.000658:   8%|▊         | 8/100 [31:57<6:12:51, 243.16s/it]

Epoch 8/100 | Train: 37.77% | Val: 67.52% | Best: 67.52% | LR: 0.000658

Epoch 9/100
Keep Rate: 67.18% | Threshold: 2.4701


Epoch 9/100 | Train: 38.28% | Val: 68.52% | Best: 68.52% | LR: 0.000583:   9%|▉         | 9/100 [36:30<6:23:18, 252.73s/it]

Epoch 9/100 | Train: 38.28% | Val: 68.52% | Best: 68.52% | LR: 0.000583

Epoch 10/100
Keep Rate: 67.93% | Threshold: 2.4627


Epoch 10/100 | Train: 37.46% | Val: 69.12% | Best: 69.12% | LR: 0.000505:  10%|█         | 10/100 [41:01<6:27:30, 258.34s/it]

Epoch 10/100 | Train: 37.46% | Val: 69.12% | Best: 69.12% | LR: 0.000505

Epoch 11/100
Keep Rate: 68.33% | Threshold: 2.4553


Epoch 11/100 | Train: 38.79% | Val: 69.78% | Best: 69.78% | LR: 0.000428:  11%|█         | 11/100 [45:14<6:20:43, 256.67s/it]

Epoch 11/100 | Train: 38.79% | Val: 69.78% | Best: 69.78% | LR: 0.000428

Epoch 12/100
Keep Rate: 68.63% | Threshold: 2.4480


Epoch 12/100 | Train: 40.08% | Val: 70.30% | Best: 70.30% | LR: 0.000352:  12%|█▏        | 12/100 [49:15<6:09:29, 251.92s/it]

Epoch 12/100 | Train: 40.08% | Val: 70.30% | Best: 70.30% | LR: 0.000352

Epoch 13/100
Keep Rate: 69.04% | Threshold: 2.4406


Epoch 13/100 | Train: 40.62% | Val: 70.06% | Best: 70.30% | LR: 0.000280:  13%|█▎        | 13/100 [53:13<5:59:11, 247.72s/it]

Epoch 13/100 | Train: 40.62% | Val: 70.06% | Best: 70.30% | LR: 0.000280

Epoch 14/100
Keep Rate: 69.49% | Threshold: 2.4333


Epoch 14/100 | Train: 44.43% | Val: 70.54% | Best: 70.54% | LR: 0.000214:  14%|█▍        | 14/100 [57:24<5:56:10, 248.49s/it]

Epoch 14/100 | Train: 44.43% | Val: 70.54% | Best: 70.54% | LR: 0.000214

Epoch 15/100
Keep Rate: 69.64% | Threshold: 2.4260


Epoch 15/100 | Train: 42.77% | Val: 70.33% | Best: 70.54% | LR: 0.000155:  15%|█▌        | 15/100 [1:01:32<5:52:00, 248.48s/it]

Epoch 15/100 | Train: 42.77% | Val: 70.33% | Best: 70.54% | LR: 0.000155

Epoch 16/100
Keep Rate: 69.91% | Threshold: 2.4187


Epoch 16/100 | Train: 40.77% | Val: 70.78% | Best: 70.78% | LR: 0.000105:  16%|█▌        | 16/100 [1:05:43<5:49:01, 249.30s/it]

Epoch 16/100 | Train: 40.77% | Val: 70.78% | Best: 70.78% | LR: 0.000105

Epoch 17/100
Keep Rate: 70.10% | Threshold: 2.4115


Epoch 17/100 | Train: 42.30% | Val: 71.04% | Best: 71.04% | LR: 0.000064:  17%|█▋        | 17/100 [1:09:53<5:44:52, 249.31s/it]

Epoch 17/100 | Train: 42.30% | Val: 71.04% | Best: 71.04% | LR: 0.000064

Epoch 18/100
Keep Rate: 70.36% | Threshold: 2.4042


Epoch 18/100 | Train: 44.69% | Val: 70.89% | Best: 71.04% | LR: 0.000034:  18%|█▊        | 18/100 [1:13:50<5:35:58, 245.84s/it]

Epoch 18/100 | Train: 44.69% | Val: 70.89% | Best: 71.04% | LR: 0.000034

Epoch 19/100
Keep Rate: 70.35% | Threshold: 2.3970


Epoch 19/100 | Train: 45.53% | Val: 71.13% | Best: 71.13% | LR: 0.000016:  19%|█▉        | 19/100 [1:18:05<5:35:27, 248.49s/it]

Epoch 19/100 | Train: 45.53% | Val: 71.13% | Best: 71.13% | LR: 0.000016

Epoch 20/100
Keep Rate: 70.36% | Threshold: 2.3898


Epoch 20/100 | Train: 42.82% | Val: 71.20% | Best: 71.20% | LR: 0.000010:  20%|██        | 20/100 [1:22:13<5:30:54, 248.18s/it]

Epoch 20/100 | Train: 42.82% | Val: 71.20% | Best: 71.20% | LR: 0.000010

Epoch 21/100
Keep Rate: 67.32% | Threshold: 2.3827


Epoch 21/100 | Train: 39.87% | Val: 67.07% | Best: 71.20% | LR: 0.000994:  21%|██        | 21/100 [1:26:13<5:23:53, 246.00s/it]

Epoch 21/100 | Train: 39.87% | Val: 67.07% | Best: 71.20% | LR: 0.000994

Epoch 22/100
Keep Rate: 66.86% | Threshold: 2.3755


Epoch 22/100 | Train: 40.26% | Val: 66.27% | Best: 71.20% | LR: 0.000976:  22%|██▏       | 22/100 [1:30:14<5:17:45, 244.42s/it]

Epoch 22/100 | Train: 40.26% | Val: 66.27% | Best: 71.20% | LR: 0.000976

Epoch 23/100
Keep Rate: 67.23% | Threshold: 2.3684


Epoch 23/100 | Train: 40.85% | Val: 67.77% | Best: 71.20% | LR: 0.000946:  23%|██▎       | 23/100 [1:34:33<5:19:05, 248.64s/it]

Epoch 23/100 | Train: 40.85% | Val: 67.77% | Best: 71.20% | LR: 0.000946

Epoch 24/100
Keep Rate: 67.55% | Threshold: 2.3613


Epoch 24/100 | Train: 43.14% | Val: 68.15% | Best: 71.20% | LR: 0.000906:  24%|██▍       | 24/100 [1:38:44<5:15:56, 249.43s/it]

Epoch 24/100 | Train: 43.14% | Val: 68.15% | Best: 71.20% | LR: 0.000906

Epoch 25/100
Keep Rate: 67.79% | Threshold: 2.3542


Epoch 25/100 | Train: 42.30% | Val: 67.94% | Best: 71.20% | LR: 0.000855:  25%|██▌       | 25/100 [1:42:57<5:13:18, 250.64s/it]

Epoch 25/100 | Train: 42.30% | Val: 67.94% | Best: 71.20% | LR: 0.000855

Epoch 26/100
Keep Rate: 68.32% | Threshold: 2.3471


Epoch 26/100 | Train: 55.59% | Val: 69.64% | Best: 71.20% | LR: 0.000796:  26%|██▌       | 26/100 [1:46:56<5:04:30, 246.90s/it]

Epoch 26/100 | Train: 55.59% | Val: 69.64% | Best: 71.20% | LR: 0.000796

Epoch 27/100
Keep Rate: 68.56% | Threshold: 2.3401


Epoch 27/100 | Train: 58.16% | Val: 70.42% | Best: 71.20% | LR: 0.000730:  27%|██▋       | 27/100 [1:50:58<4:58:45, 245.56s/it]

Epoch 27/100 | Train: 58.16% | Val: 70.42% | Best: 71.20% | LR: 0.000730

Epoch 28/100
Keep Rate: 69.03% | Threshold: 2.3331


Epoch 28/100 | Train: 56.53% | Val: 70.14% | Best: 71.20% | LR: 0.000658:  28%|██▊       | 28/100 [1:55:18<4:59:52, 249.90s/it]

Epoch 28/100 | Train: 56.53% | Val: 70.14% | Best: 71.20% | LR: 0.000658

Epoch 29/100
Keep Rate: 69.23% | Threshold: 2.3261


Epoch 29/100 | Train: 60.10% | Val: 70.89% | Best: 71.20% | LR: 0.000583:  29%|██▉       | 29/100 [1:59:42<5:00:47, 254.19s/it]

Epoch 29/100 | Train: 60.10% | Val: 70.89% | Best: 71.20% | LR: 0.000583

Epoch 30/100
Keep Rate: 69.50% | Threshold: 2.3191


Epoch 30/100 | Train: 60.75% | Val: 70.96% | Best: 71.20% | LR: 0.000505:  30%|███       | 30/100 [2:04:03<4:58:52, 256.18s/it]

Epoch 30/100 | Train: 60.75% | Val: 70.96% | Best: 71.20% | LR: 0.000505

Epoch 31/100
Keep Rate: 69.80% | Threshold: 2.3121


Epoch 31/100 | Train: 58.17% | Val: 71.32% | Best: 71.32% | LR: 0.000428:  31%|███       | 31/100 [2:08:22<4:55:42, 257.13s/it]

Epoch 31/100 | Train: 58.17% | Val: 71.32% | Best: 71.32% | LR: 0.000428

Epoch 32/100
Keep Rate: 70.01% | Threshold: 2.3052


Epoch 32/100 | Train: 60.71% | Val: 71.47% | Best: 71.47% | LR: 0.000352:  32%|███▏      | 32/100 [2:12:33<4:49:03, 255.05s/it]

Epoch 32/100 | Train: 60.71% | Val: 71.47% | Best: 71.47% | LR: 0.000352

Epoch 33/100
Keep Rate: 70.21% | Threshold: 2.2983


Epoch 33/100 | Train: 62.11% | Val: 71.64% | Best: 71.64% | LR: 0.000280:  33%|███▎      | 33/100 [2:16:39<4:41:54, 252.45s/it]

Epoch 33/100 | Train: 62.11% | Val: 71.64% | Best: 71.64% | LR: 0.000280

Epoch 34/100
Keep Rate: 70.48% | Threshold: 2.2914


Epoch 34/100 | Train: 60.01% | Val: 71.39% | Best: 71.64% | LR: 0.000214:  34%|███▍      | 34/100 [2:20:37<4:33:00, 248.19s/it]

Epoch 34/100 | Train: 60.01% | Val: 71.39% | Best: 71.64% | LR: 0.000214

Epoch 35/100
Keep Rate: 70.79% | Threshold: 2.2845


Epoch 35/100 | Train: 60.80% | Val: 71.94% | Best: 71.94% | LR: 0.000155:  35%|███▌      | 35/100 [2:24:42<4:27:43, 247.14s/it]

Epoch 35/100 | Train: 60.80% | Val: 71.94% | Best: 71.94% | LR: 0.000155

Epoch 36/100
Keep Rate: 70.76% | Threshold: 2.2777


Epoch 36/100 | Train: 64.15% | Val: 71.85% | Best: 71.94% | LR: 0.000105:  36%|███▌      | 36/100 [2:28:36<4:19:28, 243.26s/it]

Epoch 36/100 | Train: 64.15% | Val: 71.85% | Best: 71.94% | LR: 0.000105

Epoch 37/100
Keep Rate: 70.96% | Threshold: 2.2708


Epoch 37/100 | Train: 62.67% | Val: 72.07% | Best: 72.07% | LR: 0.000064:  37%|███▋      | 37/100 [2:32:32<4:13:06, 241.06s/it]

Epoch 37/100 | Train: 62.67% | Val: 72.07% | Best: 72.07% | LR: 0.000064

Epoch 38/100
Keep Rate: 71.02% | Threshold: 2.2640


Epoch 38/100 | Train: 64.32% | Val: 71.96% | Best: 72.07% | LR: 0.000034:  38%|███▊      | 38/100 [2:36:28<4:07:38, 239.66s/it]

Epoch 38/100 | Train: 64.32% | Val: 71.96% | Best: 72.07% | LR: 0.000034

Epoch 39/100
Keep Rate: 71.01% | Threshold: 2.2572


Epoch 39/100 | Train: 62.20% | Val: 72.19% | Best: 72.19% | LR: 0.000016:  39%|███▉      | 39/100 [2:40:32<4:04:56, 240.92s/it]

Epoch 39/100 | Train: 62.20% | Val: 72.19% | Best: 72.19% | LR: 0.000016

Epoch 40/100
Keep Rate: 70.95% | Threshold: 2.2505


Epoch 40/100 | Train: 62.39% | Val: 72.31% | Best: 72.31% | LR: 0.000010:  40%|████      | 40/100 [2:44:51<4:06:15, 246.26s/it]

Epoch 40/100 | Train: 62.39% | Val: 72.31% | Best: 72.31% | LR: 0.000010

Epoch 41/100
Keep Rate: 69.03% | Threshold: 2.2437


Epoch 41/100 | Train: 60.67% | Val: 69.35% | Best: 72.31% | LR: 0.000994:  41%|████      | 41/100 [2:49:11<4:06:19, 250.49s/it]

Epoch 41/100 | Train: 60.67% | Val: 69.35% | Best: 72.31% | LR: 0.000994

Epoch 42/100
Keep Rate: 68.83% | Threshold: 2.2370


Epoch 42/100 | Train: 58.77% | Val: 69.04% | Best: 72.31% | LR: 0.000976:  42%|████▏     | 42/100 [2:53:19<4:01:24, 249.74s/it]

Epoch 42/100 | Train: 58.77% | Val: 69.04% | Best: 72.31% | LR: 0.000976

Epoch 43/100
Keep Rate: 68.87% | Threshold: 2.2303


Epoch 43/100 | Train: 59.51% | Val: 68.86% | Best: 72.31% | LR: 0.000946:  43%|████▎     | 43/100 [2:57:21<3:54:54, 247.27s/it]

Epoch 43/100 | Train: 59.51% | Val: 68.86% | Best: 72.31% | LR: 0.000946

Epoch 44/100
Keep Rate: 68.81% | Threshold: 2.2236


Epoch 44/100 | Train: 61.49% | Val: 70.07% | Best: 72.31% | LR: 0.000906:  44%|████▍     | 44/100 [3:01:17<3:47:44, 244.01s/it]

Epoch 44/100 | Train: 61.49% | Val: 70.07% | Best: 72.31% | LR: 0.000906

Epoch 45/100
Keep Rate: 69.04% | Threshold: 2.2169


Epoch 45/100 | Train: 61.64% | Val: 68.96% | Best: 72.31% | LR: 0.000855:  45%|████▌     | 45/100 [3:05:13<3:41:31, 241.66s/it]

Epoch 45/100 | Train: 61.64% | Val: 68.96% | Best: 72.31% | LR: 0.000855

Epoch 46/100
Keep Rate: 69.32% | Threshold: 2.2103


Epoch 46/100 | Train: 61.53% | Val: 70.23% | Best: 72.31% | LR: 0.000796:  46%|████▌     | 46/100 [3:09:22<3:39:19, 243.70s/it]

Epoch 46/100 | Train: 61.53% | Val: 70.23% | Best: 72.31% | LR: 0.000796

Epoch 47/100
Keep Rate: 69.74% | Threshold: 2.2036


Epoch 47/100 | Train: 62.13% | Val: 70.27% | Best: 72.31% | LR: 0.000730:  47%|████▋     | 47/100 [3:13:32<3:36:52, 245.52s/it]

Epoch 47/100 | Train: 62.13% | Val: 70.27% | Best: 72.31% | LR: 0.000730

Epoch 48/100
Keep Rate: 69.96% | Threshold: 2.1970


Epoch 48/100 | Train: 62.01% | Val: 70.41% | Best: 72.31% | LR: 0.000658:  48%|████▊     | 48/100 [3:17:46<3:35:04, 248.16s/it]

Epoch 48/100 | Train: 62.01% | Val: 70.41% | Best: 72.31% | LR: 0.000658

Epoch 49/100
Keep Rate: 70.20% | Threshold: 2.1904


Epoch 49/100 | Train: 58.99% | Val: 71.10% | Best: 72.31% | LR: 0.000583:  49%|████▉     | 49/100 [3:22:04<3:33:28, 251.14s/it]

Epoch 49/100 | Train: 58.99% | Val: 71.10% | Best: 72.31% | LR: 0.000583

Epoch 50/100
Keep Rate: 70.50% | Threshold: 2.1838


Epoch 50/100 | Train: 61.60% | Val: 70.71% | Best: 72.31% | LR: 0.000505:  50%|█████     | 50/100 [3:26:21<3:30:41, 252.84s/it]

Epoch 50/100 | Train: 61.60% | Val: 70.71% | Best: 72.31% | LR: 0.000505

Epoch 51/100
Keep Rate: 70.79% | Threshold: 2.1773


Epoch 51/100 | Train: 63.77% | Val: 71.11% | Best: 72.31% | LR: 0.000428:  51%|█████     | 51/100 [3:30:35<3:26:47, 253.22s/it]

Epoch 51/100 | Train: 63.77% | Val: 71.11% | Best: 72.31% | LR: 0.000428

Epoch 52/100
Keep Rate: 71.06% | Threshold: 2.1708


Epoch 52/100 | Train: 61.30% | Val: 70.92% | Best: 72.31% | LR: 0.000352:  52%|█████▏    | 52/100 [3:34:38<3:20:10, 250.22s/it]

Epoch 52/100 | Train: 61.30% | Val: 70.92% | Best: 72.31% | LR: 0.000352

Epoch 53/100
Keep Rate: 71.16% | Threshold: 2.1643


Epoch 53/100 | Train: 62.66% | Val: 70.68% | Best: 72.31% | LR: 0.000280:  53%|█████▎    | 53/100 [3:38:47<3:15:36, 249.71s/it]

Epoch 53/100 | Train: 62.66% | Val: 70.68% | Best: 72.31% | LR: 0.000280

Epoch 54/100
Keep Rate: 71.44% | Threshold: 2.1578


Epoch 54/100 | Train: 62.45% | Val: 71.38% | Best: 72.31% | LR: 0.000214:  54%|█████▍    | 54/100 [3:43:06<3:13:35, 252.51s/it]

Epoch 54/100 | Train: 62.45% | Val: 71.38% | Best: 72.31% | LR: 0.000214

Epoch 55/100
Keep Rate: 71.58% | Threshold: 2.1513


Epoch 55/100 | Train: 63.31% | Val: 71.35% | Best: 72.31% | LR: 0.000155:  54%|█████▍    | 54/100 [3:47:24<3:13:43, 252.67s/it]
  checkpoint = torch.load('./best_model.pth')


Epoch 55/100 | Train: 63.31% | Val: 71.35% | Best: 72.31% | LR: 0.000155

Early stopping triggered at epoch 55
Best Val Accuracy: 72.31% at epoch 40


Training Complete!
Best Val Accuracy: 72.31% at epoch 40
Loading best model for inference...

Best model loaded (Epoch 40, Val Acc: 72.31%)

Generating predictions...

Submission saved to: ./submission.csv
Best model saved to: ./best_model.pth
Best Val Accuracy: 72.31% (Epoch 40)
Training stopped early (patience reached)

