# Changes:
1. Use the 1st version UNET
2. Use focal loss function
3. Best Val Mean IOU so far: 0.7804 at epoch 209

## Install required packages if running on google colab

In [None]:
try:
    import torch
except:
    !pip install torch torchvision
    !pip install imageio
    !git clone https://github.com/allen-q/salt_oil.git
    !git clone https://github.com/allen-q/salt_net.git

In [None]:
cd salt_oil/

## Import required libs

In [None]:
import zipfile
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from skimage import io, transform
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as ply
import os
import imageio
from PIL import Image
import glob
import matplotlib.pyplot as plt
import time
import math
import datetime as dt
import pytz
import pickle
from salt_func_lib import *
from torchvision import transforms, utils
from skimage import io, transform
import datetime as dt
import sys
from optparse import OptionParser
import torch.backends.cudnn as cudnn
from torch import optim
from torch.optim import lr_scheduler
import copy
from io import BytesIO
% matplotlib inline
% load_ext autoreload
% autoreload 2

## Load Unet Modules

In [None]:
from pytorch_unet.eval import eval_net
from pytorch_unet.unet import UNet
from pytorch_unet.unet.unet_parts import *
from pytorch_unet.utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch

In [None]:
## Setup data type based on whether GPU is enabled

In [None]:
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor ## UNCOMMENT THIS LINE IF YOU'RE ON A GPU!
else:    
    dtype = torch.FloatTensor

In [None]:
print(f'Data Type set to: {dtype}')

## Create Global Variables

In [None]:
def init_global_variables():
    """initialize global variables such as db connection, logger etc."""
    global log
    log = get_logger('SaltNet')

In [None]:
init_global_variables()

## Prepare Data

### Load train and test data from npy files or from raw images if npy files not exist.

In [None]:
np_train_all, np_train_all_mask, X_test, misc_data = load_all_data()

### Remove black images

In [None]:
#black_img_ids = (np_train_all.max((1,2,3))==0)

#np_train_all = np_train_all[~black_img_ids]
#np_train_all_mask = np_train_all_mask[~black_img_ids]

In [None]:
np_train_all.shape

### Remove images with all black masks

In [None]:
#black_mask_ids = (np_train_all_mask.max((1,2,3))==0)
#np_train_all = np_train_all[~black_mask_ids]
#np_train_all_mask = np_train_all_mask[~black_mask_ids]

### Train Val data split

In [None]:
np_train_all = np.clip(np_train_all/255, 0, 1)
X_test = np.clip(X_test/255, 0, 1)

In [None]:
X_train_ids, X_val_ids = (
    train_test_split(np.arange(len(np_train_all)), 
                     test_size=0.20, 
                     random_state=42)
)

In [None]:
X_train = np_train_all[X_train_ids]
X_val = np_train_all[X_val_ids]
y_train = np_train_all_mask[X_train_ids]
y_val = np_train_all_mask[X_val_ids]
depth_train = (
    misc_data['df_train_all_depth']
    .reindex(np.array(misc_data['np_train_all_ids'])[X_train_ids])
)
depth_val = (
    misc_data['df_train_all_depth']
    .reindex(np.array(misc_data['np_train_all_ids'])[X_val_ids])
)
depth_test = (
    misc_data['df_train_all_depth']
    .reindex(np.array(misc_data['np_test_ids']))
)
#X_train_mean_img = X_train.mean(0).astype(np.float32)
X_train_mean_img = X_train.mean((0,1,2)).astype(np.float32)

all_data = {
    'X_train': X_train,
    'X_val': X_val,
    'y_train': y_train,
    'y_val': y_val,
    'X_test': X_test,
    'X_train_mean_img': X_train_mean_img
}

In [None]:
X_train.shape

### Create a Train Dataloader

In [None]:
composed_tsfm = transforms.Compose([Rescale(scale='random', max_scale=2),
                                    RandomCrop(101),
                                    Flip(orient='random')])

data_params = {'batch_size': 32,
               'shuffle': True,
               'drop_last': False}

train_dataLoader = (
    DataLoader(SaltDataset(X_train, y_train, depth_train,
                           X_train_mean_img, out_size=128,  out_ch=1,
                           transform=composed_tsfm), **data_params)
)

