In [None]:
%reload_ext autoreload
%autoreload 2

import os
import cv2
import time
import torch
import numpy as np
import matplotlib.pyplot as plt

from model import LinkNet34

In [None]:
epochs = 500
batch = 8
lr = 1e-4
model_path = './data/models/resnet34_test.pth'
gamma = 0.35
brightness = 2.0
colors = 0.25
bce_w = 0
car_w = 1
train_dirs = ['data/train/', 'data/dataset/', 'data/carla-capture-20180528/']
val_dirs=['data/carla-capture-20181305/']

In [None]:
import torchvision.transforms as transforms

np.random.seed(123)

def img_size(image: np.ndarray):
    return (image.shape[1], image.shape[0])

def gauss_noise(img, sigma_squared):
    w, h = img_size(img)
    gauss = np.random.normal(-sigma_squared, sigma_squared, (h, w, 3))
    gauss = gauss.reshape(h, w, 3)
    print(gauss.max(), img.max())
    img = img + gauss
    return img.astype(np.uint8)

class GaussNoise(object):
    def __init__(self, sigma):
        self.sigma = sigma
    
    def __call__(self, img):
        return gauss_noise(img, self.sigma)

class AugmentColor(object):
    def __init__(self, gamma, brightness, colors):
        self.gamma = gamma
        self.brightness = brightness
        self.colors = colors

    def __call__(self, img):
        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift gamma
            random_gamma = torch.from_numpy(np.random.uniform(1-self.gamma, 1+self.gamma, 1)).type(torch.cuda.FloatTensor)
            img  = img  ** random_gamma

        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift brightness
            random_brightness =  torch.from_numpy(np.random.uniform(1/self.brightness, self.brightness, 1))\
                .type(torch.cuda.FloatTensor)
            img  =  img * random_brightness

        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift color
            random_colors =  torch.from_numpy(np.random.uniform(1-self.colors, 1+self.colors, 3))\
                .type(torch.cuda.FloatTensor)
            white = torch.ones([np.shape(img)[1], np.shape(img)[2]]).type(torch.cuda.FloatTensor)
            color_image = torch.stack([white * random_colors[i] for i in range(3)], dim=0)
            img  *= color_image

        # saturate
        img  = torch.clamp(img,  0, 1)
        return img

class ToTensor(object):
    def __init__(self):
        self.transform = transforms.ToTensor()

    def __call__(self, sample):
        return self.transform(sample).type(torch.cuda.FloatTensor)

train_transform = transforms.Compose([
    #GaussNoise(10),
    ToTensor(),
    AugmentColor(gamma, brightness, colors)
])

val_transform = transforms.Compose([
    #GaussNoise(10),
    ToTensor(),
])

In [None]:
from torch.utils.data import Dataset, DataLoader, ConcatDataset

