In [22]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=250)    # 可以控制输出是否换行
torch.set_grad_enabled(True)

import time
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
from collections import namedtuple
from itertools import product
import pandas as pd
import json
from torch.utils.data import DataLoader
from IPython.display import clear_output
from IPython.display import display


import math
import matplotlib.pyplot as plt

train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

#### **Using BatchNorm**

In [4]:
torch.manual_seed(50)

network1 = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
   ,nn.ReLU()
   ,nn.MaxPool2d(kernel_size=2, stride=2)
   ,nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
   ,nn.ReLU()
   ,nn.MaxPool2d(kernel_size=2, stride=2)
   ,nn.Flatten(start_dim=1)
   ,nn.Linear(in_features=12 * 4 * 4, out_features=120)
   ,nn.ReLU()
   ,nn.Linear(in_features=120, out_features=60)
   ,nn.ReLU()
   ,nn.Linear(in_features=60, out_features=10)
) 

In [5]:
torch.manual_seed(50)

network2 = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
   ,nn.ReLU()
   ,nn.MaxPool2d(kernel_size=2, stride=2)
   ,nn.BatchNorm2d(6)
   ,nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
   ,nn.ReLU()
   ,nn.MaxPool2d(kernel_size=2, stride=2)
   ,nn.Flatten(start_dim=1)
   ,nn.Linear(in_features=12 * 4 * 4, out_features=120)
   ,nn.ReLU()
   ,nn.BatchNorm1d(120)
   ,nn.Linear(in_features=120, out_features=60)
   ,nn.ReLU()
   ,nn.Linear(in_features=60, out_features=10)
) 

In [6]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
   ,train=True
   ,download=True
   ,transform=transforms.Compose([
       transforms.ToTensor()
   ])       
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw



In [7]:
loader = DataLoader(train_set, batch_size=len(train_set), num_workers=1)
data = next(iter(loader))
mean = data[0].mean()
std = data[0].std()
mean, std

(tensor(0.2860), tensor(0.3530))

In [8]:
train_set_normal = torchvision.datasets.FashionMNIST(
    root='./data'
   ,train=True
   ,download=True
   ,transform=transforms.Compose([
       transforms.ToTensor()
      ,transforms.Normalize(mean, std) 
   ]) 
)

In [9]:
trainsets = {
    'not_normal': train_set
   ,'normal': train_set_normal 
}

In [12]:
networks = {
    'not_batch_normal': network1
   ,'batch_normal': network2
}

In [None]:
# comment=f'-{run}'

In [15]:
class RunManager():
    def __init__(self):
        
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = 0
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        
        self.network = None
        self.loader = None
        self.tb = None
        
    def begin_run(self, run ,network, loader):
        self.run_start_time = time.time()
            
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter()
        
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        
        self.tb.add_image('image', grid)
        self.tb.add_graph(self.network, images.to(getattr(run, 'device', 'cpu')))
        
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
        
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
    
    def end_epoch(self):
        
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
        
        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results["loss"] = loss
        results["accuracy"] = accuracy
        results["epoch duration"] = epoch_duration
        results["run duration"] = run_duration
        
        for k, v in self.run_params._asdict().items():
            results[k] = v
            
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns') 
        
        clear_output(wait=True)
        display(df)
        
    def track_loss(self, loss):
        self.epoch_loss += loss.item() * self.loader.batch_size 
    
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)   
        
    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()
    
    def save(self, fileName):
        
        pd.DataFrame.from_dict(
            self.run_data,
            orient="columns"
        ).to_csv(f'{fileName}.csv')
        
        with open(f'{fileName}.json', 'w', encoding='utf-8')as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)
               
        

In [11]:
class RunBuilder():
    @staticmethod   # 无需创建实例，可通过指定类进行调用的方法
    def get_runs(params):
        Run = namedtuple('Run', params.keys())
        runs = []
        
        for v in product(*params.values()):
            runs.append(Run(*v))
            
        return runs

In [17]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"    # 不加这句下面的跑不出来

params = OrderedDict(
    lr = [.01]
    , batch_size = [1000]
    , num_workers = [1]
    , device = ['cuda']
    , trainset = ['normal']
    , network = list(networks.keys())
)

m = RunManager()
for run in RunBuilder.get_runs(params):
    
    device = torch.device(run.device)
    network = networks[run.network].to(device)
    loader = DataLoader(trainsets[run.trainset], batch_size=run.batch_size, num_workers=run.num_workers)
    optimizer = optim.Adam(network.parameters(), lr=run.lr)
    
    m.begin_run(run, network, loader)
    for epoch in range(20):
        m.begin_epoch()
        for batch in loader:
            
            images = batch[0].to(device)
            labels = batch[1].to(device)
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
            
        m.end_epoch()
    m.end_run()
    
m.save('results')

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,num_workers,device,trainset,network
0,1,1,0.50796,0.81685,21.426928,26.960624,0.01,1000,1,cuda,normal,network1
1,1,2,0.360259,0.866917,21.269679,48.739152,0.01,1000,1,cuda,normal,network1
2,1,3,0.326702,0.8782,21.13218,70.358601,0.01,1000,1,cuda,normal,network1
3,1,4,0.304367,0.88595,20.866048,91.753371,0.01,1000,1,cuda,normal,network1
4,1,5,0.295008,0.8896,21.034717,113.163648,0.01,1000,1,cuda,normal,network1
5,1,6,0.282299,0.89465,21.812103,135.511206,0.01,1000,1,cuda,normal,network1
6,1,7,0.266932,0.900167,22.32203,158.338163,0.01,1000,1,cuda,normal,network1
7,1,8,0.259916,0.901667,21.835737,180.586696,0.01,1000,1,cuda,normal,network1
8,1,9,0.248863,0.904717,21.59521,202.676947,0.01,1000,1,cuda,normal,network1
9,1,10,0.249999,0.904967,21.29514,224.463067,0.01,1000,1,cuda,normal,network1


In [23]:
pd.DataFrame.from_dict(m.run_data).sort_values('accuracy', ascending=False)

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,num_workers,device,trainset,network
39,2,20,0.172305,0.934067,9.858526,217.178738,0.01,1000,1,cuda,normal,network2
38,2,19,0.179563,0.9319,10.089246,207.211976,0.01,1000,1,cuda,normal,network2
36,2,17,0.183833,0.9296,9.907943,187.232933,0.01,1000,1,cuda,normal,network2
35,2,16,0.187987,0.928733,9.820056,176.850324,0.01,1000,1,cuda,normal,network2
37,2,18,0.185107,0.928633,9.582589,197.016214,0.01,1000,1,cuda,normal,network2
34,2,15,0.19145,0.92675,10.768962,166.917853,0.01,1000,1,cuda,normal,network2
33,2,14,0.200101,0.9241,10.204592,155.934592,0.01,1000,1,cuda,normal,network2
32,2,13,0.205803,0.921983,10.22874,145.623721,0.01,1000,1,cuda,normal,network2
31,2,12,0.20828,0.921417,11.203506,135.273883,0.01,1000,1,cuda,normal,network2
19,1,20,0.213486,0.918267,9.838089,432.377741,0.01,1000,1,cuda,normal,network1