val_dataLoader = (
    DataLoader(SaltDataset(X_val, y_val, depth_val, 
                           X_train_mean_img, out_size=128, out_ch=1), **data_params)
)

dataloaders = {'train': train_dataLoader, 'val':val_dataLoader}

sample = iter(dataloaders['train']).__next__()

assert sample[0].shape == torch.Size([data_params['batch_size'], 1, 128, 128])
assert sample[1].shape == torch.Size([data_params['batch_size'], 101, 101])
assert sample[2].shape == torch.Size([data_params['batch_size']])
assert sample[3].shape == torch.Size([data_params['batch_size']])

In [None]:
i = 0
for X_batch, y_batch, d_batch, X_id in dataloaders['train']:
    i+=1
    if i>3:
        break
    X_orig = X_train[X_id[0]].squeeze()
    X_tsfm = X_batch[0,0].squeeze().cpu().detach().numpy()[13:114,13:114] + X_train_mean_img.squeeze()
    y_orig = y_train[X_id[0]].squeeze()
    y_tsfm = (y_batch[0].squeeze().cpu().detach().numpy())
    plot_img_mask_pred([X_orig, X_tsfm, y_orig, y_tsfm],
                       [f'X Original-{X_id[0]}', 'X Transformed', 'y Original', 'y Transformed'])

In [None]:
X_train.shape

### Create a Train Dataloader for sanity check

In [None]:
data_params = {'batch_size': 2,
               'shuffle': True,
               'drop_last': False}

train_dataLoader = (
    DataLoader(SaltDataset(X_train[:4], y_train[:4], depth_train[:4],
                           X_train_mean_img, out_size=128, out_ch=1,
                           transform=None), **data_params)
)

val_dataLoader = (
    DataLoader(SaltDataset(X_val[:4], y_val[:4], depth_val[:4], 
                           X_train_mean_img, out_size=128, out_ch=1), **data_params)
)

dataloaders = {'train': train_dataLoader, 'val':val_dataLoader}


In [None]:
t = iter(train_dataLoader).__next__()

In [None]:
X_batch, y_batch, d_batch, X_id = t

In [None]:
y_batch.view(y_batch.shape[0],-1).mean(1)==1

In [None]:
torch.tensor([0.4,0.6,1]).type(bool)