class LyftDataset(Dataset):
    def __init__(self, data_dir, img_transform=None, trg_transform=None, read=True):
        img_dir = os.path.join(data_dir, "CameraRGB")
        trg_dir = os.path.join(data_dir, "CameraSeg")
        img_paths = sorted(os.listdir(img_dir))
        trg_paths = sorted(os.listdir(trg_dir))
        self.img_paths = [os.path.join(img_dir, path) for path in img_paths]
        self.trg_paths = [os.path.join(trg_dir, path) for path in trg_paths]
        if read: 
            self.imgs = [cv2.imread(path) for path in self.img_paths]
            self.trgs = [self._fix_trg(cv2.imread(path)) for path in self.trg_paths]
        self.img_transform = img_transform
        self.trg_transform = trg_transform
        self.read = read
    
    def _fix_trg(self, trg):
        h, w, _ = trg.shape
        mask = np.zeros((h+2, w+2, 1), dtype=np.uint8)
        cv2.floodFill(trg, mask, (w//2, h-1), (0,0,0))
        vehicles = (trg[:, :, 2]==10).astype(np.float)
        road = (trg[:, :, 2]==6).astype(np.float)
        road += (trg[:, :, 2]==7).astype(np.float)
        bg = np.ones(vehicles.shape) - vehicles - road
        return np.stack([bg, road, vehicles], axis=2)

    def __len__(self):
        if self.read:
            return len(self.imgs)
        else:
            return len(self.img_paths)

    def __getitem__(self, idx):
        if self.read:
            img = self.imgs[idx]
            trg = self.trgs[idx]
        else:
            img = cv2.imread(self.img_paths[idx])
            trg = self._fix_trg(cv2.imread(self.trg_paths[idx]))
        if self.img_transform is not None:
            img = self.img_transform(img)
        if self.trg_transform is not None:
            trg = self.trg_transform(trg)
        return img, trg

train_datasets = [LyftDataset(train_dir, train_transform, transforms.ToTensor(), False) for train_dir in train_dirs]
train_dataset = ConcatDataset(train_datasets)
print("Train imgs:", train_dataset.__len__())
val_datasets = [LyftDataset(val_dir, val_transform, transforms.ToTensor(), False) for val_dir in val_dirs]
val_dataset = ConcatDataset(val_datasets)
print("Train imgs:", val_dataset.__len__())

In [None]:
#img, trg = train_dataset.__getitem__(80)
def show_img(img):
    plt.figure(dpi=300)
    plt.imshow(img)
    plt.show()

In [None]:
for i in range(10):
    for data in train_dataset:
        print(data[0].shape, data[1].shape)
        show_img(np.moveaxis(data[0].cpu().numpy(), 0, -1))
        break

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch, shuffle=False)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from loss import LyftLoss
import torch.optim as optim
train_loss = LyftLoss(bce_w=0, car_w=2, other_w=0.5).to(device)
val_loss = LyftLoss(bce_w=0, car_w=1, other_w=0).to(device)
model = LinkNet34(3, 3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
load_model_path = './data/models/resnet34_007_cpt1.pth'
state = torch.load(load_model_path)
model.load_state_dict(state)

In [None]:
torch.cuda.synchronize()

def val():
    c_loss = 0
    with torch.no_grad():
        for img, trg in val_loader:
            img = img.type(torch.cuda.FloatTensor)
            trg = trg.type(torch.cuda.FloatTensor)
            pred = model(img)
            loss = val_loss(pred, trg)
            c_loss += loss.item()
        c_loss /= val_dataset.__len__()
    return c_loss

losses = []
best_loss = val()
print("Start val loss:", best_loss)
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    s_time = time.time()
    for img, trg in train_loader:
        # get the inputs
        img = img.type(torch.cuda.FloatTensor)
        trg = trg.type(torch.cuda.FloatTensor)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred = model(img)
        loss = train_loss(pred, trg)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    running_loss /= train_dataset.__len__()
    val = val()
    print("Epoch:", epoch+1, "train loss:", running_loss, "val loss:", val,
          "time:", time.time()-s_time, "s")
    if val < best_loss:
        torch.save(model.state_dict(), model_path[:-4]+'_cpt'+model_path[-4:])
        best_loss = val
        print("Checkpoint saved")
    losses.append([running_loss, val])
    running_loss = 0.0

print('Finished Training')
torch.save(model.state_dict(), model_path)
print(losses)

In [None]:
def show_pred(fname):
    img_test = cv2.imread(fname)
    show_img(img_test)
    img_test = np.moveaxis(img_test, -1, 0)
    img_test = img_test[np.newaxis,:,:,:]
    img_test = torch.from_numpy(img_test).type(torch.cuda.FloatTensor)
    pred = model(img_test)
    pred = pred.cpu().data[0,:,:,:].numpy()
    pred_img = np.moveaxis(pred, 0, -1)
    show_img(pred_img)
    return pred_img

In [None]:
idx = 473

pred_img = show_pred('./data/train/CameraRGB/'+str(idx)+'.png')
trg_test = train_datasets[0]._fix_trg(cv2.imread('./data/train/CameraSeg/'+str(idx)+'.png'))
show_img(trg_test)
show_img(np.abs(trg_test-pred_img[4:604,:,:]))

In [None]:
%%timeit
img_test = cv2.imread('./data/CameraRGB/'+str(idx)+'.png')
img_test = np.moveaxis(img_test, -1, 0)
img_test = img_test[np.newaxis,:,:,:]
img_test = torch.from_numpy(img_test).type(torch.cuda.FloatTensor)
pred = model(img_test)
pred = pred.cpu().data[0,:,:,:].numpy()
pred_img = np.moveaxis(pred, 0, -1)