In [1]:
import sys
sys.path.append('../')

import matplotlib.pyplot as plt

In [2]:
import os
import argparse
import yaml
import wandb
import random
import time

import numpy as np
import torch
import torch.nn as nn

from data.dataset import SegmentDataset
from data.transforms import transform
from data.collate import collate
from torch.utils.data import DataLoader
from models.model import SegmentModel
from models.utils.loss import PixelLoss
from metrics.metrics import pixelAccuracy, gatherMetrics
from metrics.pred import predict, getMask
from utils.vis import showPredictions
from utils.decorators import timer
from utils.parameters import *

  import pandas.util.testing as tm


In [5]:

version = 'street_v4'
cfg_path = '../configs/{}.yml'.format(version.replace('_', '/'))
all_configs = yaml.safe_load(open(cfg_path))


random_seed = int(all_configs['random_seed'])
batch_size = int(all_configs['batch_size'])
num_classes = int(all_configs['num_classes'])
if num_classes==2:
    ftr = all_configs['ftr']
    if ftr.lower()=='street':
        index2name = index2name_street
        color2index = color2index_street
    elif ftr.lower()=='building':
        index2name = index2name_building
        color2index = color2index_building
    else:
        raise ValueError("Unknown feature found - {}".format(ftr))

n_epoch = int(all_configs['n_epoch'])
train_annot = all_configs['train_annot']
val_annot = all_configs['val_annot']
n_segment_layers = all_configs['n_segment_layers']
optimizer = all_configs['optimizer']
lr = float(all_configs['lr'])
weight_decay = float(all_configs['weight_decay'])
adam_eps = float(all_configs['adam_eps'])
amsgrad = all_configs['amsgrad']
CHCEKPOINT_DIR = all_configs['CHCEKPOINT_DIR']
ckpt_dir = os.path.join(CHCEKPOINT_DIR, version)
use_augmentation = all_configs['use_augmentation']
loss_weights = None
if 'loss_weights' in all_configs:
    loss_weights = torch.FloatTensor(all_configs['loss_weights'])

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)


random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


model = SegmentModel(num_features=num_classes, n_layers=n_segment_layers).cuda()
criterion = PixelLoss(num_classes=num_classes, loss_weights=loss_weights)

if optimizer=='adam':
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=lr, weight_decay=weight_decay, eps=adam_eps, amsgrad=amsgrad
    )

scheduler = None
train_losses, val_losses = [], []
if 'scheduler' in all_configs:
    sch_factor = all_configs['scheduler']
    lr_lambda = lambda epoch: sch_factor**epoch
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

train_set = SegmentDataset(
    annot='../'+train_annot, 
    transform=transform, 
    dim=(2048, 2048), 
    c2i=color2index
)
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=8,
    collate_fn=collate,
)

In [7]:
n_batch = 2
pred_fig_indices = list(range(0, len(train_loader)-1))
random.shuffle(pred_fig_indices)
pred_fig_indices = pred_fig_indices[:n_batch]

In [9]:

@timer
def train(epoch, loader, optimizer, metrics=[]):
    n = len(loader)
    tot_loss = 0.0
    masks, mask_preds = [], []
    y_preds = []
    if 'pred' in metrics:
        vis_img, vis_mask, vis_y_pred = [], [], []

    model.train()
    for batch_idx, (_, _, image, mask) in enumerate(loader):
        y_pred = model(image.cuda())
        image = image.detach().cpu()
        loss = criterion(y_pred, mask.cuda())
        loss.backward()
        optimizer.step()

        y_pred = y_pred.detach().cpu()
        tot_loss += loss.item()

        train_losses.append(loss.item())
        y_preds.append(y_pred)
        masks.append(mask)

        if 'pred' in metrics:
            if batch_idx in pred_fig_indices:
                vis_img.append(image)
                vis_mask.append(mask)
                vis_y_pred.append(y_pred)

        n_arr = (50*(batch_idx+1))//n
        progress = 'Training : [{}>{}] ({}/{}) loss : {:.4f}, avg_loss : {:.4f}'.format(
            '='*n_arr, '-'*(50-n_arr), (batch_idx+1), n, loss.item(), tot_loss/(batch_idx+1))
        # if 'acc' in metrics:
        #     progress = '{}, acc : {:.4f}, avg_acc : {:.4f}'.format(progress, acc, tot_acc/(batch_idx+1))
        print(progress, end='\r')

    print("\n")
    logg = {
        'training_loss': tot_loss/n,
    }

    # Metrics
    masks = torch.cat(masks, dim=0)
    y_preds = torch.cat(y_preds, dim=0)
    return masks, y_preds
#     logg_metrics = gatherMetrics(
#         params=(masks, y_preds),
#         metrics=metrics,
#         mode='train',
#         i2n=index2name,
#     )
#     logg.update(logg_metrics)

#     # Visualizations
#     if 'pred' in metrics:
#         vis_img = torch.cat(vis_img, dim=0)
#         vis_mask = torch.cat(vis_mask, dim=0)
#         vis_y_pred = torch.cat(vis_y_pred, dim=0)
#         vis_mask_pred = predict(None, None, use_cache=True, params=(vis_y_pred, False))
#         pred_fig = showPredictions(
#             vis_img, vis_mask, vis_mask_pred, 
#             use_path=False, ret='fig', debug=False, size='auto',
#             getMatch=True,
#         )
#         logg.update({'train_prediction': wandb.Image(pred_fig)})

#     return logg



In [37]:
masks, y_preds = train(1, train_loader, optimizer)
y_preds = torch.nn.Softmax(dim=1)(y_preds)


Time : 33.14039134979248 seconds


In [35]:
mask = masks[0]
y_prob = y_preds[0]

mask_bg = mask[mask==0]
mask_street = mask[mask==1]

# For all background ground truths calculate average bg & street score
y_prob_bg = y_prob[0][mask==0].numpy()
y_prob_street = y_prob[1][mask==0].numpy()

print(np.mean(y_prob_bg))
print(np.mean(y_prob_street))

# print(y_prob[0][mask_bg].shape)

# for 

0.50433445
0.49566555


In [36]:
y_prob_bg = y_prob[0][mask==1].numpy()
y_prob_street = y_prob[1][mask==1].numpy()

print(np.mean(y_prob_bg))
print(np.mean(y_prob_street))


0.5011159
0.49888414
