# Import

In [1]:
import argparse
import torch
from statistics import mean
import csv
from torchvision.models import efficientnet_v2_m
from torchvision.models import EfficientNet_V2_M_Weights
from torchvision.models import vit_b_16
from torchvision.models import ViT_B_16_Weights


from model.wide_res_net_cifar import WideResNet_cifar
from model.wide_res_net_fashionmnist import WideResNet_fashionmnist
# from model.wide_res_net_food101 import WideResNet_food101
from model.PyramidNet import PyramidNet

from model.smooth_cross_entropy import smooth_crossentropy

from data_cifar100.cifar import Cifar100
from data_cifar100_224.cifar import Cifar100_224
from data_cifar10.cifar import Cifar10
from data_fashionmnist.fashionmnist import fashionmnist
from data_food101.food101 import Food101

from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats


# Training Script

In [2]:
def sgd_train(iteration=0):
    """
    When select experiment:
    - WideResNet: please select one that corresdoning wiht the data, since some layer in the model need to be modified
    - PyramidNet: curtainly work only with Cifar
    """
    experiment = "PyramidNet" #"WideResNet_cifar100", "WideResNet_cifar10", "WideResNet_fashinmnist", "WideResNet_food101", "PyramidNet", "EfficientNetV2_m"
    data = "Cifar100" #"Cifar100", "Cifar100_224", "Cifar10", "fashionmnist", "food101"
    #DONT FORGET TO CHECK 'num_classes'

    # define parameters
    gpu = "cuda:3"
    threads = 36
    label_smoothing = 0
    result_dir = "/home/tratchatorn/SAM/sam/example/result_repeat_" + experiment + "_" + data
    file_name = f"testSGD_{iteration}"
    optimizer_save_path = result_dir + "/" + file_name + "_opt.pth"
    model_save_path = result_dir + "/" + file_name + "_model.pth"
    csv_path = result_dir + "/" + file_name + ".csv"

    header = ["epoch", "lr", "train_loss", "train_acc", "val_loss", "val_acc"]
    with open(csv_path, 'a') as file:
        writer = csv.writer(file)
        writer.writerow(header)
    
    initialize(seed=42)
    device = torch.device(gpu)
        
    log = Log(log_each=10)
    
    base_optimizer = torch.optim.SGD
    
    if experiment == "WideResNet_cifar100":
        batch_size = 256
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 400
        depth = 28
        width_factor = 10
        dropout = 0
        model = WideResNet_cifar(depth, width_factor, dropout, in_channels=3, labels=100).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "WideResNet_cifar10":
        batch_size = 256
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 400
        depth = 28
        width_factor = 10
        dropout = 0
        model = WideResNet_cifar(depth, width_factor, dropout, in_channels=3, labels=10).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "PyramidNet":
        batch_size = 64
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 400
        depth = 272
        alpha = 200
        num_classes = 100
        bottleneck = True
        _data = "cifar100" #for "food101" use dataset="cifar100"
#         _data = "fashionmnist"
        model = PyramidNet(dataset=_data, depth=depth, alpha=alpha, num_classes=num_classes).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "EfficientNetV2_m":
        """to overwrite dropout fix EfficientNet.py (anaconda3/envs/Jin_AACE/lib/python3.8/site-packages/torchvision/models/efficientnet.py)
        according to https://github.com/pytorch/vision/commit/5785e2b05cdffeb39678914b8308a260e7e757db 
        or
        update torchvision to 0.17 or newer"""
        batch_size = 32
        learning_rate = 0.001
        momentum = 0.9
        weight_decay = 0
        epochs = 100
        dropout = 0
        num_classes = 100
        pre_trained_weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 # EfficientNet_V2_M_Weights.IMAGENET1K_V1, None
        model = efficientnet_v2_m(weights=pre_trained_weights, dropout=dropout) #use pre-trained weights
        model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) #change last dense layer to be consistent with training data
        model = model.to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
    elif experiment == "vit_b_16":
        batch_size = 32
        learning_rate = 0.01
        momentum = 0.9
        weight_decay = 0
        epochs = 200
        dropout = 0
        num_classes = 100
        image_size = 224
        pre_trained_weights = ViT_B_16_Weights.IMAGENET1K_V1 # weights=ViT_B_16_Weights.IMAGENET1K_V1, None
        model = vit_b_16(weights=pre_trained_weights, image_size=image_size)
        model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_classes) #change last dense layer to be consistent with training data
        model = model.to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
        
        
    if data == "Cifar100":
        dataset = Cifar100(batch_size, threads)
    elif data == "Cifar100_224":
        dataset = Cifar100_224(batch_size, threads)
    elif data == "Cifar10":
        dataset = Cifar10(batch_size, threads)
    elif data == "fashionmnist":
        dataset = fashionmnist(batch_size, threads)
    elif data == "food101":
        dataset = Food101(batch_size, threads)
    
#     print(model)
    for epoch in range(epochs):
        result_list = []
        model.train()
        log.train(len_dataset=len(dataset.train))
        
        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, smoothing=label_smoothing)
            loss.mean().backward()
            optimizer.step()
            optimizer.zero_grad()

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                if experiment == "EfficientNetV2_m" or experiment == "vit_b_16":
                    log(model, loss.cpu(), correct.cpu(), optimizer.param_groups[0]['lr'])
                    scheduler.step()
                else:
                    log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                    scheduler(epoch)
        
        model.eval()
        train_loss, train_acc, lr = log.output_0()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets, smoothing=label_smoothing)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())
        
        val_loss, val_acc = log.output_1()
        log.flush()

        result_list.extend([epoch, lr, train_loss, train_acc, val_loss, val_acc])
        with open(csv_path, 'a') as file:
            writer = csv.writer(file)
            writer.writerow(result_list)


# sgd_train(experiment, data)

In [3]:
# for i in range(4):
#     sgd_train(iteration=i)
sgd_train(iteration=4)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           0  ┃      3.9331  │      8.60 %  ┃   1.000e-01  │   05:34 min  ┃┈███████████████████████████┈┨      3.5939  │     13.66 %  ┃
      3.5939  │     13.66 %  ┃
┃           1  ┃      3.3878  │     17.18 %  ┃   1.000e-01  │   05:35 min  ┃┈███████████████████████████┈┨      3.1556  │     21.06 %  ┃
      3.1556  │     21.06 %  ┃
┃           2  ┃      2.9473  │     25.25 %  ┃   1.000e-01  │   05:35 min  ┃┈███████████████████████████┈┨      2.8580  │ 