In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
!jupyter labextension install @jupyter-widgets/jupyterlab-manage
!pip install -q wandb --upgrade

Config option `kernel_spec_manager_class` not recognized by `InstallLabExtensionApp`.
[33m(Deprecated) Installing extensions with the jupyter labextension install command is now deprecated and will be removed in a future major version of JupyterLab.

Users should manage prebuilt extensions with package managers like pip and conda, and extension authors are encouraged to distribute their extensions as prebuilt packages [0m
[33m[W 2024-03-20 15:30:54.039 LabApp][m Config option `kernel_spec_manager_class` not recognized by `LabApp`.


In [3]:
os.environ["WANDB_API_KEY"] = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import wandb

In [5]:
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        # CIFAR-10 images are 3x32x32, flatten them to 3072-dimensional vectors
        self.fc1 = nn.Linear(3*32*32, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10) # 10 classes in CIFAR-10
    
    def forward(self, x):
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

mlp_net = MLPNet()

In [6]:
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, stride=1) 
        self.bn1 = nn.BatchNorm2d(32)  # batch-norm
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)  # add filter num
        self.bn2 = nn.BatchNorm2d(64)  # batch-norm
        self.fc1 = nn.Linear(64 * 5 * 5, 120)  
        self.dropout = nn.Dropout(0.5)  # dropout
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # apply dropout
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

cnn_net = CNNNet()

In [7]:
class CNNNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 128, 3)
        self.bn2 = nn.BatchNorm2d(128)
        self.dropout = nn.Dropout(p=0.2) 
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 256, 3)
        self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 512, 3)
        self.bn6 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512 * 3 * 3, 4096)
        self.bn_fc1 = nn.BatchNorm1d(4096)
        self.fc2 = nn.Linear(4096, 256)
        self.bn_fc2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.bn1(F.gelu(self.conv1(x)))
        x = self.dropout(self.bn2(F.gelu(self.conv2(x))))
        x = self.pool(self.bn3(F.gelu(self.conv3(x))))
        x = self.dropout(self.bn4(F.gelu(self.conv4(x))))
        x = self.bn5(F.gelu(self.conv5(x)))
        x = self.dropout(self.pool(self.bn6(F.gelu(self.conv6(x)))))
        x = torch.flatten(x, 1)
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.fc3(x)
        return x

cnn_net2 = CNNNet2()

In [8]:
class EarlyStopping:
    """提前中止训练"""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        parameters:
            patience (int): 损失提高的迭代轮数，超过这个值后训练会停止
            verbose (bool): 如果为True，打印一条信息指示提前中止
            delta (float): 提升的最小变化，小于这个值被认为是没有提升
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("inf")
        self.delta = delta

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

In [9]:
def load_data(batch_size=16):
    # 仅用于训练集的数据增强变换
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
        transforms.RandomRotation(10),      # 随机旋转图像±10度
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 验证集和测试集的变换，没有包含随机性操作
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=train_transform)

    num_train = len(full_trainset)
    indices = list(range(num_train))
    split = int(np.floor(0.1 * num_train))  # 10%的数据用作验证集

    np.random.shuffle(indices)

    train_idx, val_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    trainloader = torch.utils.data.DataLoader(full_trainset, batch_size=batch_size,
                                              sampler=train_sampler, num_workers=2)
    # 使用full_trainset但应用test_transform进行数据加载
    valloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(root='./data', train=True,
                                                  download=True, transform=test_transform), batch_size=batch_size,
                                             sampler=val_sampler, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)
    return trainloader, valloader, testloader

In [10]:
def train_and_test(net, trainloader, valloader, testloader, device, epochs=20):
    early_stopping = EarlyStopping(patience=15, verbose=True)
    
    optimizer = optim.AdamW(net.parameters()) # use AdamW instead of SGD
    criterion = nn.CrossEntropyLoss()

    net.to(device)

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        total_train = 0
        correct_train = 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(trainloader)
        train_accuracy = correct_train / total_train

        # val
        net.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss /= len(valloader)
        val_accuracy = correct_val / total_val

        print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Train Acc: {train_accuracy}, Val Loss: {val_loss}, Val Acc: {val_accuracy}")
        
        wandb.log({'train_loss': train_loss, 
                   'train_accuracy': train_accuracy, 
                   'val_loss': val_loss, 
                   'val_accuracy': val_accuracy})

        # early stop method
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break
            
    print('Finished Training')

    
    PATH = './cifar_net.pth'
    torch.save(net.state_dict(), PATH)

    # test
    net.eval()  
    correct = 0
    total = 0
    with torch.no_grad(): 
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            # 通过网络运行图像计算输出
            outputs = net(images)
            # 选择能量最高的类作为预测结果
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')

    wandb.log({"test_accuracy": accuracy})

In [11]:
!nvidia-smi

Wed Mar 20 15:31:23 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              26W / 250W |      0MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         

In [12]:
def main(model_type, epochs):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    trainloader, valloader, testloader = load_data()

    if model_type == 'mlp':
        net = MLPNet()
    elif model_type == 'cnn':
        net = CNNNet()
    elif model_type == 'cnn2':
        net = CNNNet2()
    else:
        raise ValueError("Unsupported model type. Choose 'mlp' or 'cnn'.")
    wandb.init(project="cifar10_classification")
    train_and_test(net, trainloader, valloader, testloader, device, epochs)
    wandb.finish()

