In [1]:
import sys
sys.path.append('../')

In [2]:
import os
import argparse
import yaml
import wandb
import random
import time

import math
import numpy as np
import torch
import torch.nn as nn

from data.dataset import SegmentDataset
from data.transforms import transform
from data.collate import collate
from torch.utils.data import DataLoader
from models.model import SegmentModel
from models.utils.loss import PixelLoss
from metrics.metrics import pixelAccuracy, gatherMetrics
from metrics.pred import predict, getMask
from utils.vis import showPredictions
from utils.decorators import timer
from utils.parameters import *
from metrics.metrics import bakeWeight
from metrics.utils import conf_operations

from models.unet import UNet

  import pandas.util.testing as tm


In [3]:
version = 'OF_street_v10'
cfg_path = '../configs/{}.yml'.format(version.replace('_', '/').replace('-', '/'))
all_configs = yaml.safe_load(open(cfg_path))


random_seed = int(all_configs['random_seed'])
batch_size = int(all_configs['batch_size'])
num_classes = int(all_configs['num_classes'])
if num_classes==2:
    ftr = all_configs['ftr']
    if ftr.lower()=='street':
        index2name = index2name_street
        color2index = color2index_street
    elif ftr.lower()=='building':
        index2name = index2name_building
        color2index = color2index_building
    else:
        raise ValueError("Unknown feature found - {}".format(ftr))

n_epoch = int(all_configs['n_epoch'])
train_annot = '../'+all_configs['train_annot']
val_annot = '../'+all_configs['val_annot']

n_segment_layers = all_configs['n_segment_layers']
tail = all_configs['tail'] if 'tail' in all_configs else None
pretrained = all_configs['pretrained'] if 'pretrained' in all_configs else None

optimizer = all_configs['optimizer']
lr = float(all_configs['lr'])
weight_decay = float(all_configs['weight_decay'])
adam_eps = float(all_configs['adam_eps'])
amsgrad = all_configs['amsgrad']

CHCEKPOINT_DIR = all_configs['CHCEKPOINT_DIR']
ckpt_dir = os.path.join(CHCEKPOINT_DIR, version)

vis_batch = all_configs['vis_batch'] if ('vis_batch' in all_configs) else None 
metric_batch = all_configs['metric_batch'] if ('metric_batch' in all_configs) else None
use_augmentation = all_configs['use_augmentation']
loss_weights, hnm = None, None

if 'hnm' in all_configs:
    hnm = float(all_configs['hnm'])

if 'loss_weights' in all_configs:
    loss_weights = torch.FloatTensor(all_configs['loss_weights'])

In [4]:
model = UNet(n_channels=3, n_classes=2).cuda()

# model = SegmentModel(num_features=num_classes, n_layers=n_segment_layers).cuda()
criterion = PixelLoss(num_classes=num_classes, loss_weights=loss_weights, hnm=hnm)

if optimizer=='adam':
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=lr, weight_decay=weight_decay, eps=adam_eps, amsgrad=amsgrad
    )

scheduler = None
train_losses, val_losses = [], []
if 'scheduler' in all_configs:
    sch_factor = all_configs['scheduler']
    lr_lambda = lambda epoch: sch_factor**epoch
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

train_set = SegmentDataset(
    annot=train_annot, 
    transform=transform, 
    dim=(1024, 1024), 
    c2i=color2index
)
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=8,
    collate_fn=collate,
)

In [5]:
for (_, _, img, mask) in train_loader:
    y_pred = model(img.cuda()).detach().cpu()
    img = img.detach().cpu()
    print(y_pred.shape)

torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
torch.Size([1, 2, 1024, 1024])
