In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, clear_output
import pandas as pd
import time
import json

from itertools import product
from collections import namedtuple
from collections import OrderedDict

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',
    train = True,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [3]:
class Network(nn.Module): # 输入尺寸28*28
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12,kernel_size=5)
        
        self.fc1 = nn.Linear(in_features = 12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        #(1) hidden conv1 layer
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        #(2) hidden conv2 layer
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        #(3) hidden linear layer
        t = F.relu(self.fc1(t.reshape(-1, 12 * 4 * 4)))
        
        #(4) hidden linear layer
        t = F.relu(self.fc2(t))
        
        #(5) output layer
        t = self.out(t)
        t = F.softmax(t, dim = 1)
        return t

In [4]:
class RunBuilder():
    @staticmethod
    def get_runs(params): # params是OrderedDic，为了显示结果按照大小顺序排序
        Run = namedtuple('Run', params.keys())
        runs = []
        # 这个迭代的作用就是把生成的具名元组们全都填到列表里
        # 每个具名元组里的元素们都代表着一套超参数
        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs

In [5]:
class RunManager():
    def __init__(self):
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = 0
        
        #若干个epoch为1次run
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        
        self.network = None
        self.loader = None
        self.tb = None
        
    def begin_run(self, run, network, loader):
        self.run_start_time = time.time()
        
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')
        
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images)
    
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0 # ready for the next run
        
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        
    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss 
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
            
        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results['loss'] = loss
        results["accuracy"] = accuracy
        results['epoch duration'] = epoch_duration
        results['run duration'] = run_duration
        
        for k,v in self.run_params._asdict().items(): results[k] = v # _asdict()是namedtuple的方法
        self.run_data.append(results)

        df = pd.DataFrame.from_dict(self.run_data, orient='columns')
        
        # 下两行使Notebook可以更新结果
        clear_output(wait = True)
        display(df)
    
    def track_loss(self, loss):
        self.epoch_loss += loss.item() 
    
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)
    
    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim = 1).eq(labels).sum().item()
        
    # run_data在保存为csv和json格式的文件
    def save(self, fileName):
        pd.DataFrame.from_dict(
            self.run_data
            , orient='columns'
        ).to_csv(f'{fileName}.csv')
 
        with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)

In [6]:
params = OrderedDict(
    lr = [.001,.01],
    batch_size = [100,1000],
    shuffle = [True, False]
)

m = RunManager()
for run in RunBuilder.get_runs(params):
    
    network = Network()
    loader = DataLoader(train_set, batch_size = run.batch_size, shuffle = run.shuffle)
    optimizer = optim.Adam(network.parameters(), lr = run.lr)
    
    m.begin_run(run, network, loader)
    for epoch in range(5):
        m.begin_epoch()
        for batch in loader:
            
            images, labels = batch
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
        m.end_epoch()
    m.end_run()
m.save('results')

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle
0,1,1,1074.169397,0.683767,29.664574,32.868892,0.001,100,True
1,1,2,1001.376605,0.795,32.005572,65.097719,0.001,100,True
2,1,3,982.812772,0.825417,39.82456,105.161789,0.001,100,True
3,1,4,972.449978,0.84205,39.645549,145.030713,0.001,100,True
4,1,5,967.436489,0.849883,43.922722,189.246735,0.001,100,True
5,2,1,1074.290648,0.679567,38.594449,39.03955,0.001,100,False
6,2,2,1016.744846,0.76865,38.831541,78.107139,0.001,100,False
7,2,3,997.725079,0.800217,38.472852,116.872038,0.001,100,False
8,2,4,987.624631,0.816733,39.899454,157.063381,0.001,100,False
9,2,5,980.979384,0.82725,45.369592,202.749279,0.001,100,False


## 输出的csv与json文件可在同一目录下找到