In [1]:
%reload_ext autoreload
%autoreload 2

import os
import cv2
import time
import torch
import numpy as np
import matplotlib.pyplot as plt

from model import LinkNet34

In [2]:
epochs = 500
batch = 8
lr = 1e-4
model_path = './data/models/resnet34_003.pth'
gamma = 0.35
brightness = 2.0
colors = 0.25
bce_coef = 10

In [3]:
import torchvision.transforms as transforms

np.random.seed(123)

def img_size(image: np.ndarray):
    return (image.shape[1], image.shape[0])

def gauss_noise(img, sigma_squared):
    w, h = img_size(img)
    gauss = np.random.normal(-sigma_squared, sigma_squared, (h, w, 3))
    gauss = gauss.reshape(h, w, 3)
    print(gauss.max(), img.max())
    img = img + gauss
    return img.astype(np.uint8)

class GaussNoise(object):
    def __init__(self, sigma):
        self.sigma = sigma
    
    def __call__(self, img):
        return gauss_noise(img, self.sigma)

class AugmentColor(object):
    def __init__(self, gamma, brightness, colors):
        self.gamma = gamma
        self.brightness = brightness
        self.colors = colors

    def __call__(self, img):
        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift gamma
            random_gamma = torch.from_numpy(np.random.uniform(1-self.gamma, 1+self.gamma, 1)).type(torch.cuda.FloatTensor)
            img  = img  ** random_gamma

        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift brightness
            random_brightness =  torch.from_numpy(np.random.uniform(1/self.brightness, self.brightness, 1))\
                .type(torch.cuda.FloatTensor)
            img  =  img * random_brightness

        p = np.random.uniform(0, 1, 1)
        if p > 0.5:
            # randomly shift color
            random_colors =  torch.from_numpy(np.random.uniform(1-self.colors, 1+self.colors, 3))\
                .type(torch.cuda.FloatTensor)
            white = torch.ones([np.shape(img)[1], np.shape(img)[2]]).type(torch.cuda.FloatTensor)
            color_image = torch.stack([white * random_colors[i] for i in range(3)], dim=0)
            img  *= color_image

        # saturate
        img  = torch.clamp(img,  0, 1)
        return img

class ToTensor(object):
    def __init__(self):
        self.transform = transforms.ToTensor()

    def __call__(self, sample):
        return self.transform(sample).type(torch.cuda.FloatTensor)

train_transform = transforms.Compose([
    #GaussNoise(10),
    ToTensor(),
    AugmentColor(gamma, brightness, colors)
])

In [4]:
from torch.utils.data import Dataset, DataLoader

