In [None]:
%reload_ext autoreload
%autoreload 2

import os
import re
import cv2
import numpy as np

import argus
from argus import Model
from argus import load_model
from argus.engine import State
from argus.callbacks import MonitorCheckpoint, EarlyStopping, LoggingToFile

from src.utils import rle_decode

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader

from src.models.unet_flex import UNetFlexProb
from src.losses import ShipLoss
from src.metrics import ShipIOUT
from src.utils import  filename_without_ext
from src.transforms import ProbOutputTransform, test_transforms, train_transforms
from src.dataset import ShipDataset, ShipDatasetFolds
from src.lr_scheduler import ReduceLROnPlateau
from src.utils import get_best_model_path

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
IMG_SIZE = (256, 256)
EPOCHS = 250
BATCH_SIZE = 32
TRAIN_SPLIT = 0.8  # Fraction to use for train
LR = 1e-5

save_path = '../data/models/linknet34_004b'

IMG_EXT = '.jpg'
TRG_EXT = '.png'

In [None]:
imgs_dir = '../data/datasets/ships_small/train_small/images/'
trgs_dir = '../data/datasets/ships_small/train_small/targets/'

imgs = os.listdir(trgs_dir)  # Only not empty
img_ids = [filename_without_ext(img) for img in imgs]

print("Images:", len(img_ids))

In [None]:
SKIP_EMPTY_PROB = 0.9

train_trns = train_transforms(size=IMG_SIZE, skip_empty_prob=SKIP_EMPTY_PROB, sigma_g=10)
val_trns = test_transforms(size=IMG_SIZE)

In [None]:
def get_data_loaders(batch_size, ids):
    n_images = len(ids)
    n_train = round(n_images*TRAIN_SPLIT)
    train_dataset = ShipDataset(ids[:n_train], imgs_dir=imgs_dir, trgs_dir=trgs_dir, masks=True, **train_trns)
    val_dataset = ShipDataset(ids[n_train:], imgs_dir=imgs_dir, trgs_dir=trgs_dir, masks=True, **val_trns)
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=8)
    return train_loader, val_loader

train_loader, val_loader = get_data_loaders(batch_size=BATCH_SIZE,
                                            ids=img_ids)

In [None]:
def show_img(img):
    plt.figure(dpi=200)
    plt.imshow(img)
    plt.show()

def show_img_tensor(tensor):
    img = np.moveaxis(tensor.numpy(), 0, -1)[:,:,::-1]
    show_img(img)


def show_in_cols(masks_list, n_col=3):
    n_masks = len(masks_list)
    n_row = n_masks//n_col
    if n_masks % n_col > 0:
        n_row += 1
    
    f, ax = plt.subplots(n_row, n_col, figsize=(18,6))
    for i in range(n_masks):
        a = ax[i//n_col][i%n_col]
        a.imshow(masks_list[i])
        a.axis('off')


def show_trg_tensor(tensor):
    masks = tensor.numpy()
    masks_list = [masks[i, :, :] for i in range(masks.shape[0])]
    show_in_cols(masks_list)

In [None]:
for img, trg in train_loader:
    print(img.shape)

In [None]:
n_images_to_draw = 2

for img, trg in train_loader:
    for i in range(n_images_to_draw):
        
        img_i = img[i, ...]
        trg_i = trg[i, ...]
        print(trg_i.shape)
        print(trg_i[0, ...].shape, np.count_nonzero(trg_i[0, ...].data.numpy()), trg_i[0, ...].shape)
        #show_img(trg_i[0, ...].data.numpy())
        show_img_tensor(img_i)
        show_trg_tensor(trg_i)
    break

In [None]:
class ShipMetaModel(Model):
    nn_module = {
        'UNetFlexProb': UNetFlexProb,
    }
    loss = {
        'ShipLoss': ShipLoss
    }
    prediction_transform = {
        'ProbOutputTransform': ProbOutputTransform
    }

In [None]:
from src.models.resnet_blocks import resnet34

params = {'nn_module': ('UNetFlexProb', {
            'num_classes': 5,
            'num_channels': 3,
            'blocks': resnet34,
            'final': 'sigmoid',
            'skip_dropout': True,
            'dropout_2d': 0.2,
            'is_deconv': True,
            'pretrain': 'resnet34',
            'pretrain_layers': [True for _ in range(5)]
            }),
        'loss': ('ShipLoss', {
            'fb_weight': 0.25,  # Need tuning!
            'fb_beta': 1,
            'bce_weight': 0.25,
            'prob_weight': 0.25,
            'mse_weight': 1.0
            }),
        'prediction_transform': ('ProbOutputTransform', {
            'segm_thresh': 0.5,
            'prob_thresh': 0.5
            }),
        'optimizer': ('Adam', {'lr': LR}),
        'device': 'cuda'
    }

model = ShipMetaModel(params)
callbacks = [MonitorCheckpoint(save_path, monitor='val_iout', max_saves=2, copy_last=True),
             EarlyStopping(monitor='val_iout', patience=60),
             ReduceLROnPlateau(monitor='val_iout', patience=10, factor=0.5, min_lr=1e-8),
             LoggingToFile(os.path.join(save_path, 'log.txt'))]

In [None]:
# pretrain_path = '../../data/models/linknet18_001/'

# if pretrain_path is not None:
#     model = load_model(get_best_model_path(pretrain_path))

In [None]:
model.fit(train_loader,
          val_loader=val_loader,
          max_epochs=EPOCHS,
          callbacks=callbacks,
          metrics=['iout'])

## Validate

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
save_path = '/workdir/data/models/linknet34_folds_007/fold_0/'
model = load_model(get_best_model_path(save_path))
model.prediction_transform.prob_thresh = 0.001
model.nn_module.eval()

In [None]:
folds_path = '/workdir/data/kfolds.pk'
val_folds = [0]
val_dataset = ShipDatasetFolds(folds_path, val_folds, imgs_dir=imgs_dir, trgs_dir=trgs_dir, masks=True, **val_trns)

val_loader = DataLoader(val_dataset, batch_size=4,
                        shuffle=False, num_workers=16)

model.validate(val_loader, metrics=['iout'])

linknet34_folds_006/fold_0/ val_train_iout 0.864456 'val_loss': 0.702380567754134, 'val_iout': 0.4743406742794948
linknet34_folds_006/fold_1/ val_train_iout 0.873608 'val_loss': 0.7951664915043155, 'val_iout': 0.5019373756414277
linknet34_folds_006/fold_2/ val_train_iout 0.860670 'val_loss': 0.7101081279212463, 'val_iout': 0.49924292177466073
linknet34_folds_006/fold_3/ val_train_iout 0.840669 'val_loss': 0.5926205536820858, 'val_iout': 0.526805161714539
linknet34_folds_006/fold_4/ val_train_iout 0.827383 'val_loss': 0.5521928212950905, 'val_iout': 0.5329581093189941

linknet34_folds_008/fold_0 {'val_loss': 0.1019397437081371, 'val_iout': 0.5002803833605218}