In [None]:
import os
import pickle
import pandas as pd
from pathlib import Path
from pku_autonomous_driving import io, util, dataset, resnet, centernet, training, graphics, transform, const

import importlib
importlib.reload(io)
importlib.reload(util)
importlib.reload(dataset)
importlib.reload(resnet)
importlib.reload(centernet)
importlib.reload(training)
importlib.reload(graphics)
importlib.reload(transform)
importlib.reload(const)

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torchvision
from pku_autonomous_driving.transform import CropBottomHalf, CropFar, PadByMean, Resize, Normalize, DropPointsAtOutOfScreen, CreateMask, CreateRegr, ToCHWOrder
from pku_autonomous_driving.const import IMG_WIDTH, IMG_HEIGHT, MODEL_SCALE

near_transform = torchvision.transforms.Compose([
    CropBottomHalf(),
    PadByMean(),
    Resize(IMG_WIDTH, IMG_HEIGHT),
    Normalize(),
    DropPointsAtOutOfScreen(IMG_WIDTH, IMG_HEIGHT),
    CreateMask(IMG_WIDTH, IMG_HEIGHT, MODEL_SCALE),
    CreateRegr(IMG_WIDTH, IMG_HEIGHT, MODEL_SCALE),
    ToCHWOrder()
])

far_transform = torchvision.transforms.Compose([
    CropFar(IMG_WIDTH, IMG_HEIGHT),
    Normalize(),
    DropPointsAtOutOfScreen(IMG_WIDTH, IMG_HEIGHT),
    CreateMask(IMG_WIDTH, IMG_HEIGHT, MODEL_SCALE),
    CreateRegr(IMG_WIDTH, IMG_HEIGHT, MODEL_SCALE),
    ToCHWOrder()
])

transforms = {
    'NEAR': near_transform,
    'FAR': far_transform
}

train_transform = transforms[os.environ.get("TRANSFORM_TYPE", "NEAR")]

In [None]:
from pku_autonomous_driving.dataset import CarDataset, create_data_loader

train, dev = io.load_train_data()
train = train[:3]
dev = dev[:3]

train_dataset = CarDataset(train, transform=train_transform)
dev_dataset = CarDataset(dev, transform=train_transform)

train_loader = create_data_loader(train_dataset, batch_size=1)
dev_loader = create_data_loader(dev_dataset, batch_size=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

data = train_loader.dataset[0]
img, mask, regr = data["img"], data["mask"], data["regr"]
plt.figure(figsize=(16,16))
plt.imshow(np.rollaxis(img, 0, 3))
plt.show()

plt.figure(figsize=(16,16))
plt.imshow(mask)
plt.show()

plt.figure(figsize=(16,16))
plt.imshow(regr[-2])
plt.show()

In [None]:
base_model = resnet.resnext50_32x4d(pretrained=False)
model = centernet.CentResnet(base_model, 8)

setup_kwargs = {"model": model, "device": device}
if 'INITIAL_WEIGHTS' in os.environ:
    setup_kwargs["path"] = Path(os.environ["INITIAL_WEIGHTS"])
util.setup_model(**setup_kwargs)
print(setup_kwargs)

In [None]:
n_epochs = int(os.environ.get("N_EPOCHS", 6))

try:
    history = pickle.load(Path(os.environ["INITIAL_HISTORY"]).open('rb'))
    beg_epoch = math.ceil(history.index[-1])
except:
    history = pd.DataFrame()
    beg_epoch = 0
end_epoch = beg_epoch + n_epochs

In [None]:
%%time
from torch import optim
from torch.optim import lr_scheduler
import pandas as pd
import pickle


optimizer = optim.AdamW(model.parameters(), lr=0.001)
#optimizer =  RAdam(model.parameters(), lr = 0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=max(n_epochs, 10) * len(train_loader) // 3, gamma=0.1)

best_dev_loss = np.inf
for epoch in range(beg_epoch, end_epoch):
    training.clean_up()
    training.train(model, optimizer, exp_lr_scheduler, train_loader, epoch, device, history)
    training.evaluate(model, dev_loader, epoch, device, history)
    
    with open('./history.pickle', 'wb') as fp:
        pickle.dump(history , fp)
        
    cur_dev_loss = history['dev_loss'].dropna().iloc[-1]
    if cur_dev_loss < best_dev_loss:
        torch.save(model.state_dict(), './resnext50.pth')
        best_dev_loss = cur_dev_loss
    torch.save(model.state_dict(), f'./resnext50_{epoch}.pth')

In [None]:
history['train_loss'].iloc[:].plot();

In [None]:
series1 = history.dropna()['mask_loss']
plt.plot(series1.index, series1 ,label = 'mask loss');
series2 = history.dropna()['regr_loss']
plt.plot(series2.index, 30*series2,label = 'regr loss');
series3 = history.dropna()['dev_loss']
plt.plot(series3.index, series3,label = 'dev loss');
plt.show()

In [None]:
series = history.dropna()['dev_loss']
plt.scatter(series.index, series);

In [None]:
data = train_loader.dataset[0]
img, mask, regr = data["img"], data["mask"], data["regr"]

plt.figure(figsize=(16,16))
plt.title('Input image')
plt.imshow(np.rollaxis(img, 0, 3))
plt.show()

plt.figure(figsize=(16,16))
plt.title('Ground truth mask')
plt.imshow(mask)
plt.show()

output = model(torch.tensor(img[None]).to(device))
logits = output[0,0].data.cpu().numpy()

plt.figure(figsize=(16,16))
plt.title('Model predictions')
plt.imshow(logits)
plt.show()

plt.figure(figsize=(16,16))
plt.title('Model predictions thresholded')
plt.imshow(logits > 0)
plt.show()

In [None]:
import gc
gc.collect()

for idx in range(4):
    data = dev_loader.dataset[idx]
    img, mask, regr = data["img"], data["mask"], data["regr"]
    output = model(torch.tensor(img[None]).to(device)).data.cpu().numpy()

    coords_pred = util.extract_coords(data, output[0])
    coords_true = util.extract_coords(data)

    img = io.load_image(dev_loader.dataset.dataset[idx].image_id)
    fig, axes = plt.subplots(1, 2, figsize=(30,30))
    axes[0].set_title('Ground truth')
    axes[0].imshow(graphics.draw_coords(img, coords_true))
    axes[1].set_title('Prediction')
    axes[1].imshow(graphics.draw_coords(img, coords_pred))
    plt.show()