In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
!jupyter labextension install @jupyter-widgets/jupyterlab-manage
!pip install -q wandb --upgrade

Config option `kernel_spec_manager_class` not recognized by `InstallLabExtensionApp`.
[33m(Deprecated) Installing extensions with the jupyter labextension install command is now deprecated and will be removed in a future major version of JupyterLab.

Users should manage prebuilt extensions with package managers like pip and conda, and extension authors are encouraged to distribute their extensions as prebuilt packages [0m
[33m[W 2024-03-18 08:47:54.717 LabApp][m Config option `kernel_spec_manager_class` not recognized by `LabApp`.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import wandb
import numpy as np

In [4]:
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        # CIFAR-10 images are 3x32x32, flatten them to 3072-dimensional vectors
        self.fc1 = nn.Linear(3*32*32, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10) # 10 classes in CIFAR-10
    
    def forward(self, x):
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

mlp_net = MLPNet()

In [5]:
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, stride=1) 
        self.bn1 = nn.BatchNorm2d(32)  # batch-norm
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)  # add filter num
        self.bn2 = nn.BatchNorm2d(64)  # batch-norm
        self.fc1 = nn.Linear(64 * 5 * 5, 120)  
        self.dropout = nn.Dropout(0.5)  # dropout
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # apply dropout
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

cnn_net = CNNNet()

In [6]:
class EarlyStopping:
    """提前中止训练"""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        parameters:
            patience (int): 损失提高的迭代轮数，超过这个值后训练会停止
            verbose (bool): 如果为True，打印一条信息指示提前中止
            delta (float): 提升的最小变化，小于这个值被认为是没有提升
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("inf")
        self.delta = delta

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

In [7]:
def load_data(batch_size=16):
    # 仅用于训练集的数据增强变换
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
        transforms.RandomRotation(10),      # 随机旋转图像±10度
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 验证集和测试集的变换，没有包含随机性操作
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=train_transform)

    num_train = len(full_trainset)
    indices = list(range(num_train))
    split = int(np.floor(0.1 * num_train))  # 10%的数据用作验证集

    np.random.shuffle(indices)

    train_idx, val_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    trainloader = torch.utils.data.DataLoader(full_trainset, batch_size=batch_size,
                                              sampler=train_sampler, num_workers=2)
    # 使用 full_trainset 但用 test_transform 进行数据加载
    valloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(root='./data', train=True,
                                                  download=True, transform=test_transform), batch_size=batch_size,
                                             sampler=val_sampler, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)
    return trainloader, valloader, testloader

In [8]:
def train_and_test(net, trainloader, valloader, testloader, device, epochs=50):
    early_stopping = EarlyStopping(patience=15, verbose=True)
    
    optimizer = optim.AdamW(net.parameters()) # use AdamW instead of SGD
    criterion = nn.CrossEntropyLoss()

    net.to(device)

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        total_train = 0
        correct_train = 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(trainloader)
        train_accuracy = correct_train / total_train

        # val
        net.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss /= len(valloader)
        val_accuracy = correct_val / total_val

        print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Train Acc: {train_accuracy}, Val Loss: {val_loss}, Val Acc: {val_accuracy}")
        
        wandb.log({'train_loss': train_loss, 
                   'train_accuracy': train_accuracy, 
                   'val_loss': val_loss, 
                   'val_accuracy': val_accuracy})

        # early stop method
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break
            
    print('Finished Training')

    
    PATH = './cifar_net.pth'
    torch.save(net.state_dict(), PATH)

    # test
    net.eval()  
    correct = 0
    total = 0
    with torch.no_grad(): 
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            # 通过网络运行图像计算输出
            outputs = net(images)
            # 选择能量最高的类作为预测结果
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')

    wandb.log({"test_accuracy": accuracy})

In [9]:
!nvidia-smi

Mon Mar 18 08:48:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0              26W / 250W |      0MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [10]:
def main(model_type):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    trainloader, valloader, testloader = load_data()

    if model_type == 'mlp':
        net = MLPNet()
    elif model_type == 'cnn':
        net = CNNNet()
    else:
        raise ValueError("Unsupported model type. Choose 'mlp' or 'cnn'.")
    wandb.init(project="cifar10_classification")
    train_and_test(net, trainloader, valloader, testloader, device)
    wandb.finish()

