In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from collections import OrderedDict
from collections import namedtuple
from itertools import product

from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from IPython.display import display, clear_output
import pandas as pd
import time
import json

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)


<torch.autograd.grad_mode.set_grad_enabled at 0x20c33787850>

## 添加RunManger
### 提取tensorboard的调用 允许添加额外的功能

In [2]:
class RunManager():
    def __init__(self):
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None

        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None

        self.network = None
        self.loader = None
        self.tb = None

    def begin_run(self,run,network,loader):
        self.run_start_time = time.time()

        self.run_params = run
        self.run_count +=1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'--{run}')

        images,labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)

        self.tb.add_image('images',grid)
        self.tb.add_graph(self.network,images)

    def end_run(self):
        self.tb.close()
        self.epoch_count = 0#为下一次做准备

    def begin_epoch(self):#每次重置epoch
        self.epoch_start_time = time.time()

        self.epoch_count +=1
        self.epoch_loss =0
        self.epoch_num_correct = 0

    def end_epoch(self):
        #epoch 的时间 run为运行总时长
        epoch_duration = time.time() -self.epoch_start_time
        run_duration = time.time() - self.run_start_time

        #epoch的loss 和准确率
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct/len(self.loader.dataset)

        #保存每次epoch的数据到tensorboard上 epoch_count 指明哪次epoch 
        self.tb.add_scalar('loss',loss,self.epoch_count)
        self.tb.add_scalar('Accuracy',accuracy,self.epoch_count)

        for name,param in self.network.named_parameters():
            self.tb.add_histogram(name,param,self.epoch_count)
            self.tb.add_histogram(f'{name}.grad',param.grad,self.epoch_count)


        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results["loss"] = loss
        results["accuracy"] = accuracy
        results["epoch_duration"] = epoch_duration
        results["run_duration"] = run_duration

        for k,v in self.run_params._asdict().items():
            results[k] = v

        self.run_data.append(results)

        df = pd.DataFrame.from_dict(self.run_data,orient = 'columns')

        clear_output(wait=True)
        display(df)
    def track_loss(self, loss, batch):
        self.epoch_loss += loss.item() * batch[0].shape[0]

    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)
    
    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

    def save(self, fileName):

        pd.DataFrame.from_dict(
            self.run_data, orient='columns'
        ).to_csv(f'{fileName}.csv')

        with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)

In [3]:
class RunBuilder():
    @staticmethod
    def get_runs(params):

        Run = namedtuple('Run',params.keys())

        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs

In [4]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels =1,out_channels = 6,kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels =6,out_channels = 12,kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2 = nn.Linear(in_features=120,out_features=60)
        self.out = nn.Linear(in_features=60,out_features=10)

    def forward(self,t):
        #t = t #第一层输入
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size =2,stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t,kernel_size = 2,stride =2)

        t = t.reshape(-1,12*4*4)
        t = F.relu(self.fc1(t))

        t = F.relu(self.fc2(t))

        t = self.out(t)

        return t


In [5]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/'
    ,train=True
    ,download=True
    ,transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [6]:
params = OrderedDict(
    lr = [.01]
    ,batch_size = [100,1000,10000]
   # ,shuffle = [True,False]
    ,num_workers = [0,1,2,4]
)

m = RunManager()

for run in RunBuilder.get_runs(params):
    network = Network()
    loader = DataLoader(train_set,batch_size = run.batch_size,num_workers=run.num_workers)
    optimizer = optim.Adam(network.parameters(),lr=run.lr)

    m.begin_run(run,network,loader)
    for epoch in range(5):
        m.begin_epoch()
        for batch in loader:

            images,labels = batch
            preds = network(images)
            loss = F.cross_entropy(preds,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            m.track_loss(loss,batch)
            m.track_num_correct(preds,labels)
        m.end_epoch()
    m.end_run()
m.save('results')
    


Unnamed: 0,run,epoch,loss,accuracy,epoch_duration,run_duration,lr,batch_size,num_workers
0,1,1,0.563614,0.78785,20.120712,23.118753,0.01,100,0
1,1,2,0.378374,0.860317,17.183592,40.577524,0.01,100,0
2,1,3,0.347882,0.870767,16.404134,57.090365,0.01,100,0
3,1,4,0.329626,0.877183,16.585684,73.811202,0.01,100,0
4,1,5,0.321449,0.8805,17.131477,91.069356,0.01,100,0
5,2,1,0.541784,0.795433,11.24418,14.209212,0.01,100,1
6,2,2,0.37364,0.862033,11.261182,25.580102,0.01,100,1
7,2,3,0.345288,0.872417,10.949377,36.635213,0.01,100,1
8,2,4,0.33039,0.87875,10.703329,47.484153,0.01,100,1
9,2,5,0.319582,0.882317,10.250116,57.829017,0.01,100,1