In [13]:
main('cnn2', epochs=50)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:15<00:00, 11184212.35it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Currently logged in as: [33mnagi-ovo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.16.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240320_153145-5nymeh7g[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mrural-field-29[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification/runs/5nymeh7g[0m


Epoch 1, Train Loss: 1.5382902543187946, Train Acc: 0.4444888888888889, Val Loss: 1.1656603637023475, Val Acc: 0.5804
Epoch 2, Train Loss: 1.1535790202522718, Train Acc: 0.5936222222222223, Val Loss: 0.959811592873293, Val Acc: 0.6674
Epoch 3, Train Loss: 0.9686891496276838, Train Acc: 0.6616444444444445, Val Loss: 0.7737962764006453, Val Acc: 0.7324
Epoch 4, Train Loss: 0.8464099498266349, Train Acc: 0.7050666666666666, Val Loss: 0.7085788064776137, Val Acc: 0.7536
Epoch 5, Train Loss: 0.7629352746815432, Train Acc: 0.7366888888888888, Val Loss: 0.6506390474474849, Val Acc: 0.7732
Epoch 6, Train Loss: 0.6855441344047076, Train Acc: 0.7633333333333333, Val Loss: 0.5936773314881629, Val Acc: 0.7894
Epoch 7, Train Loss: 0.6315463796904237, Train Acc: 0.7833333333333333, Val Loss: 0.5496586030140852, Val Acc: 0.8098
Epoch 8, Train Loss: 0.5835678021132501, Train Acc: 0.8013777777777777, Val Loss: 0.5275677310201687, Val Acc: 0.8148
Epoch 9, Train Loss: 0.5443927664104472, Train Acc: 0.812

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:  test_accuracy ▁
[34m[1mwandb[0m: train_accuracy ▁▃▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████████████████
[34m[1mwandb[0m:     train_loss █▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:   val_accuracy ▁▃▅▅▆▆▇▇▇▇▇▇█▇▇█▇███████████████████████
[34m[1mwandb[0m:       val_loss █▆▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:  test_accuracy 85.74
[34m[1mwandb[0m: train_accuracy 0.93007
[34m[1mwandb[0m:     train_loss 0.20264
[34m[1mwandb[0m:   val_accuracy 0.861
[34m[1mwandb[0m:       val_loss 0.44778
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mrural-field-29[0m at: [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification/runs/5nymeh7g[0m
[34m[1mwandb[0m: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 

In [14]:
main('mlp', epochs=50)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Tracking run with wandb version 0.16.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240320_155836-l2ra0wjr[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mwoven-lion-30[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification/runs/l2ra0wjr[0m


Epoch 1, Train Loss: 1.7165797447717532, Train Acc: 0.38822222222222225, Val Loss: 1.6469947543387977, Val Acc: 0.4198
Epoch 2, Train Loss: 1.5547211299025627, Train Acc: 0.44837777777777776, Val Loss: 1.5794330549697144, Val Acc: 0.436
Epoch 3, Train Loss: 1.4822777047981115, Train Acc: 0.4754, Val Loss: 1.6667548486599908, Val Acc: 0.4172
EarlyStopping counter: 1 out of 15
Epoch 4, Train Loss: 1.4362711231563594, Train Acc: 0.49264444444444444, Val Loss: 1.5407547695568193, Val Acc: 0.4544
Epoch 5, Train Loss: 1.4014781934698832, Train Acc: 0.5041333333333333, Val Loss: 1.580287456322021, Val Acc: 0.4466
EarlyStopping counter: 1 out of 15
Epoch 6, Train Loss: 1.3721132180761113, Train Acc: 0.5110222222222223, Val Loss: 1.5333679116572054, Val Acc: 0.472
Epoch 7, Train Loss: 1.3552488471704829, Train Acc: 0.5182888888888889, Val Loss: 1.4834576915628233, Val Acc: 0.4726
Epoch 8, Train Loss: 1.3396490123447387, Train Acc: 0.5253555555555556, Val Loss: 1.4918890418336033, Val Acc: 0.490

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:  test_accuracy ▁
[34m[1mwandb[0m: train_accuracy ▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█████████
[34m[1mwandb[0m:     train_loss █▆▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:   val_accuracy ▁▂▁▃▅▅▆▅▅▆▆▆▅▇▆▅▆▆▇▇▇▇▇▆▇▇▇▇█▇▇▆▇▇█▆▇█▇▇
[34m[1mwandb[0m:       val_loss ▇▅█▄▄▂▂▃▁▃▄▃▂▂▃▁▂▂▃▂▄▂▃▃▂▂▂▂▁▃▂▃▃▂▂▄▁▁▄▃
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:  test_accuracy 52.26
[34m[1mwandb[0m: train_accuracy 0.62209
[34m[1mwandb[0m:     train_loss 1.07945
[34m[1mwandb[0m:   val_accuracy 0.5076
[34m[1mwandb[0m:       val_loss 1.50553
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mwoven-lion-30[0m at: [34m[4mhttps://wandb.ai/nagi-ovo/cifar10_classification/runs/l2ra0wjr[0m
[34m[1mwandb[0m: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 

In [15]:
# main('cnn')