In [11]:
main('mlp')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 47758953.32it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch 1, Train Loss: 1.721216379896924, Train Acc: 0.38506666666666667, Val Loss: 1.6231356693532901, Val Acc: 0.4216
Epoch 2, Train Loss: 1.562007608381489, Train Acc: 0.44382222222222223, Val Loss: 1.527741766775759, Val Acc: 0.4608
Epoch 3, Train Loss: 1.4935887031312773, Train Acc: 0.4705111111111111, Val Loss: 1.5109439139929823, Val Acc: 0.4728
Epoch 4, Train Loss: 1.4522265617083792, Train Acc: 0.4835111111111111, Val Loss: 1.5505274616110438, Val Acc: 0.467
EarlyStopping counter: 1 out of 15
Epoch 5, Train Loss: 1.420023116741662, Train Acc: 0.49491111111111113, Val Loss: 1.5002552335635542, Val Acc: 0.4782
Epoch 6, Train Loss: 1.3910513434001572, Train Acc: 0.5092, Val Loss: 1.4816672790545624, Val Acc: 0.4868
Epoch 7, Train Loss: 1.3672942841048772, Train Acc: 0.5143111111111112, Val Loss: 1.4701765684274057, Val Acc: 0.4928
Epoch 8, Train Loss: 1.3508400804115877, Train Acc: 0.5221555555555556, Val Loss: 1.5267417034782922, Val Acc: 0.4728
EarlyStopping counter: 1 out of 15


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁
train_accuracy,▁▃▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█████████████
train_loss,█▆▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▅▅▆▆▇▆▆▇▇▆▇▇▇▇▇▇█▆▇█▇▇███▇▇▇▇███████
val_loss,█▅▅▆▄▄▃▂▂▄▂▁▂▃▂▄▄▂▂▂▁▃▂▂▂▃▁▁▃▁▁▃▄▂▃▂▃▃▂▂

0,1
test_accuracy,53.08
train_accuracy,0.61331
train_loss,1.10211
val_accuracy,0.5356
val_loss,1.46185


In [12]:
main('cnn')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Currently logged in as: [33mnagi-ovo[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1, Train Loss: 1.656180173974422, Train Acc: 0.38915555555555553, Val Loss: 1.3559305665972896, Val Acc: 0.5048
Epoch 2, Train Loss: 1.4081231361388142, Train Acc: 0.4911111111111111, Val Loss: 1.2388295765502004, Val Acc: 0.576
Epoch 3, Train Loss: 1.2968489838968678, Train Acc: 0.5330888888888888, Val Loss: 1.1514485920199191, Val Acc: 0.624
Epoch 4, Train Loss: 1.2209865173707006, Train Acc: 0.5658, Val Loss: 1.1388978937182563, Val Acc: 0.631
Epoch 5, Train Loss: 1.158221830994982, Train Acc: 0.5886, Val Loss: 1.0765346467685395, Val Acc: 0.6498
Epoch 6, Train Loss: 1.1209640248767685, Train Acc: 0.6016222222222222, Val Loss: 1.01349548352793, Val Acc: 0.6756
Epoch 7, Train Loss: 1.0848404036969255, Train Acc: 0.6184444444444445, Val Loss: 1.1124926019019592, Val Acc: 0.6284
EarlyStopping counter: 1 out of 15
Epoch 8, Train Loss: 1.0515963946145666, Train Acc: 0.6303777777777778, Val Loss: 1.0045205741263807, Val Acc: 0.682
Epoch 9, Train Loss: 1.0269674665767698, Train Acc: 

VBox(children=(Label(value='0.054 MB of 0.054 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁
train_accuracy,▁▃▄▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇██████████████████
train_loss,█▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▆▄▆▅▆▇▇▆▇▇▇▇▇▇▇▇▇▇█▇▇███▇▇█▇▇█▇███▇█
val_loss,█▇▆▅▄▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▁▂▁▂▂

0,1
test_accuracy,75.49
train_accuracy,0.74573
train_loss,0.74025
val_accuracy,0.7554
val_loss,0.80641
