## Training a FashionMNIST model using pytorch

#### Import packages

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms  # Help transform the data
import numpy as np
import pandas as pd

# for run maneger
import json
from collections import OrderedDict
from collections import namedtuple
from itertools import product
import time
from torch.utils.tensorboard import SummaryWriter


# Check the version
print(torch.__version__)
print(torchvision.__version__)

# Check cuda
print(torch.cuda.is_available())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

2.1.0.dev20230618+cu121
0.16.0.dev20230619+cu121
True
cuda


#### Define Network

In [3]:
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        t = t 
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = t.reshape(-1, 12*4*4)  # t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        return t

#### Define some useful functions

In [4]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

#### Define run builder and manager

In [5]:
class RunBuilder():
    @staticmethod
    def get_runs(params):
        Run = namedtuple('Run', params.keys())
        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs

In [6]:
params = OrderedDict(
    lr = [.01, .001],
    batch_size = [1000, 10000]
)

In [7]:
runs = RunBuilder.get_runs(params)
runs

[Run(lr=0.01, batch_size=1000),
 Run(lr=0.01, batch_size=10000),
 Run(lr=0.001, batch_size=1000),
 Run(lr=0.001, batch_size=10000)]

In [8]:
for run in runs:
    print(run, run.lr, run.batch_size)

Run(lr=0.01, batch_size=1000) 0.01 1000
Run(lr=0.01, batch_size=10000) 0.01 10000
Run(lr=0.001, batch_size=1000) 0.001 1000
Run(lr=0.001, batch_size=10000) 0.001 10000


In [9]:
for run in RunBuilder.get_runs(params):
    comment = f'-{run}'
    print(comment)

-Run(lr=0.01, batch_size=1000)
-Run(lr=0.01, batch_size=10000)
-Run(lr=0.001, batch_size=1000)
-Run(lr=0.001, batch_size=10000)


In [10]:
class RunManager():
    def __init__(self):
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        
        self.network = None
        self.loader = None
        self.tb = None
        

    def begin_run(self, run, network, loader):
        self.run_start_time = time.time()
        
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')
        
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images.to(getattr(run, 'device', 'cpu')))

        
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
        

    def begin_epoch(self):
        self.epoch_start_time = time.time()
        
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0


    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss/len(self.loader.dataset)
        accuracy = self.epoch_num_correct/len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
            
        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results["loss"] = loss
        results["accuracy"] = accuracy
        results["epoch duration"] = epoch_duration
        results["run duration"] = run_duration
        for k,v in self.run_params._asdict().items(): results[k] = v
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')

        #clear_output(wait=True)  # jupyter command
        display(df)


    def track_loss(self, loss):
        self.epoch_loss += loss.item()*self.loader.batch_size


    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)


    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()


    def save(self, fileName):
        pd.DataFrame.from_dict(
            self.run_data,
            orient='columns').to_csv(f'{fileName}.csv')
        with open(f'{fileName},json','w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)

#### Get trainning set

In [11]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',   # 数据集在本地的存储位置
    train = True,                   # 数据集用于训练
    download = True,                # 如果本地没有数据，就自动下载
    transform = transforms.Compose([
        transforms.ToTensor()         
    ])                              # 将图像转换成张量
)

train_loader = torch.utils.data.DataLoader(train_set)
# 训练集被打包或加载到数据加载器中，可以以我们期望的格式来访问基础数据；
# 数据加载器使我们能够访问数据并提供查询功能

#### Config

In [12]:
params = OrderedDict(
    lr = [.01],
    batch_size =[1000, 2000], 
    shuffle = [True, False], 
    device = ['cpu', 'cuda'])

weight_decay = 1e-5

#### Training

In [13]:
m = RunManager()
for run in RunBuilder.get_runs(params):

    device = torch.device(run.device)
    network = Network().to(device)
    loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch_size, shuffle=run.shuffle)
    optimizer = optim.Adam(network.parameters(), lr=run.lr, weight_decay=weight_decay)
    
    m.begin_run(run, network, loader)
    for epoch in range(5):
        m.begin_epoch()
        for batch in loader:
            images = batch[0].to(device)
            labels = batch[1].to(device)
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
            
        m.end_epoch()
    m.end_run()
m.save('resuls')

  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda


  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,device
0,1,1,1.032375,0.606383,9.352872,10.009187,0.01,1000,True,cpu
1,1,2,0.571312,0.779267,9.55718,19.624873,0.01,1000,True,cpu
2,1,3,0.487586,0.819033,9.066132,28.738521,0.01,1000,True,cpu
3,1,4,0.439671,0.8403,9.053366,37.83301,0.01,1000,True,cpu
4,1,5,0.393559,0.855267,8.561755,46.440281,0.01,1000,True,cpu
5,2,1,1.059964,0.58835,7.575051,13.200307,0.01,1000,True,cuda
6,2,2,0.583228,0.774933,7.024577,20.29839,0.01,1000,True,cuda
7,2,3,0.492476,0.817433,6.970921,27.325854,0.01,1000,True,cuda
8,2,4,0.445764,0.834167,6.984803,34.378654,0.01,1000,True,cuda
9,2,5,0.407803,0.849933,6.904956,41.347786,0.01,1000,True,cuda
