In [1]:
import torch
import torchvision.datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.nn import Conv2d, MaxPool2d, Linear, Flatten
from torch.utils.tensorboard import SummaryWriter
from torch.nn import L1Loss

In [2]:
dataset = torchvision.datasets.CIFAR10('CIFAR', False, transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)

In [3]:
 class CIFARModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_maxp = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
        )
        self.flatten = Flatten()
        self.mlp = nn.Sequential(
            Linear(1024, 64),
            Linear(64, 10)
        )
    
    def forward(self, x):
        x = self.conv_maxp(x)
        x = self.flatten(x)
        x = self.mlp(x)
        return x

model = CIFARModel()
model

CIFARModel(
  (conv_maxp): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1024, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [None]:
# debug model
input = torch.ones((64, 3, 32, 32))
output = model(input)
input, output

In [None]:
#visualize model
writer = SummaryWriter('logs')
writer.add_graph(model, torch.ones((64, 3, 32, 32)))
writer.close()

In [13]:
from tqdm import tqdm

loss_CE = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

writer = SummaryWriter('logs')

for epoch in tqdm(range(5)):  
    for step, data in enumerate(dataloader):
        imgs, targets = data
        outputs = model(imgs)
        loss = loss_CE(outputs, targets)
        writer.add_scalar('CIFAR_loss_CE', loss, step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    

100%|██████████| 5/5 [00:41<00:00,  8.23s/it]


In [14]:
 for step, data in enumerate(dataloader):
    if step > 1:
        break
    imgs, targets = data
    outputs = model(imgs)
    loss = loss_CE(outputs, targets)
    print('out:', torch.argmax(outputs, 1))
    print('tar: ', targets)
    print('loss: ', loss)

out: tensor([3, 1, 1, 0, 6, 6, 1, 6, 3, 1, 0, 9, 3, 7, 9, 0, 5, 2, 8, 6, 7, 0, 2, 9,
        4, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 3, 4, 1, 9, 5, 8, 6, 5, 6, 0, 9, 3, 7,
        4, 6, 9, 8, 2, 3, 8, 8, 7, 8, 2, 3, 7, 3, 6, 3])
tar:  tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3])
loss:  tensor(0.6522, grad_fn=<NllLossBackward0>)
out: tensor([6, 6, 1, 0, 3, 7, 0, 6, 8, 8, 9, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 2, 2, 8,
        8, 9, 0, 3, 8, 6, 4, 6, 6, 2, 0, 7, 4, 5, 6, 3, 1, 1, 2, 6, 7, 7, 4, 0,
        6, 2, 1, 3, 0, 4, 3, 7, 8, 3, 1, 2, 8, 2, 0, 5])
tar:  tensor([6, 2, 1, 2, 3, 7, 2, 6, 8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7,
        8, 9, 0, 3, 8, 6, 4, 6, 6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0,
        6, 2, 1, 3, 0, 4, 2, 7, 8, 3, 1, 2, 8, 0, 8, 3])
loss:  tensor(0.5986, grad_fn=<NllLossBackward0>)
