In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize, Compose, Resize, ConvertImageDtype
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from pathlib import Path

for module in np, pd, torch:
    print(module.__name__, module.__version__)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)
print(torch.cuda.get_device_name())

numpy 2.0.2
pandas 2.2.3
torch 2.6.0+cu126
cuda:0
NVIDIA GeForce GTX 1050 Ti


In [None]:
DATA_DIR = Path("./archive")

class MonkeyDataset(datasets.ImageFolder):
    def __init__(self, mode, transform = None):
        if mode == "train":
            root = DATA_DIR / "training"
        elif mode == "valid":
            root = DATA_DIR / "validation"

        super().__init__(root, transform)
        self.imgs = self.samples
        self.targets = [s[1] for s in self.samples]

In [9]:
class EarlyStopCallBack:
    def __init__(self, patience = 5, min_delta = 0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_metric = -1
    
    def __call__(self, metric):
        if metric > self.best_metric + self.min_delta:
            self.best_metric = metric
            self.counter = 0
        else:
            self.counter += 1
    
    @property
    def early_stop(self):
        return self.counter >= self.patience

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride1, stride2, padding):
        super().__init__()
        self.activation = F.relu
        self.conv1 = nn.Conv2d(in_channels,  out_channels, kernel_size, stride1, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride2, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        if in_channels != out_channels:
            self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, padding=0) 
        else:
            self.downsample = None

    def forward(self, x):
        act = self.activation
        bn = self.bn
        fx = bn(self.conv2(act(bn(self.conv1(x)))))

        if self.downsample is not None:
            x = bn(self.downsample(x))

        return act(fx + x)

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)  # (224 + 2*3 - 7)//2 + 1 ==> 112
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)    # (112 + 2*1 - 3)//2 + 1 ==> 56
        self.conv2_x = nn.Sequential(
            ResBlock(in_channels=64,  out_channels=64,  kernel_size=3, stride1=1, stride2=1, padding=1), # (56 + 2*1 - 3)//1 + 1 ==> 56, (56 + 2*1 - 3)//1 + 1 ==> 56
            ResBlock(in_channels=64,  out_channels=64,  kernel_size=3, stride1=1, stride2=1, padding=1)  # (56 + 2*1 - 3)//1 + 1 ==> 56, (56 + 2*1 - 3)//1 + 1 ==> 56
        )
        self.conv3_x = nn.Sequential(
            ResBlock(in_channels=64,  out_channels=128, kernel_size=3, stride1=2, stride2=1, padding=1), # (56 + 2*1 - 3)//2 + 1 ==> 28, (28 + 2*1 - 3)//1 + 1 ==> 28 
            ResBlock(in_channels=128, out_channels=128, kernel_size=3, stride1=1, stride2=1, padding=1)  # (28 + 2*1 - 3)//1 + 1 ==> 28, (28 + 2*1 - 3)//1 + 1 ==> 28
        )
        self.conv4_x = nn.Sequential(
            ResBlock(in_channels=128, out_channels=256, kernel_size=3, stride1=2, stride2=1, padding=1), # (28 + 2*1 - 3)//2 + 1 ==> 14, (14 + 2*1 - 3)//1 + 1 ==> 14
            ResBlock(in_channels=256, out_channels=256, kernel_size=3, stride1=1, stride2=1, padding=1)  # (14 + 2*1 - 3)//1 + 1 ==> 14, (14 + 2*1 - 3)//1 + 1 ==> 14
        )
        self.conv5_x = nn.Sequential(
            ResBlock(in_channels=256, out_channels=512, kernel_size=3, stride1=2, stride2=1, padding=1), # (14 + 2*1 - 3)//2 + 1 ==> 7, (7 + 2*1 - 3)//1 + 1 ==> 7
            ResBlock(in_channels=512, out_channels=512, kernel_size=3, stride1=1, stride2=1, padding=1)  # ( 7 + 2*1 - 3)//1 + 1 ==> 7, (7 + 2*1 - 3)//1 + 1 ==> 7
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_features=512, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.avg_pool(x)
        x = self.fc(self.flatten(x))

        return x

In [None]:
@torch.no_grad()
def evaluating(model, valid_dataloader, loss_function):
    loss_list  = []
    pred_list  = []
    label_list = []

    for imgs, labels in valid_dataloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        pred = logits.argmax(axis=-1)
        loss = loss_function(logits, labels)

        loss_list.append(loss.item())
        pred_list.extend(pred.cpu().numpy().tolist())
        label_list.extend(labels.cpu().numpy().tolist())
    
    acc = accuracy_score(label_list, pred_list)

    return np.mean(loss_list), acc

27.5

In [None]:
def training(
        model,
        train_dataloader,
        valid_dataloader,
        epoch,
        loss_function,
        optimizer,
        early_stop_cb=None,
        eval_step=500
):
    model.train()
    record_dict = {"train":[], "valid":[]}
    global_step = 0
    with tqdm(total=epoch*len(train_dataloader)) as pbar:
        for epoch_id in range(epoch):
            for imgs, labels in train_dataloader:
                imgs = imgs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                logits = model(imgs)
                loss = loss_function(logits, labels)
                loss.backward()
                optimizer.step()

                loss = loss.cpu().item()
                pred = logits.argmax(axis=-1)
                acc = accuracy_score(labels.cpu().numpy(), pred.cpu().numpy())
                
                record_dict["train"].append({"loss":loss, "acc":acc, "step":global_step})

                if global_step % eval_step == 0:
                    model.eval()
                    valid_loss, valid_acc = evaluating(model, valid_dataloader, loss_function)
                    model.train()

                    record_dict["valid"].append({"loss":valid_loss, "acc":valid_acc, "step":global_step})
                    
                    if early_stop_cb is not None:
                        early_stop_cb(valid_acc)
                        if early_stop_cb.early_stop:
                            print(f"Early Stop at step:{epoch_id} / global_step:{global_step}")
                            return record_dict
                
                global_step += 1
                pbar.update(1)
                pbar.set_postfix({"epoch_id":epoch_id})
    
    return record_dict

In [None]:
def plot_learning_curve(record_dict, sample_step=500):
    train_df = pd.DataFrame(record_dict["train"]).set_index("step").iloc[::sample_step]
    valid_df = pd.DataFrame(record_dict["valid"]).set_index("step")
    fig_num = len(train_df.columns)
    fig, axs = plt.subplots(1, fig_num, figsize=(5*fig_num, 5))
    for idx, item in enumerate(train_df.columns):
        axs[idx].plot(train_df.index, train_df[item], label=f"train_{item}")
        axs[idx].plot(valid_df.index, valid_df[item], label=f"valid_{item}")
        axs[idx].grid()
        axs[idx].legend()
        axs[idx].set_xlabel("step")
    
    plt.show()

In [None]:
img_h, img_w = 224, 224

transformer = Compose([
    Resize((img_h, img_w)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ConvertImageDtype(torch.float)
])

train_dataset = MonkeyDataset(mode="train", transform=transformer)
valid_dataset = MonkeyDataset(mode="valid", transform=transformer)

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

torch.Size([32, 3, 224, 224]) torch.Size([32])


In [None]:
model = ResNet().to(device)

loss_function = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

early_stop_cb = EarlyStopCallBack()

epoch = 20

record = training(model, train_dataloader, valid_dataloader, epoch, loss_function, optimizer, early_stop_cb)

tensor([[ 0.0517, -0.0888, -0.0508, -0.0006, -0.0503, -0.0805,  0.0388, -0.0194,
          0.0283, -0.0229]], grad_fn=<AddmmBackward0>)

In [None]:
plot_learning_curve(record)
model.eval()
loss, acc = evaluating(model, valid_dataloader, loss_function)
print(f"loss: {loss:.4f}, acc: {acc:.4f}")