In [None]:
def train_model(model, dataloaders, criterion1, criterion2, optimizer, scheduler, model_save_name, other_data={}, 
                num_epochs=25, print_every=2, save_model_every=None, save_log_every=None, log=get_logger('SaltNet')):
    #args = locals()
    #args = {k:v.shape if isinstance(v, (torch.Tensor, np.ndarray)) else v for k,v in args.items()}
    #args = {k:v.shape if isinstance(v, (torch.Tensor, np.ndarray)) else v for k,v in args.items()}
    log.info('Start Training...')
    #log.info('Passed parameters: {}'.format(args))

    start = time.time()

    if torch.cuda.is_available():
        model.cuda()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_model = None
    best_iou = 0.0
    all_losses = []
    iter_count = 0
    X_train = other_data['X_train']
    X_val = other_data['X_val']
    y_train = other_data['y_train']
    y_val = other_data['y_val']
    X_train_mean_img = other_data['X_train_mean_img']

    for epoch in range(1, num_epochs+1):
        log.info('Epoch {}/{}'.format(epoch, num_epochs))
        log.info('-' * 20)
        if save_log_every is not None:
            if (epoch % save_log_every == 0):
                push_log_to_git()
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            epoch_loss = []
            epoch_mask_loss = []
            pred_vs_true_epoch = []

            for X_batch, y_batch, d_batch, X_id in dataloaders[phase]:
                #print(X_batch.shape)
                #print(len(iter(dataloaders[phase])))
                # zero the parameter gradients
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    y_pred, y_mask_pred = model(X_batch)
                    pred_vs_true_epoch.append([y_pred, y_batch])
                    #from boxx import g
                    #g()
                    loss_pix = criterion1(y_pred, y_batch.float())        
                    y_batch2 = (y_batch.view(y_batch.shape[0],-1).mean(1)==1).float()
                    loss_mask = criterion2(y_mask_pred, y_batch2)
                    print(f'Loss1: {loss1}, Loss2:{loss2}')
                    loss = loss_pix + loss_mask
                    all_losses.append(loss.item())
                    epoch_loss.append(loss.item())
                    epoch_mask_loss.append(loss_mask.item())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        iter_count += 1
                if (phase == 'train') & (iter_count % print_every == 0):
                    iou_batch = calc_mean_iou(y_pred.ge(0.5), y_batch.float())
                    iou_acc = calc_clf_accuracy(y_pred.ge(0.5), y_batch.float())

                    log.info('Batch Loss: {:.4f}, Epoch loss: {:.4f}, Epoch mask loss: {:.4f}, Batch IOU: {:.4f}, Batch Acc: {:.4f} at iter {}, epoch {}, Time: {}'.format(
                        np.mean(all_losses[-print_every:]), np.mean(epoch_loss), np.mean(epoch_mask_loss),iou_batch, iou_acc, iter_count, epoch, timeSince(start))
                    )
                    X_orig = X_train[X_id[0]].squeeze()
                    X_tsfm = X_batch[0,0].squeeze().cpu().detach().numpy()
                    X_tsfm = transform.resize(X_tsfm, (128, 128), mode='constant', preserve_range=True)
                    X_tsfm = X_tsfm[13:114,13:114] + X_train_mean_img.squeeze()
                    #X_tsfm = X_batch[0][X_batch[0].sum((1,2)).argmax()].squeeze().cpu().detach().numpy()[:101,:101] + X_train_mean_img.squeeze()

                    y_orig = y_train[X_id[0]].squeeze()
                    y_tsfm = (y_batch[0].squeeze().cpu().detach().numpy())
                    y_tsfm_pred =  y_pred[0].squeeze().gt(0.5)
                    plot_img_mask_pred([X_orig, X_tsfm, y_orig, y_tsfm, y_tsfm_pred],
                                       ['X Original', 'X Transformed', 'y Original', 'y Transformed', 'y Predicted'])

            y_pred_epoch = torch.cat([e[0] for e in pred_vs_true_epoch])
            y_true_epoch = torch.cat([e[1] for e in pred_vs_true_epoch])
            #from boxx import g
            #g()
            mean_iou_epoch = calc_mean_iou(y_pred_epoch.ge(0.5), y_true_epoch.float())
            mean_acc_epoch = calc_clf_accuracy(y_pred_epoch.ge(0.5), y_true_epoch.float())
            log.info('{} Mean IOU: {:.4f}, Mean Acc: {:.4f}, Best Val IOU: {:.4f} at epoch {}'.format(phase, mean_iou_epoch, mean_acc_epoch, best_iou, epoch))
            if phase == 'val' and mean_iou_epoch > best_iou:
                best_iou = mean_iou_epoch
                best_model_wts = copy.deepcopy(model.state_dict())
                stats = {'best_iou': best_iou,
                         'all_losses': all_losses,
                         'iter_count': iter_count}
                log.info(save_model_state_to_chunks(epoch, copy.deepcopy(model.state_dict()),
                                                    copy.deepcopy(optimizer.state_dict()),
                                                    copy.deepcopy(scheduler.state_dict()), stats, model_save_name, '.'))
                best_model = (epoch, copy.deepcopy(model.state_dict()),
                                                    copy.deepcopy(optimizer.state_dict()),
                                                    copy.deepcopy(scheduler.state_dict()), stats, model_save_name, '.')
                log.info('Best Val Mean IOU so far: {}'.format(best_iou))
                # Visualize 1 val sample and predictions
                X_orig = X_val[X_id[0]].squeeze()
                y_orig = y_val[X_id[0]].squeeze()
                y_pred2 =  y_pred[0].squeeze().gt(0.5)
                plot_img_mask_pred([X_orig, y_orig, y_pred2],
                                   ['Val X Original', 'Val y Original', 'Val y Predicted'])
        if save_model_every is not None:
            if (epoch % save_model_every == 0) | (epoch == num_epochs-1):
                if best_model is not None:
                    log.info(save_model_state_to_chunks(*best_model))                
                    push_model_to_git(ckp_name=model_save_name)
                    best_model = None
                else:
                    log.info("Skip pushing model to git as there's no improvement")

    # load best model weights
    model.load_state_dict(best_model_wts)
    log.info('-' * 20)
    time_elapsed = time.time() - start
    log.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    log.info('Best val IOU: {:4f}'.format(best_iou))

    return model

