## Setup environment

In [2]:
import torch
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn
import torch.nn.functional as F
import argparse
import os


# from models.multi_stage_sequenceencoder import multistageSTARSequentialEncoder, multistageLSTMSequentialEncoder
# from models.networkConvRef import model_2DConv
from eval import evaluate_fieldwise

## Input requriement files

In [3]:
from TransformerEncoder import TransformerEncoder
from dataset import Dataset
from utils.scheduled_optimizer import ScheduledOptim
from utils.classmetric import ClassMetric

## Input Parameters

In [7]:
data_path   = "/Users/stevenzhu/Downloads/ZueriCrop.hdf5"   # path to dataset
gt_path     = '/Users/stevenzhu/Downloads/labels.csv'       # gt file path
batchsize   = 4                                             # batch size
workers     = 8                                             # number of dataset worker threads
epochs      = 30                                            # epochs to train
lr          = 0.001                                         # learning rate
snapshot    = None                                          # load weights from snapshot
checkpoint_dir = '/Users/stevenzhu/Downloads/'              # directory to save checkpoints
weight_decay= 0.0001                                        # weight_decay
hidden      = 64                                            # hidden dim
layer       = 6                                             # num layer
lrSC        = 2                                             # lrScheduler
dropout     = 0.5                                           # dropout of CNN
stage       = 3                                             # num stage
clip        = 5                                             # grad clip
seed        = 0                                             # random seed
fold        = 1                                             # 5 fold
cell        = "star"                                        # Cell type: main building block
input_dim   = 4                                             # Input channel size
apply_cm    = False                                         # apply cloud masking
n_layers    = 6                                             # number of layers
d_inner     = hidden*4
n_heads     = 8                                             # 
fold_num    = None
test_every_n_epochs = 1                                     # get test every n epochs

## main

In [None]:
# def main(
#         datadir=None,
#         batchsize=1,
#         workers=12,
#         epochs=1,
#         lr=1e-3,
#         snapshot=None,
#         checkpoint_dir=None,
#         weight_decay=0.0000,
#         name='debug',
#         layer=6,
#         hidden=64,
#         lrS=1,
#         lambda_1=1,
#         lambda_2=1,
#         stage=3,
#         clip=1,
#         fold_num=None,
#         gt_path=None,
#         cell=None,
#         dropout=None,
#         input_dim=None,
#         apply_cm=None
#         ):

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Prepare dataset

traindataset    = Dataset(data_path, 0., 'train', False, fold_num, gt_path, num_channel=input_dim, apply_cloud_masking=apply_cm)
testdataset     = Dataset(data_path, 0., 'test', True, fold_num, gt_path, num_channel=input_dim, apply_cloud_masking=apply_cm)

# number of classes
nclasses     = traindataset.n_classes
len_max_seq  = traindataset.max_obs

# Loss weight
LOSS_WEIGHT     = torch.ones(nclasses)
LOSS_WEIGHT[0]  = 0

# Class stage mappping
s1_2_s3 = traindataset.l1_2_g
s2_2_s3 = traindataset.l2_2_g

# load dataset
traindataloader = torch.utils.data.DataLoader(traindataset, batch_size=batchsize, shuffle=True, num_workers=workers)
testdataloader = torch.utils.data.DataLoader(testdataset, batch_size=batchsize, shuffle=True, num_workers=workers)

# model
model =  TransformerEncoder(in_channels=input_dim, len_max_seq=len_max_seq,
        d_word_vec=hidden, d_model=hidden, d_inner=d_inner,
        n_layers=n_layers, n_head=n_heads, d_k=hidden//n_heads, d_v=hidden//n_heads,
        dropout=dropout, nclasses=nclasses)


pytorch_total_params = sum(p.numel() for p in model.parameters())
print("initialized {} model ({} parameters)".format(model, pytorch_total_params))

#loss
loss = torch.nn.CrossEntropyLoss(weight=LOSS_WEIGHT)



# CUDA
print('CUDA available: ', torch.cuda.is_available())
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()
    loss = loss.cuda()



# Optimizer
optimizer = torch.optim.Adam(list(model.parameters()), lr=lr,
                                weight_decay=weight_decay)

#optimizer = ScheduledOptim(
#        optim.Adam(
#            filter(lambda x: x.requires_grad, model.parameters()),
#            betas=(0.9, 0.98), eps=1e-09, weight_decay= weight_decay),
#        model.d_model, args.warmup)

## Learning rate scheduler
# if lrS == 1:
#     lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1, last_epoch=-1)
# elif lrS == 2:
#     lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, last_epoch=-1)
# elif lrS == 3:
#     lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5, last_epoch=-1)
# else:
#     lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5, last_epoch=-1)




