In [15]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score

for module in np, pd, torch:
    print(module.__name__, module.__version__)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

numpy 2.0.2
pandas 2.2.3
torch 2.6.0+cu126
cuda:0


In [16]:
train_ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_ds = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_ds, valid_ds = random_split(train_ds, [55000, 5000], torch.Generator().manual_seed(42))

In [17]:
transforms = nn.Sequential(
    Normalize([0.2856], [0.3202])
)   # 对每个通道进行标准化 对于FashionMNIST灰度图只有一个通道

In [18]:
batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)

In [19]:
class CNN(nn.Module):
    def __init__(self, activation="relu"):
        super().__init__()
        self.activation = F.relu if activation == "relu" else F.selu
        self.conv1 = nn.Conv2d(in_channels=1,   out_channels=32,  kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32,  out_channels=32,  kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32,  out_channels=64,  kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64,  out_channels=64,  kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=64,  out_channels=128, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 10)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)


    def forward(self, x):
        act = self.activation
        x = self.pool(act(self.conv2(act(self.conv1(x)))))
        x = self.pool(act(self.conv4(act(self.conv3(x)))))
        x = self.pool(act(self.conv6(act(self.conv5(x)))))
        x = self.fc2(act(self.fc1(self.flatten(x))))
        
        return x

    '''
    def forward(self, x):
        act = self.activation
        print(x.shape)
        x = act(self.conv1(x))  # [1, 28, 28]   --> [32, 28, 28]
        print(x.shape)
        x = act(self.conv2(x))  # [32, 28, 28]  --> [32, 28, 28]
        print(x.shape)
        x = self.pool(x)        # [32, 28, 28]  --> [32, 14, 14]
        print(x.shape)
        x = act(self.conv3(x))  # [32, 14, 14]  --> [64, 14, 14]
        print(x.shape)
        x = act(self.conv4(x))  # [64, 14, 14]  --> [64, 14, 14]
        print(x.shape)
        x = self.pool(x)        # [64, 14, 14]  --> [64, 7, 7]
        print(x.shape)
        x = act(self.conv5(x))  # [64, 7, 7]    --> [128, 7, 7]
        print(x.shape)
        x = act(self.conv6(x))  # [128, 7, 7]   --> [128, 7, 7]
        print(x.shape)
        x = self.pool(x)        # [128, 7, 7]   --> [128, 3, 3]
        print(x.shape)

        x = self.flatten(x) # 128 * 3 * 3 --> 1152
        print(x.shape)

        x = act(self.fc1(x))
        x = self.fc2(x)

        return x
    '''

In [20]:
class EarlyStopCallback:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.best_metric = -1
        self.counter = 0

    def __call__(self, metric):
        if metric > self.best_metric + self.min_delta:
            self.best_metric = metric
            self.counter = 0
        else:
            self.counter += 1

    @property
    def early_stop(self):
        return self.counter >= self.patience

In [21]:
@torch.no_grad()
def evaluating(model, dataloader, loss_function):
    loss_list  = []
    pred_list  = []
    label_list = []

    for datas, labels in dataloader:
        datas  = datas.to(device)
        labels = labels.to(device)
        logits = model(datas)
        loss = loss_function(logits, labels)
        loss_list.append(loss.item())
        preds = logits.argmax(axis=-1)
        pred_list.extend(preds.cpu().numpy().tolist())
        label_list.extend(labels.cpu().numpy().tolist())

    acc = accuracy_score(label_list, pred_list)

    return np.mean(loss_list), acc


In [25]:
def training_base(
        model,
        train_dataloader,
        valid_dataloader,
        loss_function,
        optimizer,
        epoch,
        early_stor_cb = None,
        eval_step = 500
):
    record_dict = {"train":[], "valid":[]}
    global_step = 0
    model.train()
    with tqdm(total=epoch * len(train_dataloader)) as pbar:
        for epoch_id in range(epoch):
            for datas, labels in train_dataloader:
                datas = datas.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                logits = model(datas)
                loss = loss_function(logits, labels)
                loss.backward()
                optimizer.step()
                preds = logits.argmax(axis=-1)
                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
                loss = loss.cpu().item()
                record_dict["train"].append({"loss":loss, "acc":acc, "step":global_step})

                if global_step % eval_step == 0:
                    model.eval()
                    valid_loss, valid_acc = evaluating(model, valid_dataloader, loss_function)
                    record_dict["valid"].append({"loss":valid_loss, "acc":valid_acc, "step":global_step})
                    model.train()

                    if early_stor_cb is not None:
                        early_stor_cb(valid_acc)
                        if early_stor_cb.early_stop:
                            print(f"Early stop at epoch : {epoch_id}, global_step : {global_step}")
                            return record_dict
                global_step += 1
                pbar.update(1)
                pbar.set_postfix({"epoch":epoch_id})

    return record_dict

In [23]:
def plot_learning_curve(record_dict, sample_step=500):
    train_df = pd.DataFrame(record_dict["train"]).set_index("step").iloc[::sample_step]
    valid_df = pd.DataFrame(record_dict["valid"]).set_index("step")
    
    fig_num = len(train_df.columns)
    fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))
    for idx, item in enumerate(train_df.columns):    
        axs[idx].plot(train_df.index, train_df[item], label=f"train_{item}")
        axs[idx].plot(valid_df.index, valid_df[item], label=f"valid_{item}")
        axs[idx].grid()
        axs[idx].legend()
        axs[idx].set_xticks(range(0, train_df.index[-1], 5000))
        axs[idx].set_xticklabels(map(lambda x: f"{int(x/1000)}k", range(0, train_df.index[-1], 5000)))
        axs[idx].set_xlabel("step")
    
    plt.show()

In [None]:
model = CNN("relu")
early_stop_cb = EarlyStopCallback()
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model = model.to(device)
record = training_base(model, train_loader, valid_loader, loss_function, optimizer, 20, early_stop_cb, eval_step=1000)
plot_learning_curve(record, sample_step=500)

In [28]:
record["train"][0:10]

[{'loss': 2.3093442916870117, 'acc': 0.0625, 'step': 0},
 {'loss': 2.3058431148529053, 'acc': 0.0625, 'step': 1},
 {'loss': 2.3048465251922607, 'acc': 0.03125, 'step': 2},
 {'loss': 2.3020904064178467, 'acc': 0.125, 'step': 3},
 {'loss': 2.2993860244750977, 'acc': 0.03125, 'step': 4},
 {'loss': 2.3038411140441895, 'acc': 0.125, 'step': 5},
 {'loss': 2.303762435913086, 'acc': 0.15625, 'step': 6},
 {'loss': 2.305441379547119, 'acc': 0.09375, 'step': 7},
 {'loss': 2.3050379753112793, 'acc': 0.09375, 'step': 8},
 {'loss': 2.3088085651397705, 'acc': 0.09375, 'step': 9}]

: 