<a href="https://colab.research.google.com/github/Ryan-Lily/python-learning-notes/blob/master/TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
sys.path.append('..')

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

from torchvision.datasets import CIFAR10
from torchvision import transforms as tfs
from datetime import datetime

from utils import resnet

#使用数据增强
def train_tf(x):
    img_aug = tfs.Compose([
        tfs.Resize(120),
        tfs.RandomHorizontalFlip(),
        tfs.RandomCrop(96),
        tfs.ColorJitter(brightness = 0.5, contrast = 0.5, hue = 0.5),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    x = img_aug(x)
    return x

def test_tf(x):
    img_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = img_aug(x)
    return x


train_set = CIFAR10('./data', train = True, transform = train_tf, download = True)
test_set = CIFAR10('./data', train = False, transform = test_tf, download = True)

train_data = torch.utils.data.DataLoader(train_set, shuffle = True, batch_size = 256)
test_data = torch.utils.data.DataLoader(test_set, shuffle = False, batch_size = 256)

net = resnet(3, 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.1, weight_decay = 1e-4)

In [None]:
writer = SummaryWriter()

def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total

if torch.cuda.is_available():
    net = net.cuda()

prev_time = datetime.now()

for epoch in range(30):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
        if torch.cuda.is_available():
            im = Variable(im).cuda()
            label = Variable(label).cuda()
        else:
            im = Variable(im)
            label = Variable(label)
        
        output = net(im)
        loss = criterion(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_acc += get_acc(output, label)
    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    valid_loss = 0
    valid_acc = 0
    net = net.eval()
    for im, label in test_data:
        if torch.cuda.is_available():
            im =Variable(im).cuda()
            label =Variable(label).cuda()
        else:
            im = Variable(im)
            lable = Variable(label)
        
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
    epoch_str = "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f" % (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(train_data), valid_acc / len(train_data))
    prev_time = cur_time
    writer.add_scalars('Loss', {'train': train_loss / len(train_data), 'valid': valid_loss / len(train_data)}, epoch)
    writer.add_scalars('Acc', {'train': train_acc / len(train_data), 'valid': valid_acc / len(train_data)}, epoch)
    print(epoch_str + time_str)