## Train the model using a small data set to see if it can overfit

In [None]:
saltnet = UNet(n_channels=1, n_classes=1, pred_overall_mask=True)

In [None]:
#saltnet = resnet34unet(in_ch=3, bilinear=False, pretrained=False)

loss_focal = FocalLoss(alpha=0.25, gamma=2)
loss_focal_mask = FocalLoss(alpha=0.25, gamma=3)
#loss_fn_iou = IOU_Loss()
optimizer = torch.optim.Adam(saltnet.parameters(), lr=0.001)

scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

model_save_name = None

# Test Run
trained_model = train_model(saltnet, dataloaders, loss_focal, loss_focal_mask, optimizer, scheduler, model_save_name, 
                other_data=all_data, num_epochs=50, print_every=2, save_model_every=None, save_log_every=None, log=log)

## Train the full with full dataset

In [None]:
saltnet = resnet34unet(in_ch=3, bilinear=False, pretrained=False)
saltnet.load_state_dict(torch.load('Unet_Resnet34_scale_1.2_baseline_2018_08_29_10_41_13.ckp'))

In [None]:
#saltnet = resnet34unet(in_ch=3, bilinear=False, pretrained=False)
saltnet = UNet(n_channels=1, n_classes=1)

loss_fn_bce = nn.BCELoss()
loss_focal = FocalLoss(alpha=0.25, gamma=2)
#loss_fn_iou = IOU_Loss()
optimizer = torch.optim.Adam(saltnet.parameters(), lr=0.01)

scheduler = lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

model_save_name = f'Unet_unet_focal_loss_baseline_{get_current_time_as_fname()}.ckp'

# Test Run
trained_model = train_model(saltnet, dataloaders, loss_focal, optimizer, scheduler, model_save_name, 
                            other_data=all_data, num_epochs=250, print_every=50, push_every=5, log=log)

In [None]:
push_to_git??

## Load Trained Model

In [None]:
ls

In [None]:
loaded_model = trained_model

In [None]:
'''loaded_model = resnet18unet()
model_file_suffix = "Unet_Resnet34_scale_1.2_baseline_2018_08_29_10_41_13.ckp"
model_state_dict = torch.load(join_files(model_file_suffix, '.', returnFileObject=True, removeChunks=False))
loaded_model.load_state_dict(model_state_dict['model'])'''

In [None]:
loaded_model = resnet34unet(in_ch=3, bilinear=False, pretrained=False)
model_file_suffix = "Unet_Resnet34_scale_1.2_baseline_2018_08_29_10_41_13.ckp"
model_state_dict = torch.load(join_files(model_file_suffix, '.', returnFileObject=True, removeChunks=False))
loaded_model.load_state_dict(model_state_dict['model'])

In [None]:
torch.save(loaded_model.state_dict(), 'Unet_Resnet34_scale_1.2_baseline_2018_08_29_10_41_13.ckp')

In [None]:
    log.info('Pushing model state to git.')
    get_ipython().system("git config user.email 'allen.qin.au@gmail.com'")
    get_ipython().system('git add .')
    get_ipython().system('git commit -m "Unet_Resnet34_scale_1.2_baseline_2018_08_29_10_41_13."')
    get_ipython().system('git push https://allen.qin.au%40gmail.com:github0mygod@github.com/allen-q/salt_oil.git --all --force')

### Make Predictions on validation set

### Set model to evaluation mode

In [None]:
loaded_model.eval()
assert loaded_model.training == False

In [None]:
val_dataLoader = DataLoader(SaltDataset(X_val, y_val, depth_val, X_train_mean_img, img_out_size=128), batch_size=16)

In [None]:
if torch.cuda.is_available():
    loaded_model.cuda()

In [None]:
y_val_pred = []

with torch.no_grad():
    for X_val_batch, y_val_batch, depth_val_batch, X_val_id_batch in val_dataLoader:
        y_val_pred.append(loaded_model(X_val_batch))
