# Using TensorBoard with PyTorch

TensorBoard 是一个可视化工具,是一个前端界面,从一个文件中读取数据,然后显示.它使我们能够跟踪和可视化度量标准,比如loss和accuracy,还能可视化网络图等.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

from torch.utils.tensorboard import SummaryWriter

In [2]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [18]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features = 10)
    
    def forward(self,t):
        
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = t.reshape(-1,12 * 4 *4)
        t = F.relu(self.fc1(t))
        
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        
        return t

In [6]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data',
    train=True,
    download=True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

## 开始使用torchboard

In [20]:
tb = SummaryWriter()

network = Network()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

p = network(images)
tb.add_image('images', grid)
tb.add_graph(network, images)
tb.close()

In [26]:
batch_size = 100
lr = 0.01

network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(network.parameters(), lr=lr)

comment = f' batch_size = {batch_size} lr = {lr}' #注释字符串


tb = SummaryWriter(comment = comment)
tb.add_image('images', grid)
tb.add_graph(network, images)

for epoch in range(10):
    total_loss = 0
    total_correct = 0
    for batch in train_loader:
        images, labels = batch
        preds = network(images)
        loss = F.cross_entropy(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_correct += get_num_correct(preds, labels)
        total_loss += loss.item() * batch_size
    
    tb.add_scalar("Loss", total_loss, epoch)
    tb.add_scalar('Number Correct', total_correct, epoch)
    tb.add_scalar('Accuracy', total_correct/len(train_set), epoch)
    
    tb.add_histogram('conv1.bias', network.conv1.bias, epoch)
    tb.add_histogram('conv1.weight',network.conv1.weight, epoch)
    tb.add_histogram('conv1.weight.grad', network.conv1.weight.grad, epoch)
                  
    
    print('epoch:', epoch, 'total correct:', total_correct, 'total loss:',total_loss, 'accuracy:', total_correct/len(train_set))

                  
tb.close()

epoch: 0 total correct: 53460 total loss: 176.29366613924503 accuracy: 0.891
epoch: 1 total correct: 53647 total loss: 174.04086027294397 accuracy: 0.8941166666666667
epoch: 2 total correct: 53689 total loss: 171.92572508752346 accuracy: 0.8948166666666667
epoch: 3 total correct: 53701 total loss: 171.8174800425768 accuracy: 0.8950166666666667
epoch: 4 total correct: 53728 total loss: 170.1576620489359 accuracy: 0.8954666666666666
epoch: 5 total correct: 53842 total loss: 169.22006750106812 accuracy: 0.8973666666666666
epoch: 6 total correct: 53788 total loss: 169.22142618894577 accuracy: 0.8964666666666666
epoch: 7 total correct: 53645 total loss: 174.30822359025478 accuracy: 0.8940833333333333
epoch: 8 total correct: 53974 total loss: 165.42745631933212 accuracy: 0.8995666666666666
epoch: 9 total correct: 53964 total loss: 164.84228514879942 accuracy: 0.8994


# RunBuilder类

In [54]:
from collections import OrderedDict
from collections import namedtuple
from itertools import product
import pandas as pd
import time

In [34]:
#构建运行的参数集
class RunBuilder():
    @staticmethod
    def get_runs(params):

        Run = namedtuple('Run', params.keys())

        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))

        return runs


In [35]:
params = OrderedDict(
    lr = [.01, .001]
    ,batch_size = [1000, 10000]
)


In [43]:
for run in RunBuilder.get_runs(params):
    comment = f'{run}'
    print(comment)


Run(lr=0.01, batch_size=1000)
Run(lr=0.01, batch_size=10000)
Run(lr=0.001, batch_size=1000)
Run(lr=0.001, batch_size=10000)


# 构建RunManager类

In [68]:
class RunManager():
    def __init__(self):
        
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        
        self.network = None
        self.loader = None
        self.tb = None
        
    def begin_run(self, run, network, loader):
        self.run_start_time = time.time()
        
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')
        
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images)
    
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
        
    def end_epoch(self):
        self.tb.close()
        self.epoch_count = 0
        
        
    
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        
    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name,param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
            
        result = OrderedDict()
        
        results = OrderedDict()
        results['runs'] = self.run_count
        results['epoch'] = self.epoch_count
        results['loss'] = loss
        results['accuracy'] = accuracy
        results['epoch duration'] = epoch_duration
        results['run duration'] = run_duration
        for k,v in self.run_params._asdict().items() : results[k] = v
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')
        
        #clear_output(wait = True)
        display(df)
        
    def track_loss(self, loss):
        self.epoch_loss += loss.item() * self.loader.batch_size
        
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)
            
    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim = 1).eq(labels).sum().item()
        
    def save(self, filename):
            
        pd.DataFrame.from_dict(
            self.run_data,
            orient = 'columns',
        ).to_csv(f'{filename}.csv')
            
        with open(f'{fileName}.son', 'w', encoding = 'utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii = False, indent = 4)
            

In [70]:
parms = OrderedDict(
    lr = [.01],
    batch_size = [1000, 2000]
)

m = RunManager()
for run in RunBuilder.get_runs(params):
    
    network = Network()
    loader = torch.utils.data.DataLoader(train_set, batch_size = run.batch_size)
    optimizer = optim.Adam(network.parameters(), lr = run.lr)
    
    m.begin_run(run, network, loader)
    
    for epoch in range(5):
        m.begin_epoch()
        for batch in loader:
            
            images, labels = batch
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
            
        m.end_epoch()
    m.end_run()
m.close()
m.save('results')
            

Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


Unnamed: 0,runs,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size
0,1,1,0.885079,0.673167,5.582882,6.052915,0.01,1000
1,1,2,0.491405,0.816133,5.572856,11.675041,0.01,1000
2,1,3,0.408394,0.849033,5.580238,17.303495,0.01,1000
3,1,4,0.364318,0.866733,5.550674,22.904758,0.01,1000
4,1,5,0.337685,0.8763,5.597674,28.553536,0.01,1000
5,2,1,2.070852,0.233983,5.743619,9.850224,0.01,10000
6,2,2,1.848424,0.316583,5.481817,15.384301,0.01,10000
7,2,3,1.420566,0.47215,5.696931,21.131911,0.01,10000
8,2,4,1.131756,0.576083,5.696314,26.87729,0.01,10000
9,2,5,1.001418,0.620767,5.451691,32.378489,0.01,10000


AttributeError: 'RunManager' object has no attribute 'close'