In [8]:
import torch
import torchvision
from torch import nn
import torch.optim as optim
from model import MFNN1, MFNN2, MLP, FCNN
import utils 
from torch.optim.swa_utils import AveragedModel
import numpy as np
#import pandas as pd
import sys
import time
import random
from dataclasses import dataclass
from sam import SAM
from nsam import NSAM
from pyhessian import hessian
import matplotlib.pyplot as plt
from tqdm import tqdm

In [9]:
config = utils.configs(opt_type='SWA',
                       dataset='FashionMNIST', 
                       model='FCNN', 
                       scheduler=1, 
                       lr=0.1,
                       bs=600, 
                       rho = 10, 
                       epochs = 100)

root = utils.make_root(config)
path = utils.get_path(root)

config

configs(model='FCNN', opt_type='SWA', scheduler=1, dataset='FashionMNIST', lr=0.1, bs=600, weight_decay=0.0, momentum=0.0, rho=10, epochs=100)

In [10]:
loader_train, loader_test = utils.make_data_loader(config)

In [11]:
#Hessianを計算するためにtレーニングデータセットから1000個抽出したサブセットを作成

import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader 
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST

train = FashionMNIST('FashionMNIST', 
                      train = True, 
                      download = True, 
                      transform = transforms.ToTensor()
                      )
X_train = train.data.type(torch.float32)
t_train = train.targets

ds_train = TensorDataset(X_train, t_train) 

random.seed(0)
data_num = 100
subset = random.sample(range(X_train.data.shape[0]),data_num)
sample_ds_hess = torch.utils.data.Subset(ds_train, subset)
sample_sampler_hess = torch.utils.data.RandomSampler(sample_ds_hess)
hess_loader = torch.utils.data.DataLoader(sample_ds_hess, sampler=sample_sampler_hess, batch_size=data_num)

In [12]:
#Hessian　traceの計算
def cal_hess(model):
    hess_trace = 0.0
    for X, t in hess_loader:
        hess_trace = hess_trace + np.mean(hessian(model, loss_fn, data=(X, t), cuda=True).trace()) 
    return hess_trace

In [13]:
def main(model, opt, loss_fn, scheduler, config):
    rho = config.rho
    swa_start = 50
    hess_list = []
    for epoch in range(config.epochs):
        with open(path, 'a') as f:
            f.write(f'EPOCH: {epoch}\n')
        print(f'EPOCH: {epoch}')
        
        start = time.time() 
        
        #start train loop----------------------
        train_loss = []
        total_train = 0
        correct_train = 0
        
        model.train()
        for X, t in loader_train:
            X = X.reshape([config.bs,1,28,28])
            y = model(X)
            opt.zero_grad() 
            opt_type = config.opt_type
                          
            if opt_type == 'NSAM' or opt_type == 'SAM':
                with torch.no_grad():
                    loss = loss_fn(y, t)
                    opt.first_step(zero_grad=True)
                loss_fn(model(X), t).backward()
                opt.second_step(zero_grad=True)
            elif opt_type == 'SGD':
                loss = loss_fn(y, t)
                loss.backward()
                opt.step()
            elif opt_type == 'SWA':
                loss = loss_fn(y, t)
                loss.backward()
                opt.step()
                if epoch > swa_start:
                    averaged.update_parameters(model)
        
            pred = y.argmax(1)
            train_loss.append(loss.tolist())
            
            total_train += t.shape[0]
            correct_train += (pred==t).sum().item()
            
            #calicurate grad norm(with l2 rag)
            norm = 0
            for j, param in enumerate(model.parameters()):
                norm += torch.norm(param.grad)
            
            print(norm)
          
        scheduler.step()
    
        end  = time.time()
        
        log = f'train loss: {np.mean(train_loss):.3f}, accuracy: {correct_train/total_train:.3f}'
        with open(path, 'a') as f:
            f.write(log + f' train_time: {end - start:.5f}' + '\n')
        print(log)
        #end train loop------------------------------------
                
        #start test loop-----------------------------------
        test_loss = []
        total_test= 0
        correct_test = 0
    
        model.eval()
        
        hess_tr = 0.0
        """
        if opt_type != 'SWA' or epoch < swa_start:
            hess_tr = cal_hess(model)
        else:
            hess_tr = cal_hess(averaged)
        hess_list.append(hess_tr)
        """
        
        for X, t in loader_test:
            X = X.reshape([10000,1,28,28])
            if opt_type == 'SWA':
                y = averaged(X)
            else:
                y = model(X)
            loss = loss_fn(y, t)
            
            pred = y.argmax(1)
            test_loss.append(loss.tolist())
            
            total_test += t.shape[0]
            correct_test += (pred==t).sum().item()
    
        log = f'test loss: {np.mean(test_loss):.3f}, accuracy: {correct_test/total_test:.3f}, hessian trace: {hess_tr:.3f}'
        with open(path, 'a') as f:
            f.write(log + '\n')
        print(log)
        # end test loop-------------------------------------- 
        
        print(f'time: {end - start:.5f}')    
        
    return hess_list

In [14]:
model = FCNN()
averaged = AveragedModel(model)

base_optimizer = optim.SGD

if config.opt_type == 'SGD' or config.opt_type == 'SWA':
    opt = optim.SGD(model.parameters(), lr = config.lr)
elif config.opt_type == 'SAM':
    opt = SAM(model.parameters(), base_optimizer, lr = config.lr, rho = config.rho, device=device)
elif config.opt_type == 'NSAM':
    opt = NSAM(model.parameters(), base_optimizer, lr = config.lr, rho = config.rho, device=device)

loss_fn = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.LambdaLR(opt, lr_lambda = utils.func2, verbose=True)
hess_list = main(model, opt, loss_fn, scheduler, config)

Adjusting learning rate of group 0 to 1.0000e-01.
EPOCH: 0
tensor(0.1929)
tensor(0.9221)
tensor(0.1965)
tensor(0.0845)
tensor(0.0530)
tensor(0.0652)
tensor(0.0749)
tensor(0.1074)
tensor(0.1256)
tensor(0.0799)
tensor(0.1940)
tensor(0.0623)
tensor(0.0676)
tensor(0.0669)
tensor(0.1286)
tensor(0.0531)
tensor(0.0729)
tensor(0.1184)
tensor(0.0536)
tensor(0.0790)
tensor(0.0722)
tensor(0.0902)
tensor(0.0677)
tensor(0.0888)
tensor(0.1197)
tensor(0.1265)
tensor(0.1146)
tensor(0.1022)
tensor(0.0689)


KeyboardInterrupt: 