In [1]:
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR100
from typing import Optional, Callable
import os
import timm
import numpy as np
import pandas as pd
from torchvision.transforms import v2
from torch.backends import cudnn
from torch import GradScaler
from torch import optim
from tqdm import tqdm
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

get_default_device()

device(type='cuda')

In [3]:
device = torch.device('cuda')
cudnn.benchmark = True
pin_memory = True
enable_half = True  # Disable for CPU, it is slower!
scaler = GradScaler(device, enabled=enable_half)

In [4]:
class SimpleCachedDataset(Dataset):
    def __init__(self, dataset: Dataset, runtime_transforms: Optional[v2.Transform], cache: bool):
        if cache:
            dataset = tuple([x for x in dataset])
        self.dataset = dataset
        self.runtime_transforms = runtime_transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        image, label = self.dataset[i]
        if self.runtime_transforms is None:
            return image, label
        return self.runtime_transforms(image), label

In [5]:
class CIFAR100_noisy_fine(Dataset):
    """
    See https://github.com/UCSC-REAL/cifar-10-100n, https://www.noisylabels.com/ and `Learning with Noisy Labels
    Revisited: A Study Using Real-World Human Annotations`.
    """

    def __init__(
        self, root: str, train: bool, transform: Optional[Callable], download: bool
    ):
        cifar100 = CIFAR100(
            root=root, train=train, transform=transform, download=download
        )
        data, targets = tuple(zip(*cifar100))

        if train:
            noisy_label_file = os.path.join(root, "CIFAR-100-noisy.npz")
            if not os.path.isfile(noisy_label_file):
                raise FileNotFoundError(
                    f"{type(self).__name__} need {noisy_label_file} to be used!"
                )

            noise_file = np.load(noisy_label_file)
            if not np.array_equal(noise_file["clean_label"], targets):
                raise RuntimeError("Clean labels do not match!")
            
            targets = noise_file["noisy_label"]

        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, i: int):
        return self.data[i], self.targets[i]

In [6]:
mean=(0.507, 0.4865, 0.4409)
sd=(0.2673, 0.2564, 0.2761)

# mean = (0.5071, 0.4867, 0.4408)
# sd = (0.2675, 0.2565, 0.2761)

# mean = (0.4914, 0.4822, 0.4465) 
# sd = (0.2023, 0.1994, 0.2010)

basic_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])

test_transforms = v2.Compose([
    basic_transforms,
    v2.Normalize(mean, sd, inplace=True)
])

runtime_transforms = v2.Compose([
    v2.RandomCrop(size=32, padding=4),
    v2.RandomHorizontalFlip(0.5),
    v2.Normalize(mean, sd, inplace=True),
])

train_set = CIFAR100_noisy_fine('./fii-atnn-2024-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=True, transform=basic_transforms)
test_set = CIFAR100_noisy_fine('./fii-atnn-2024-project-noisy-cifar-100/fii-atnn-2024-project-noisy-cifar-100', download=False, train=False, transform=test_transforms)
train_set = SimpleCachedDataset(train_set, runtime_transforms, True)
test_set = SimpleCachedDataset(test_set, None, True)

train_loader = DataLoader(train_set, batch_size=50, shuffle=True, pin_memory=pin_memory)
test_loader = DataLoader(test_set, batch_size=500, pin_memory=pin_memory)


In [7]:
# CONVNEXT PRETRAINED - 1
model = timm.create_model("convnext_base", pretrained=True)
model.head = nn.Sequential(
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(model.num_features, 100),
)

# RESNET PRETRAINED - 2
# model = timm.create_model("resnext50_32x4d", pretrained=True)
# model.fc = nn.Linear(2048, 100)

UNFREEZE_EPOCH = 20

# Freeze earlier layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the final fully connected layer
for param in model.head.parameters():
    param.requires_grad = True


model = model.to(device)
# model = torch.jit.script(model)  # does not work for this model
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)


In [8]:
def train():
    model.train()
    correct = 0
    total = 0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return 100.0 * correct / total

In [9]:
warm_up_epochs = 25
loss_threshold = 2  
dynamic_threshold_decay = 0.995  

def train_loss_filtering(epoch):
    model.train()
    correct = 0
    total = 0
    global loss_threshold 

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(
            device, non_blocking=True
        )
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        if epoch >= warm_up_epochs:
            sample_losses = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
            mask = sample_losses < loss_threshold
            if mask.sum() == 0: 
                continue

            filtered_inputs = inputs[mask]
            filtered_targets = targets[mask]

            with torch.autocast(device.type, enabled=enable_half):
                outputs = model(filtered_inputs)
                loss = criterion(outputs, filtered_targets)
        else:
            filtered_inputs = inputs
            filtered_targets = targets

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        predicted = outputs.argmax(1)
        total += filtered_targets.size(0)
        correct += predicted.eq(filtered_targets).sum().item()

    if epoch >= warm_up_epochs:
        loss_threshold *= dynamic_threshold_decay

    return 100.0 * correct / total