class LyftDataset(Dataset):
    def __init__(self, data_dir, data_dir2, img_transform=None, trg_transform=None):
        img_dir = os.path.join(data_dir, "CameraRGB")
        trg_dir = os.path.join(data_dir, "CameraSeg")
        img_paths = sorted(os.listdir(img_dir))
        trg_paths = sorted(os.listdir(trg_dir))
        img_paths = [os.path.join(img_dir, path) for path in img_paths]
        trg_paths = [os.path.join(trg_dir, path) for path in trg_paths]
        
        
        img_dir2 = os.path.join(data_dir2, "CameraRGB")
        trg_dir2 = os.path.join(data_dir2, "CameraSeg")
        img_paths2 = sorted(os.listdir(img_dir2))[::3]
        trg_paths2 = sorted(os.listdir(trg_dir2))[::3]
        img_paths2 = [os.path.join(img_dir2, path) for path in img_paths2]
        trg_paths2 = [os.path.join(trg_dir2, path) for path in trg_paths2]
        img_paths.extend(img_paths2)
        trg_paths.extend(trg_paths2)
        print(len(img_paths), len(trg_paths))
        
        self.imgs = [cv2.imread(path) for path in img_paths]
        self.imgs = [cv2.copyMakeBorder(img,4,4,0,0,cv2.BORDER_REFLECT) for img in self.imgs]
        self.trgs = [self._fix_trg(cv2.imread(path)) for path in trg_paths]
        self.trgs = [cv2.copyMakeBorder(trg,4,4,0,0,cv2.BORDER_REFLECT) for trg in self.trgs]
        self.img_transform = img_transform
        self.trg_transform = trg_transform
    
    def _fix_trg(self, trg):
        h, w, _ = trg.shape
        mask = np.zeros((h+2, w+2, 1), dtype=np.uint8)
        cv2.floodFill(trg, mask, (w//2, h-1), (0,0,0))
        vehicles = (trg[:, :, 2]==10).astype(np.float)
        road = (trg[:, :, 2]==6).astype(np.float)
        road += (trg[:, :, 2]==7).astype(np.float)
        bg = np.ones(vehicles.shape) - vehicles - road
        return np.stack([bg, road, vehicles], axis=2)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        trg = self.trgs[idx]
        if self.img_transform is not None:
            img = self.img_transform(img)
        if self.trg_transform is not None:
            trg = self.trg_transform(trg)
        return img, trg

train_dataset = LyftDataset('./data/', './data/dataset/', train_transform, transforms.ToTensor())

2515 2515


In [None]:
#img, trg = train_dataset.__getitem__(80)
def show_img(img):
    plt.figure(dpi=300)
    plt.imshow(img)
    plt.show()

In [None]:
for i in range(10):
    for data in train_dataset:
        print(data[0].shape, data[1].shape)
        show_img(np.moveaxis(data[0].cpu().numpy(), 0, -1))
        break

In [5]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
from loss import BCEDiceLoss
import torch.optim as optim
loss_function = BCEDiceLoss(bce_coef).to(device)
model = LinkNet34(3, 3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [8]:
load_model_path = './data/models/resnet34_002_cpt.pth'
state = torch.load(load_model_path)
model.load_state_dict(state)

In [9]:
torch.cuda.synchronize()

losses = []
best_loss = 1e19
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    s_time = time.time()
    for data in train_loader:
        # get the inputs
        img = data[0]
        trg = data[1]
        img = img.type(torch.cuda.FloatTensor)
        trg = trg.type(torch.cuda.FloatTensor)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred = model(img)
        loss = loss_function(pred, trg)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Epoch:", epoch+1, "loss:", running_loss, "time:", time.time()-s_time, "s")
    if running_loss < best_loss:
        torch.save(model.state_dict(), model_path[:-4]+'_cpt'+model_path[-4:])
        best_loss = running_loss
        print("Checkpoint saved")
    losses.append(running_loss)
    running_loss = 0.0

print('Finished Training')
torch.save(model.state_dict(), model_path)
print(losses)

Epoch: 1 loss: 43.44235575199127 time: 239.30349159240723 s
Checkpoint saved
Epoch: 2 loss: 26.578725691884756 time: 239.43996047973633 s
Checkpoint saved
Epoch: 3 loss: 23.246523153036833 time: 239.24568557739258 s
Checkpoint saved
Epoch: 4 loss: 19.86813148483634 time: 238.11513781547546 s
Checkpoint saved
Epoch: 5 loss: 18.983006946742535 time: 238.20205163955688 s
Checkpoint saved
Epoch: 6 loss: 17.822439590469003 time: 238.23431539535522 s
Checkpoint saved
Epoch: 7 loss: 16.21480198018253 time: 238.25354933738708 s
Checkpoint saved
Epoch: 8 loss: 15.275330897420645 time: 238.1876220703125 s
Checkpoint saved
Epoch: 9 loss: 15.054051768034697 time: 238.03970050811768 s
Checkpoint saved
Epoch: 10 loss: 14.218648981302977 time: 237.84919786453247 s
Checkpoint saved
Epoch: 11 loss: 13.643306516110897 time: 237.83871960639954 s
Checkpoint saved
Epoch: 12 loss: 13.2335836738348 time: 237.78394293785095 s
Checkpoint saved
Epoch: 13 loss: 12.904990680515766 time: 237.91810822486877 s
Check

Epoch: 116 loss: 9.021850612014532 time: 237.6672487258911 s
Epoch: 117 loss: 7.390742231160402 time: 237.68674659729004 s
Epoch: 118 loss: 6.671435027383268 time: 237.691077709198 s
Epoch: 119 loss: 6.249731785617769 time: 237.59751915931702 s
Epoch: 120 loss: 5.965912719257176 time: 237.75020933151245 s
Checkpoint saved
Epoch: 121 loss: 5.825400850735605 time: 237.67586970329285 s
Checkpoint saved
Epoch: 122 loss: 5.787101532332599 time: 237.65906882286072 s
Checkpoint saved
Epoch: 123 loss: 5.734213276766241 time: 237.69465398788452 s
Checkpoint saved
Epoch: 124 loss: 5.663996180519462 time: 237.74854111671448 s
Checkpoint saved
Epoch: 125 loss: 5.675771647132933 time: 237.64060044288635 s
Epoch: 126 loss: 5.6344291446730494 time: 237.68607759475708 s
Checkpoint saved
Epoch: 127 loss: 5.706817157566547 time: 237.6790907382965 s
Epoch: 128 loss: 5.795093893073499 time: 237.75051069259644 s
Epoch: 129 loss: 5.801473788917065 time: 237.7179775238037 s
Epoch: 130 loss: 5.916758148930967

Epoch: 239 loss: 4.41332864202559 time: 237.78272032737732 s
Checkpoint saved
Epoch: 240 loss: 4.445261163637042 time: 237.7391972541809 s
Epoch: 241 loss: 4.319328165613115 time: 237.78627347946167 s
Checkpoint saved
Epoch: 242 loss: 4.297984664328396 time: 237.83015942573547 s
Checkpoint saved
Epoch: 243 loss: 4.447765822522342 time: 237.85625529289246 s
Epoch: 244 loss: 4.400516036432236 time: 237.82069897651672 s
Epoch: 245 loss: 4.459273950196803 time: 237.7625949382782 s
Epoch: 246 loss: 4.451233308296651 time: 237.81122994422913 s
Epoch: 247 loss: 4.541255367454141 time: 237.68673372268677 s
Epoch: 248 loss: 4.580882400274277 time: 237.75545382499695 s
Epoch: 249 loss: 4.55484358407557 time: 238.19212746620178 s
Epoch: 250 loss: 4.566498324740678 time: 237.8260943889618 s
Epoch: 251 loss: 4.630310460925102 time: 237.91403365135193 s
Epoch: 252 loss: 4.621210774406791 time: 237.90225172042847 s
Epoch: 253 loss: 4.603346002288163 time: 237.78777742385864 s
Epoch: 254 loss: 4.54667

Epoch: 367 loss: 3.76908414112404 time: 238.06935811042786 s
Epoch: 368 loss: 3.8330118032172322 time: 237.93293261528015 s
Epoch: 369 loss: 3.763500540982932 time: 237.95722317695618 s
Epoch: 370 loss: 3.75692012347281 time: 237.90991759300232 s
Epoch: 371 loss: 3.6993210869841278 time: 237.90945315361023 s
Checkpoint saved
Epoch: 372 loss: 3.727281963452697 time: 237.93618369102478 s
Epoch: 373 loss: 3.757761113345623 time: 237.9176561832428 s
Epoch: 374 loss: 3.8357441034168005 time: 238.01154232025146 s
Epoch: 375 loss: 3.834840157534927 time: 237.93429374694824 s
Epoch: 376 loss: 3.72947380784899 time: 237.83316659927368 s
Epoch: 377 loss: 3.6375322197563946 time: 237.92313504219055 s
Checkpoint saved
Epoch: 378 loss: 3.5991888246499 time: 237.86762762069702 s
Checkpoint saved
Epoch: 379 loss: 3.6230952790938318 time: 237.94856810569763 s
Epoch: 380 loss: 3.6794544854201376 time: 237.94164299964905 s
Epoch: 381 loss: 3.6885277181863785 time: 237.95157170295715 s
Epoch: 382 loss: 3

Epoch: 494 loss: 3.2037081569433212 time: 238.06627750396729 s
Epoch: 495 loss: 3.174900853075087 time: 238.00922298431396 s
Epoch: 496 loss: 3.198228213004768 time: 238.017418384552 s
Epoch: 497 loss: 3.197917287237942 time: 237.92114543914795 s
Epoch: 498 loss: 3.1813455750234425 time: 237.9789218902588 s
Epoch: 499 loss: 3.1363280713558197 time: 237.94813179969788 s
Epoch: 500 loss: 3.1740955845452845 time: 237.96405005455017 s
Finished Training
[43.44235575199127, 26.578725691884756, 23.246523153036833, 19.86813148483634, 18.983006946742535, 17.822439590469003, 16.21480198018253, 15.275330897420645, 15.054051768034697, 14.218648981302977, 13.643306516110897, 13.2335836738348, 12.904990680515766, 12.788451502099633, 13.906510373577476, 13.411184709519148, 12.194152789190412, 11.928161235526204, 11.53395333327353, 11.228587793186307, 11.050389908254147, 10.796904236078262, 10.676769684068859, 11.225154586136341, 10.61463712155819, 10.422103626653552, 10.12984485924244, 10.01987105421

In [None]:
idx = 650
img_test = cv2.imread('./data/CameraRGB/'+str(idx)+'.png')
show_img(img_test)
trg_test = train_dataset._fix_trg(cv2.imread('./data/CameraSeg/'+str(idx)+'.png'))
show_img(trg_test)
img_test = np.moveaxis(img_test, -1, 0)
img_test = img_test[np.newaxis,:,:,:]
img_test = torch.from_numpy(img_test).type(torch.cuda.FloatTensor)
pred = model(img_test)
pred = pred.cpu().data[0,:,:,:].numpy()

pred_img = np.moveaxis(pred, 0, -1)

show_img(pred_img)
show_img(np.abs(trg_test-pred_img))

In [None]:
%%timeit
img_test = cv2.imread('./data/CameraRGB/'+str(idx)+'.png')
img_test = np.moveaxis(img_test, -1, 0)
img_test = img_test[np.newaxis,:,:,:]
img_test = torch.from_numpy(img_test).type(torch.cuda.FloatTensor)
pred = model(img_test)
pred = pred.cpu().data[0,:,:,:].numpy()
pred_img = np.moveaxis(pred, 0, -1)