# unet training

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from tqdm import tqdm
from unet import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# cifar10
batch_size = 64
num_train = 10000

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
     
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainset = torch.utils.data.Subset(trainset, range(num_train))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = UNet(in_channels=3, out_channels=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
summary(model, (3, 1920//2, 1080//2))

cuda:0
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 960, 540]           1,728
       BatchNorm2d-2         [-1, 64, 960, 540]             128
              ReLU-3         [-1, 64, 960, 540]               0
            Conv2d-4         [-1, 64, 960, 540]          36,864
       BatchNorm2d-5         [-1, 64, 960, 540]             128
              ReLU-6         [-1, 64, 960, 540]               0
      DoubleConv2d-7         [-1, 64, 960, 540]               0
         MaxPool2d-8         [-1, 64, 480, 270]               0
            Conv2d-9        [-1, 128, 480, 270]          73,728
      BatchNorm2d-10        [-1, 128, 480, 270]             256
             ReLU-11        [-1, 128, 480, 270]               0
           Conv2d-12        [-1, 128, 480, 270]         147,456
      BatchNorm2d-13        [-1, 128, 480, 270]             256
             ReLU-14        [-1,

In [4]:
# training loop
for epoch in range(20): 
    with tqdm(trainloader, unit="batch") as tepoch:

        running_loss = 0.0
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            
            data, target = data.to(device), target.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(data).to(device)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            # calculating accuracy
            predictions = outputs.argmax(dim=1, keepdim=True)
            correct = (predictions == target).sum().item()
            accuracy = correct / batch_size

            # print statistics
            tepoch.set_postfix(loss=loss.item(), accuracy=100.*accuracy)

print('Finished Training')

Epoch 0:   0%|          | 0/157 [00:00<?, ?batch/s]


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [64]

In [None]:
# saving
# PATH = './cifar_net.pth'
# torch.save(net.state_dict(), PATH)

In [None]:
# testing
dataiter = iter(testloader)
images, labels = dataiter.next()
images, labels = images.to(device), labels.to(device)

# print images
plt.figure(figsize=(20,20))
plt.imshow(torchvision.utils.make_grid(images.cpu()).permute(1, 2, 0))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))