In [10]:
@torch.inference_mode()
def val():
    model.eval()
    correct = 0
    total = 0
    total_loss = 0.0

    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        total_loss += loss.item()

        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / len(test_loader)
    accuracy = 100.0 * correct / total

    return accuracy, avg_loss


In [11]:
@torch.inference_mode()
def inference():
    model.eval()
    
    labels = []
    
    for inputs, _ in test_loader:
        inputs = inputs.to(device, non_blocking=True)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(inputs)

        predicted = outputs.argmax(1).tolist()
        labels.extend(predicted)
    
    return labels

In [12]:
def save_checkpoint(epoch, model, optimizer, best_val_acc, timestamp, checkpoint_name="checkpoint", checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_name}_{timestamp}.pth")
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "best_val_acc": best_val_acc,
    }
    torch.save(state, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")


def load_checkpoint(checkpoint_path="checkpoint.pth"):
    if not os.path.isfile(checkpoint_path):
        print(f"No checkpoint found at {checkpoint_path}")
        return None, None, 0.0

    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state["model_state_dict"])
    optimizer.load_state_dict(state["optimizer_state_dict"])
    best_val_acc = state["best_val_acc"]
    start_epoch = state["epoch"] + 1
    print(f"Checkpoint loaded from {checkpoint_path} (epoch {state['epoch']})")
    return start_epoch, best_val_acc

In [13]:
best = 0.0
epochs = list(range(100))
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

with tqdm(epochs) as tbar:
    for epoch in tbar:
        train_acc = train()
        val_acc, avg_loss = val()
        if val_acc > best:
            best = val_acc
            save_checkpoint(epoch, model, optimizer, best, timestamp, checkpoint_name="convnext")
        if epoch == UNFREEZE_EPOCH:
            for param in model.parameters():
                param.requires_grad = True
        
        tbar.set_description(f"Train: {train_acc:.2f}, Val: {val_acc:.2f}, Best: {best:.2f}, Val loss: {avg_loss:.2f}")

Train: 23.56, Val: 32.89, Best: 32.89, Val loss: 8.51:   1%|          | 1/100 [00:28<47:29, 28.79s/it]

Checkpoint saved at epoch 0 to checkpoints\convnext_20250106_031050.pth


Train: 30.15, Val: 33.59, Best: 33.59, Val loss: 8.92:   5%|▌         | 5/100 [02:09<40:33, 25.62s/it]

Checkpoint saved at epoch 4 to checkpoints\convnext_20250106_031050.pth


Train: 31.41, Val: 34.22, Best: 34.22, Val loss: 8.94:   8%|▊         | 8/100 [03:25<38:49, 25.32s/it]

Checkpoint saved at epoch 7 to checkpoints\convnext_20250106_031050.pth


Train: 24.80, Val: 52.99, Best: 52.99, Val loss: 2.32:  22%|██▏       | 22/100 [10:02<50:38, 38.96s/it]

Checkpoint saved at epoch 21 to checkpoints\convnext_20250106_031050.pth


Train: 47.67, Val: 57.14, Best: 57.14, Val loss: 2.18:  23%|██▎       | 23/100 [11:11<1:01:24, 47.85s/it]

Checkpoint saved at epoch 22 to checkpoints\convnext_20250106_031050.pth


Train: 50.81, Val: 60.20, Best: 60.20, Val loss: 2.07:  24%|██▍       | 24/100 [12:19<1:08:14, 53.87s/it]

Checkpoint saved at epoch 23 to checkpoints\convnext_20250106_031050.pth


Train: 54.09, Val: 61.56, Best: 61.56, Val loss: 2.01:  26%|██▌       | 26/100 [14:29<1:13:27, 59.56s/it]

Checkpoint saved at epoch 25 to checkpoints\convnext_20250106_031050.pth


Train: 56.09, Val: 61.80, Best: 61.80, Val loss: 1.99:  28%|██▊       | 28/100 [16:25<1:10:28, 58.73s/it]

Checkpoint saved at epoch 27 to checkpoints\convnext_20250106_031050.pth


Train: 57.15, Val: 61.91, Best: 61.91, Val loss: 1.99:  29%|██▉       | 29/100 [17:23<1:09:14, 58.51s/it]

Checkpoint saved at epoch 28 to checkpoints\convnext_20250106_031050.pth


Train: 59.01, Val: 62.08, Best: 62.08, Val loss: 1.99:  31%|███       | 31/100 [19:16<1:06:06, 57.48s/it]

Checkpoint saved at epoch 30 to checkpoints\convnext_20250106_031050.pth


Train: 77.12, Val: 58.03, Best: 62.08, Val loss: 2.28:  44%|████▍     | 44/100 [32:16<41:05, 44.02s/it]  


KeyboardInterrupt: 

In [16]:
# data = {
#     "ID": [],
#     "target": []
# }


# for i, label in enumerate(inference()):
#     data["ID"].append(i)
#     data["target"].append(label)

# df = pd.DataFrame(data)
# df.to_csv("./submission.csv", index=False)