y_val_pred = torch.cat(y_val_pred)

In [None]:
train_dataLoader = DataLoader(SaltDataset(X_train, y_train, depth_train, X_train_mean_img, img_out_size=128), batch_size=16)
y_train_pred = []

with torch.no_grad():
    for X_train_batch, y_train_batch, depth_train_batch, X_train_id_batch in train_dataLoader:
        y_train_pred.append(loaded_model(X_train_batch))
y_train_pred = torch.cat(y_train_pred)

In [None]:
for i in range(10):
    rand_id = np.random.choice(X_val_id_batch)
    print(f'Image ID: {rand_id}')
    val_img = X_val[rand_id]
    val_mask = y_val[rand_id]
    val_mask_pred = y_val_pred.ge(0.5)[rand_id]
    plot_img_mask_pred([val_img, val_mask, val_mask_pred], range(3), img_per_line=3)

In [None]:
for i in range(10):
    rand_id = np.random.choice(X_train_id_batch)
    print(f'Image ID: {rand_id}')
    img = X_train[rand_id]
    mask = y_train[rand_id]
    mask_pred = y_train_pred.ge(0.5)[rand_id]
    plot_img_mask_pred([img, mask, mask_pred], range(3), img_per_line=3)

In [None]:
ZERO_MASK_CUTOFF = 100

In [None]:
y_val_pred_adj = adjust_predictions(ZERO_MASK_CUTOFF, X_val, y_val_pred.gt(0.5), y_val.squeeze())

In [None]:
y_val_pred_adj = adjust_predictions(100, X_val, y_val_pred.gt(0.5), y_val.squeeze())

In [None]:
results=[]
for cut_off in range(0, 3000, 10):
  print(cut_off)
  results.append(calc_mean_iou(adjust_predictions(cut_off, X_val, y_val_pred.gt(0.5), y_val.squeeze()), y_val.squeeze()))

In [None]:
ZERO_MASK_CUTOFF = range(0, 3000, 10)[np.argmax(results)]

In [None]:
y_val_pred_adj = adjust_predictions(ZERO_MASK_CUTOFF, X_val, y_val_pred.gt(0.5), y_val.squeeze())

In [None]:
y_train_pred_adj = adjust_predictions(ZERO_MASK_CUTOFF, X_train, y_train_pred.gt(0.5), y_train.squeeze())

## Make predictions on test set

In [None]:
#test_dataLoader = DataLoader(SaltDataset(np_test[:10], None, depth_test, X_train_mean_img), batch_size=4)
test_dataLoader = DataLoader(SaltDataset(X_test, np.zeros_like(X_test), depth_test, X_train_mean_img, img_out_size=128), batch_size=16)

In [None]:
y_test_pred_raw = []
with torch.no_grad():
    for X_test_batch, y_test_batch, depth_test_batch, X_test_id_batch in test_dataLoader:
        y_test_pred_raw.append(loaded_model(X_test_batch))
y_test_pred = torch.cat(y_test_pred_raw)

### Show segmentation masks for a few images

In [None]:
for i in range(10):
    rand_id = np.random.choice(X_test_id_batch)
    print(f'Image ID: {rand_id}')
    img = X_test[rand_id]
    mask_pred = y_test_pred.ge(0.5)[rand_id]
    plot_img_mask_pred([img, mask_pred], range(2))

### Adjust predictions

In [None]:
y_test_pred_adj = adjust_predictions(ZERO_MASK_CUTOFF, X_test, y_test_pred.gt(0.5))

### Encode predictions using RLE(Run Length Encoding) method

In [None]:
y_test_pred_rle = rle_encoder3d(y_test_pred_adj)

In [None]:
y_test_pred_adj.shape

In [None]:
df_test_rle = pd.DataFrame(index=misc_data['np_test_ids'], data=y_test_pred_rle).reset_index()

In [None]:
df_test_rle.columns = ['id', 'rle_mask']

In [None]:
df_test_rle[df_test_rle.rle_mask==''].shape

In [None]:
df_test_rle.shape

In [None]:
df_test_rle.to_csv(f'submission_{get_current_time_as_fname()}.csv', index=False)

In [None]:
push_to_git()