<a href="https://www.kaggle.com/code/vitdnghong/cifar-10-resnet?scriptVersionId=226042228" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import os
import wandb

In [None]:
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)
        
        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)
        
        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))
        
        x = self.relu(self.batch_norm2(self.conv2(x)))
        
        x = self.conv3(x)
        x = self.batch_norm3(x)
        
        #downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        #add identity
        x+=identity
        x=self.relu(x)
        
        return x

In [None]:
class Block(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()
       

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
      identity = x.clone()

      x = self.relu(self.batch_norm2(self.conv1(x)))
      x = self.batch_norm2(self.conv2(x))

      if self.i_downsample is not None:
          identity = self.i_downsample(identity)
      print(x.shape)
      print(identity.shape)
      x += identity
      x = self.relu(x)
      return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*ResBlock.expansion, num_classes)
        
    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        
        return x

    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []
        
        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )
            
        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion
        
        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))
            
        return nn.Sequential(*layers)

In [None]:
def ResNet50(num_classes, channels=3):
    return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)

In [None]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [None]:
traindata = torchvision.datasets.CIFAR10(root='./kaggle/working/', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./kaggle/working/', train=False,
                                       download=True, transform=transform_test)

In [None]:
trainset_size = int(len(traindata) * 0.8)
validset_size = len(traindata) - trainset_size

seed = torch.Generator().manual_seed(42)
trainset, validset = data.random_split(traindata, [trainset_size, validset_size], generator=seed)

In [None]:
batch_size = 64
train_loader = DataLoader(trainset, batch_size = batch_size, num_workers = 4, shuffle = True, pin_memory=True)
val_loader = DataLoader(validset, batch_size = batch_size, num_workers = 4, shuffle = False, pin_memory=True)
test_loader = DataLoader(testset, batch_size = batch_size, num_workers = 4, shuffle = False, pin_memory=True)

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

dataiter = iter(train_loader)
images, labels = next(dataiter)

images = images / 2 + 0.5

row, col, pos = 4, 4, 0
fig, axes = plt.subplots(row, col, figsize = (9, 9))

for i in range(row):
    for j in range(col):
        ax = axes[i][j]

        ax.imshow(images[pos].permute(1, 2, 0).numpy())

        ax.set(xticks = [])
        ax.set(yticks = [])

        ax.set_title(classes[labels[pos].numpy()])
        pos += 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
net = ResNet50(10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)

In [None]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

my_secret = user_secrets.get_secret("wandb_api_key") 

wandb.login(key=my_secret)

In [None]:
wandb.init(project="ResNet_CIFAR-10")

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, loss, path="best_model.pth"):
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss,
    }, f"/kaggle/working/{path}")

In [None]:
class EarlyStopping:
    def __init__(self, patience=20):
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def check(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print("⏹️ Early Stopping Triggered!")
                self.early_stop = True

In [None]:
EPOCHS = 200

wandb.config = {
    "epochs": EPOCHS,
    "batch_size": batch_size,
    "learning_rate": optimizer.param_groups[0]['lr']
}

early_stopper = EarlyStopping(patience = 20)
best_val_loss = float('inf')

In [None]:
for epoch in range(EPOCHS):
    net.train()  # Turn on Training mode
    total_train_loss = 0
    correct_train = 0
    total_train = 0

    # Training loop
    for inputs, labels in train_loader:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        # Cộng dồn loss & accuracy
        total_train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)  # Lấy nhãn dự đoán
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)  # Tổng số mẫu trong batch

    avg_train_loss = total_train_loss / len(train_loader)
    train_accuracy = correct_train / total_train * 100  # Accuracy %

    # Validation loop
    net.eval()
    total_val_loss = 0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            # Cộng dồn loss & accuracy
            total_val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += labels.size(0)

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = correct_val / total_val * 100  # Accuracy %

    # In ra loss và accuracy mỗi epoch
    print(f"📉 Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

    # Log vào wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "train_accuracy": train_accuracy,
        "val_loss": avg_val_loss,
        "val_accuracy": val_accuracy
    })

    # Cập nhật scheduler
    scheduler.step(avg_val_loss)

    # Lưu checkpoint nếu tốt hơn
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        save_checkpoint(net, optimizer, scheduler, epoch, avg_val_loss)

    # Kiểm tra early stopping
    early_stopper.check(avg_val_loss)
    if early_stopper.early_stop:
        print("🚀 Training stopped early!")
        break  # Dừng training nếu không cải thiện

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to('cuda'), labels.to('cuda')
        outputs = net(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Test Accuracy: ', 100*(correct/total), '%')

In [None]:
wandb.finish()