start_epoch = 0
best_test_acc = 0
#if snapshot is not None:
#    checkpoint = torch.load(snapshot)
#    network.load_state_dict(checkpoint['network_state_dict'])
#    network_gt.load_state_dict(checkpoint['network_gt_state_dict'])
#    optimizer.load_state_dict(checkpoint['optimizerA_state_dict'])

## fit model
for epoch in range(start_epoch, epochs):

    print("\nEpoch {}".format(epoch))

    train_epoch(traindataloader, model, optimizer, loss)

    # call LR scheduler
    #lr_scheduler.step()

    # evaluate model
    if epoch > 1 and epoch % 1 == 0:
        print("\n Eval on test set")
        stats = test_epoch(model, testdataloader, nclasses)
        print(stats, epoch)

        if checkpoint_dir is not None:
            checkpoint_name = os.path.join(checkpoint_dir, name + '_epoch_' + str(epoch) + "_model.pth")
            if test_acc > best_test_acc:
                print('Model saved! Best val acc:', test_acc)
                best_test_acc = test_acc
                snapshot(checkpoint_name, optimizer, epoch)


In [8]:
def train_epoch(dataloader, model, optimizer, loss):

    model.train()

    # builds a confusion matrix
    # metric = ClassMetric(num_classes=self.nclasses)   

    for iteration, data in enumerate(dataloader):
        optimizer.zero_grad()

        input, target_glob = data

        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()


        logprobabilities, deltas, pts, budget = model.forward(input.transpose(1,2))
        loss = F.nll_loss(logprobabilities, target[:, 0])
        loss.backward()
        optimizer.step()
#        if isinstance(optimizer,ScheduledOptim):
#            optimizer.step_and_update_lr()
#        else:
#            optimizer.step()





In [9]:
def test_epoch(model, dataloader, nclasses, epoch=None):
    # sets the model to train mode: no dropout is applied
    model.eval()

    # builds a confusion matrix
    #metric_maxvoted = ClassMetric(num_classes=self.nclasses)
    metric = ClassMetric(num_classes=nclasses)
    #metric_all_t = ClassMetric(num_classes=self.nclasses)

    tstops = list()
    predictions = list()
    probas = list()
    ids_list = list()
    labels = list()


    with torch.no_grad():
        for iteration, data in enumerate(dataloader):

            input, target_glob = data

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                target_glob = target_glob.cuda()

            logprobabilities, deltas, pts, budget = model.forward(inputs.transpose(1, 2))

            loss = F.nll_loss(logprobabilities, target_glob[:, 0])

            stats = dict(loss=loss,)

            prediction = model.predict(logprobabilities)
            t_stop = None

            ## enter numpy world
            prediction = prediction.detach().cpu().numpy()
            label = target_glob.mode(1)[0].detach().cpu().numpy()
            if t_stop is not None: t_stop = t_stop.cpu().detach().numpy()
            if pts is not None: pts = pts.detach().cpu().numpy()
            if deltas is not None: deltas = deltas.detach().cpu().numpy()
            if budget is not None: budget = budget.detach().cpu().numpy()

            if t_stop is not None: tstops.append(t_stop)
            predictions.append(prediction)
            labels.append(label)
            probas.append(logprobabilities.exp().detach().cpu().numpy())

            stats = metric.add(stats)

            accuracy_metrics = metric.update_confmat(label,
                                                        prediction)

            stats["accuracy"] = accuracy_metrics["overall_accuracy"]
            stats["mean_accuracy"] = accuracy_metrics["accuracy"].mean()

            #for cl in range(len(accuracy_metrics["accuracy"])):
            #    acc = accuracy_metrics["accuracy"][cl]
            #    stats["class_{}_accuracy".format(cl)] = acc

            stats["mean_recall"] = accuracy_metrics["recall"].mean()
            stats["mean_precision"] = accuracy_metrics["precision"].mean()
            stats["mean_f1"] = accuracy_metrics["f1"].mean()
            stats["kappa"] = accuracy_metrics["kappa"]
            if t_stop is not None:
                earliness = (t_stop.astype(float) / (inputs.shape[1] - 1)).mean()
                stats["earliness"] = metric.update_earliness(earliness)

        stats["confusion_matrix"] = copy.copy(metric.hist)
        stats["targets"] = targets.cpu().numpy()
        stats["inputs"] = inputs.cpu().numpy()
        if deltas is not None: stats["deltas"] = deltas
        if pts is not None: stats["pts"] = pts
        if budget is not None: stats["budget"] = budget




    if t_stop is not None: stats["t_stops"] = np.hstack(tstops)
    stats["predictions"] = np.hstack(predictions) # N
    stats["labels"] = np.hstack(labels) # N
    stats["probas"] = np.vstack(probas) # NxC

    return stats

In [10]:
def snapshot(filename, optimizer, epoch):
    model.save(
        filename,
        optimizer_state_dict=optimizer.state_dict(),
        epoch=epoch)