In [2]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import matplotlib.pyplot as plt
from tqdm import tqdm

from helper import *
from models.enet.model import *

torch.cuda.set_device(0)

In [3]:
mean = [0.28689554, 0.32513303, 0.28389177]
std = [0.18696375, 0.19017339, 0.18720214]
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean = mean, std = std)
            ])
dataset = CityscapesDataset(transform = transform)
dataloader = data.DataLoader(dataset, batch_size = 2, shuffle = True, drop_last = True)

In [3]:
net = ENet(num_classes = 1)
net = net.cuda()

In [4]:
optimizer = torch.optim.Adam(net.parameters(), lr = 1e-4)
criterion = nn.BCEWithLogitsLoss()

def dice_loss(inp, target):
    smooth = 1.

    iflat = inp.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return -((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

In [5]:
def train(model, train_loader, epoch, loss_function, optimiser, savename):
    model.train()
    loop = tqdm(train_loader)
    loss_min = 100
    for data, target in loop:
        data, target = data.float().cuda(), target.float().cuda()
        
        optimiser.zero_grad()
        prediction = model(data)
        prediction = prediction.squeeze(1)
        
        loss = loss_function(prediction, target) + dice_loss(prediction, target)
        losses.append(loss.item())
        
        loss.backward()
        optimiser.step()
        
        loop.set_description('Epoch {}/{}'.format(epoch + 1, num_epochs))
        loop.set_postfix(loss = loss.item())
        
        if loss.item() < loss_min :
            loss_min = loss.item()
            torch.save(net.state_dict(), savename)

In [6]:
num_epochs = 2
losses = []
for epoch in range(num_epochs) :
    train(model = net, train_loader = dataloader, loss_function = criterion, optimiser = optimizer, epoch = epoch)

Epoch 1/2:  34%|███▍      | 4024/11736 [33:18<1:05:48,  1.95it/s, loss=0.836]

KeyboardInterrupt: 

In [10]:
torch.save(net.state_dict(), 'testing.pt')

In [4]:
net = ENet(num_classes = 1)
net.load_state_dict(torch.load('road_bce_dice.pt', map_location = 'cpu'))

<All keys matched successfully>

In [8]:
net.cuda()
net.eval()
img, mask = next(iter(dataloader))
img = img.float().cuda()
out = net(img)

RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 10.91 GiB total capacity; 9.96 GiB already allocated; 14.31 MiB free; 413.47 MiB cached)

In [None]:
unorm = UnNormalize(mean = mean, std = std)
fig = plt.figure()
plt.subplot(3, 1, 1)
img2 = unorm(img)
img2 = img2.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()
plt.imshow(img2[0])
plt.subplot(3, 1, 2)
plt.imshow(mask[0])
out2 = out.squeeze(1).detach().cpu().numpy()
print(out2.shape)
plt.subplot(3, 1, 3)
plt.imshow(out2[0])
plt.show()