In [1]:
#from basic_fcn import *
#import transfer_fcn 
#import customfcn1
#import customfcn2
#import voc
import torch.nn as nn
import time
import torch
import gc
import os
import util
#import torchvision.transforms as standard_transforms
import torchvision.transforms.functional as TF
import numpy as np
import unet
from mri_imgmask import *

#MODE = ['lr', 'weight', 'augment', 'unet']
MODE = ['lr', 'weight', 'unet', 'init']
"""
None: baseline
'lr': 4a (lr schedule)
'augment': 4b (data augment)
'weight': 4c (weight)
'custom1': 5a-1 (custom1)
'custom2': 5a-2 (custom2)
'transfer': 5b (transfer)
'unet': 5c (unet)
"""

"\nNone: baseline\n'lr': 4a (lr schedule)\n'augment': 4b (data augment)\n'weight': 4c (weight)\n'custom1': 5a-1 (custom1)\n'custom2': 5a-2 (custom2)\n'transfer': 5b (transfer)\n'unet': 5c (unet)\n"

In [2]:
class MaskToTensor(object):
    def __call__(self, img):
        return torch.from_numpy(np.array(img, dtype=np.int32)).long()


def init_weights(m):
    if 'transfer' in MODE:
        if isinstance(m, nn.ConvTranspose2d):
            torch.nn.init.xavier_uniform_(m.weight.data)
            torch.nn.init.normal_(m.bias.data) #xavier not applicable for biases
    else:
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            torch.nn.init.xavier_uniform_(m.weight.data)
            torch.nn.init.normal_(m.bias.data) #xavier not applicable for biases

def getClassWeights(dataset, n_class):
    cum_counts = torch.zeros(n_class)
    for iter, (inputs, labels) in enumerate(dataset):
        # inputs: 1x3x256x256, labels: 1x256x256 => 256x256
        labels = torch.squeeze(labels)
        vals, counts = labels.unique(return_counts = True)
        for v, c in zip(vals, counts):
            cum_counts[v.item()] += c.item()
            
        print(f"Cumulative counts at iter {iter}: {cum_counts}")
            
    totalPixels = torch.sum(cum_counts)
    classWeights = 1 - (cum_counts / totalPixels)
    print(f"Class weights: {classWeights}")
    return classWeights

In [3]:
def train_transform(image, mask):
    image = TF.to_tensor(image)
    mask = torch.from_numpy(np.array(mask, dtype=np.int32)).long()
    image = TF.normalize(image, mean=mean_std[0], std=mean_std[1])
    
    # mask value from 255 to 1
    mask[mask==255] = 1

    if 'augment' in MODE:
        images = list(TF.ten_crop(image, 128))
        masks = list(TF.ten_crop(mask, 128))
        for i in range(10):
            angles = [30, 60]
            for angle in angles:
                msk = masks[i].unsqueeze(0)
                img = TF.rotate(images[i], angle)
                msk = TF.rotate(msk, angle)
                msk = msk.squeeze(0)
                images.append(img)
                masks.append(msk)
                
        image = torch.stack([img for img in images])
        mask = torch.stack([msk for msk in masks])
        
    return image, mask

def valtest_transform(image, mask):
    image = TF.to_tensor(image)
    mask = torch.from_numpy(np.array(mask, dtype=np.int32)).long()
    image = TF.normalize(image, mean=mean_std[0], std=mean_std[1])

    # mask value from 255 to 1
    mask[mask==255] = 1
    
    return image, mask

def sample_transform(image, mask):
    image = torch.from_numpy(np.array(image, dtype=np.int32)).long()
    mask = torch.from_numpy(np.array(mask, dtype=np.int32)).long()
    
    return image, mask

In [5]:
# ----------------  TODO: move to config file later ------------------------#
root = os.path.join('archive','lgg-mri-segmentation','kaggle_3m')
epochs = 20
n_class = 2
learning_rate = 0.01
early_stop_tolerance = 8
model_save_path = 'model.pth'
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

if 'init' in MODE:  # need it only once
    make_trainvaltestCSV(root)
if 'mini' in MODE:
    make_trainvaltestCSV_mini(root, 100)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # determine which device to use (cuda or cpu)
# ---------------------------------------------------------------------------#

def get_model():
    model = unet.UNet(n_class=n_class)
    model.apply(init_weights)
    
    if 'transfer' in MODE:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
        optimizer = torch.optim.Adam(params_to_update, lr = learning_rate, weight_decay=0.001)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    
    if 'lr' in MODE:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    if 'weight' in MODE:
        train_loader_no_shuffle = get_dataloader(root, 'train', transforms=train_transform, batch_size=1, shuffle=False)
        classWeights = getClassWeights(train_loader_no_shuffle, n_class).to(device)
        criterion = nn.CrossEntropyLoss(weight=classWeights) 
    else:
        criterion = nn.CrossEntropyLoss() 
    model = model.to(device)
    model_dict = {'model':model, 'criterion':criterion, 'optimizer':optimizer}
    if 'lr' in MODE:
        model_dict['scheduler'] = scheduler
    return model_dict

Total Data: 1980, Train - 1584, Valid - 198, Test - 198


In [6]:
def train(train_loader, valid_loader, model_dict):
    best_iou_score = 0.0
    best_loss = 100.0
    early_stop_count = 0
    train_loss_per_epoch = []
    train_iou_per_epoch = []
    train_acc_per_epoch = []
    
    valid_loss_per_epoch = []
    valid_iou_per_epoch = []
    valid_acc_per_epoch = []
    
    model = model_dict['model']
    criterion = model_dict['criterion']
    optimizer = model_dict['optimizer']
    if 'lr' in MODE:
        scheduler = model_dict['scheduler']
    
    for epoch in range(epochs):
        ts = time.time()
        losses = []
        mean_iou_scores = []
        accuracy = []
        for iter, (inputs, labels) in enumerate(train_loader):
            # reset optimizer gradients
            optimizer.zero_grad()

            # both inputs and labels have to reside in the same device as the model's
            inputs =  inputs.to(device)# transfer the input to the same device as the model's
            labels =  labels.to(device) # transfer the labels to the same device as the model's
            
            if 'augment' in MODE:
                # due to crop transform
                b, ncrop, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                b, ncrop, h, w = labels.size()
                labels = labels.view(-1, h, w)
            
            outputs = model(inputs) # Compute outputs. we will not need to transfer the output, it will be automatically in the same device as the model's!

            loss = criterion(outputs, labels)  # calculate loss

            with torch.no_grad():
                losses.append(loss.item())
                _, pred = torch.max(outputs, dim=1)
                acc = util.pixel_acc(pred, labels)
                accuracy.append(acc)
                iou_score = util.iou(pred, labels)
                mean_iou_scores.append(iou_score)
                
            # backpropagate
            loss.backward()

            # update the weights
            optimizer.step()

            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))

        if 'lr' in MODE:
            print(f'Learning rate at epoch {epoch}: {scheduler.get_last_lr()[0]:0.9f}')  # changes every epoch
            # lr scheduler
            scheduler.step()           
                    
        with torch.no_grad():
            train_loss_at_epoch = np.mean(losses)
            train_iou_at_epoch = np.mean(mean_iou_scores)
            train_acc_at_epoch = np.mean(accuracy)

            train_loss_per_epoch.append(train_loss_at_epoch)
            train_iou_per_epoch.append(train_iou_at_epoch)
            train_acc_per_epoch.append(train_acc_at_epoch)

            print("Finishing epoch {}, time elapsed {}".format(epoch, time.time() - ts))

            valid_loss_at_epoch, valid_iou_at_epoch, valid_acc_at_epoch = val(valid_loader, model_dict, epoch)
            valid_loss_per_epoch.append(valid_loss_at_epoch)
            valid_iou_per_epoch.append(valid_iou_at_epoch)
            valid_acc_per_epoch.append(valid_acc_at_epoch)

            if valid_iou_at_epoch > best_iou_score:
                best_iou_score = valid_iou_at_epoch
                # save the best model
            if valid_loss_at_epoch < best_loss:
                print(f"Valid Loss {valid_loss_at_epoch} < Best Loss {best_loss}. (Valid IOU {valid_iou_at_epoch}) Saving Model...")
                best_loss = valid_loss_at_epoch
                early_stop_count = 0
                torch.save(model.state_dict(), model_save_path)
            else:
                early_stop_count += 1
                if early_stop_count > early_stop_tolerance:
                    print("Early Stopping...")
                    break
    model.load_state_dict(torch.load(model_save_path))
            
    return best_iou_score, train_loss_per_epoch, train_iou_per_epoch, train_acc_per_epoch, valid_loss_per_epoch, valid_iou_per_epoch, valid_acc_per_epoch
    

def val(valid_loader, model_dict, epoch):
    model = model_dict['model']
    criterion = model_dict['criterion']

    model.eval() # Put in eval mode (disables batchnorm/dropout) !
    
    losses = []
    mean_iou_scores = []
    accuracy = []

    with torch.no_grad(): # we don't need to calculate the gradient in the validation/testing
        for iter, (input, label) in enumerate(valid_loader):
            input = input.to(device)
            label = label.to(device)
            
            output = model(input)
            loss = criterion(output, label)
            losses.append(loss.item())
            _, pred = torch.max(output, dim=1)
            acc = util.pixel_acc(pred, label)
            accuracy.append(acc)
            iou_score = util.iou(pred, label)
            mean_iou_scores.append(iou_score)
        loss_at_epoch = np.mean(losses)
        iou_at_epoch = np.mean(mean_iou_scores)
        acc_at_epoch = np.mean(accuracy)

    print(f"Valid Loss at epoch: {epoch} is {loss_at_epoch}")
    print(f"Valid IoU at epoch: {epoch} is {iou_at_epoch}")
    print(f"Valid Pixel acc at epoch: {epoch} is {acc_at_epoch}")

    model.train() #TURNING THE TRAIN MODE BACK ON TO ENABLE BATCHNORM/DROPOUT!!

    return loss_at_epoch, iou_at_epoch, acc_at_epoch


def modelTest(test_loader, model_dict):
    model = model_dict['model']
    criterion = model_dict['criterion']

    model.eval()  # Put in eval mode (disables batchnorm/dropout) !

    image_outputs = []
    image_labels = []
    
    losses = []
    mean_iou_scores = []
    accuracy = []

    with torch.no_grad():  # we don't need to calculate the gradient in the validation/testing

        for iter, (input, label) in enumerate(test_loader):
            input = input.to(device)
            label = label.to(device)

            output = model(input)
            loss = criterion(output, label)
            losses.append(loss.item())
            _, pred = torch.max(output, dim=1)
            acc = util.pixel_acc(pred, label)
            accuracy.append(acc)
            iou_score = util.iou(pred, label)
            mean_iou_scores.append(iou_score)
            
            image_outputs.extend(output)
            image_labels.extend(label)

    test_loss = np.mean(losses)
    test_iou = np.mean(mean_iou_scores)
    test_acc = np.mean(accuracy)

    model.train()  #TURNING THE TRAIN MODE BACK ON TO ENABLE BATCHNORM/DROPOUT!!

    return test_loss, test_iou, test_acc, image_outputs, image_labels

def visualize_image(pred, true):
    figure, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))

    ax[0].imshow(pred)
    ax[1].imshow(true)
    ax[0].set_title("Predicted Mask")
    ax[1].set_title("Original Mask")
    
    figure.tight_layout()
    figure.show()

In [7]:
model_dict = get_model()
train_loader = get_dataloader(root, 'train', transforms=train_transform, batch_size=8, shuffle=True)
valid_loader = get_dataloader(root, 'valid', transforms=valtest_transform, batch_size=8, shuffle=True)
test_loader = get_dataloader(root, 'test', transforms=valtest_transform, batch_size=8, shuffle=True)

val(valid_loader, model_dict, epoch=0)  # show the accuracy before training

train : 3032
Cumulative counts at iter 0: tensor([65536.,     0.])
Cumulative counts at iter 1: tensor([131072.,      0.])
Cumulative counts at iter 2: tensor([196608.,      0.])
Cumulative counts at iter 3: tensor([258261.,   3883.])
Cumulative counts at iter 4: tensor([319348.,   8332.])
Cumulative counts at iter 5: tensor([381779.,  11437.])
Cumulative counts at iter 6: tensor([447028.,  11724.])
Cumulative counts at iter 7: tensor([512564.,  11724.])
Cumulative counts at iter 8: tensor([578100.,  11724.])
Cumulative counts at iter 9: tensor([643636.,  11724.])
Cumulative counts at iter 10: tensor([709172.,  11724.])
Cumulative counts at iter 11: tensor([772620.,  13812.])
Cumulative counts at iter 12: tensor([834151.,  17817.])
Cumulative counts at iter 13: tensor([895386.,  22118.])
Cumulative counts at iter 14: tensor([957943.,  25097.])
Cumulative counts at iter 15: tensor([1023479.,   25097.])
Cumulative counts at iter 16: tensor([1089015.,   25097.])
Cumulative counts at iter 

Cumulative counts at iter 154: tensor([9990461.,  167619.])
Cumulative counts at iter 155: tensor([10055997.,   167619.])
Cumulative counts at iter 156: tensor([10121533.,   167619.])
Cumulative counts at iter 157: tensor([10186045.,   168643.])
Cumulative counts at iter 158: tensor([10251581.,   168643.])
Cumulative counts at iter 159: tensor([10317117.,   168643.])
Cumulative counts at iter 160: tensor([10382449.,   168847.])
Cumulative counts at iter 161: tensor([10447985.,   168847.])
Cumulative counts at iter 162: tensor([10513521.,   168847.])
Cumulative counts at iter 163: tensor([10579057.,   168847.])
Cumulative counts at iter 164: tensor([10644593.,   168847.])
Cumulative counts at iter 165: tensor([10710129.,   168847.])
Cumulative counts at iter 166: tensor([10775665.,   168847.])
Cumulative counts at iter 167: tensor([10841201.,   168847.])
Cumulative counts at iter 168: tensor([10906737.,   168847.])
Cumulative counts at iter 169: tensor([10971552.,   169568.])
Cumulative

Cumulative counts at iter 301: tensor([19530998.,   260876.])
Cumulative counts at iter 302: tensor([19593858.,   263552.])
Cumulative counts at iter 303: tensor([19657060.,   265885.])
Cumulative counts at iter 304: tensor([19722596.,   265885.])
Cumulative counts at iter 305: tensor([19788132.,   265885.])
Cumulative counts at iter 306: tensor([19853668.,   265885.])
Cumulative counts at iter 307: tensor([19919204.,   265885.])
Cumulative counts at iter 308: tensor([19984740.,   265885.])
Cumulative counts at iter 309: tensor([20050276.,   265885.])
Cumulative counts at iter 310: tensor([20115812.,   265885.])
Cumulative counts at iter 311: tensor([20181084.,   266150.])
Cumulative counts at iter 312: tensor([20246620.,   266150.])
Cumulative counts at iter 313: tensor([20312156.,   266150.])
Cumulative counts at iter 314: tensor([20377298.,   266544.])
Cumulative counts at iter 315: tensor([20442834.,   266544.])
Cumulative counts at iter 316: tensor([20508370.,   266544.])
Cumulati

Cumulative counts at iter 437: tensor([28383946.,   320833.])
Cumulative counts at iter 438: tensor([28449482.,   320833.])
Cumulative counts at iter 439: tensor([28515018.,   320833.])
Cumulative counts at iter 440: tensor([28580554.,   320833.])
Cumulative counts at iter 441: tensor([28646090.,   320833.])
Cumulative counts at iter 442: tensor([28708762.,   323697.])
Cumulative counts at iter 443: tensor([28771870.,   326125.])
Cumulative counts at iter 444: tensor([28837406.,   326125.])
Cumulative counts at iter 445: tensor([28902942.,   326125.])
Cumulative counts at iter 446: tensor([28968478.,   326125.])
Cumulative counts at iter 447: tensor([29034014.,   326125.])
Cumulative counts at iter 448: tensor([29098168.,   327506.])
Cumulative counts at iter 449: tensor([29161074.,   330136.])
Cumulative counts at iter 450: tensor([29226610.,   330136.])
Cumulative counts at iter 451: tensor([29292146.,   330136.])
Cumulative counts at iter 452: tensor([29357682.,   330136.])
Cumulati

Cumulative counts at iter 570: tensor([36975296.,   445765.])
Cumulative counts at iter 571: tensor([37040832.,   445765.])
Cumulative counts at iter 572: tensor([37106368.,   445765.])
Cumulative counts at iter 573: tensor([37171904.,   445765.])
Cumulative counts at iter 574: tensor([37237440.,   445765.])
Cumulative counts at iter 575: tensor([37302976.,   445765.])
Cumulative counts at iter 576: tensor([37366772.,   447506.])
Cumulative counts at iter 577: tensor([37428640.,   451173.])
Cumulative counts at iter 578: tensor([37489656.,   455692.])
Cumulative counts at iter 579: tensor([37555192.,   455692.])
Cumulative counts at iter 580: tensor([37620728.,   455692.])
Cumulative counts at iter 581: tensor([37681760.,   460198.])
Cumulative counts at iter 582: tensor([37743888.,   463607.])
Cumulative counts at iter 583: tensor([37806944.,   466086.])
Cumulative counts at iter 584: tensor([37872480.,   466086.])
Cumulative counts at iter 585: tensor([37938016.,   466086.])
Cumulati

Cumulative counts at iter 720: tensor([46690008.,   561440.])
Cumulative counts at iter 721: tensor([46755544.,   561440.])
Cumulative counts at iter 722: tensor([46820172.,   562348.])
Cumulative counts at iter 723: tensor([46885616.,   562441.])
Cumulative counts at iter 724: tensor([46950548.,   563044.])
Cumulative counts at iter 725: tensor([47016084.,   563044.])
Cumulative counts at iter 726: tensor([47081620.,   563044.])
Cumulative counts at iter 727: tensor([47146072.,   564128.])
Cumulative counts at iter 728: tensor([47211608.,   564128.])
Cumulative counts at iter 729: tensor([47277144.,   564128.])
Cumulative counts at iter 730: tensor([47342224.,   564582.])
Cumulative counts at iter 731: tensor([47406800.,   565543.])
Cumulative counts at iter 732: tensor([47472336.,   565543.])
Cumulative counts at iter 733: tensor([47537872.,   565543.])
Cumulative counts at iter 734: tensor([47603408.,   565543.])
Cumulative counts at iter 735: tensor([47668944.,   565543.])
Cumulati

Cumulative counts at iter 864: tensor([56000080.,   688558.])
Cumulative counts at iter 865: tensor([56063136.,   691040.])
Cumulative counts at iter 866: tensor([56128096.,   691615.])
Cumulative counts at iter 867: tensor([56193632.,   691615.])
Cumulative counts at iter 868: tensor([56259168.,   691615.])
Cumulative counts at iter 869: tensor([56319728.,   696593.])
Cumulative counts at iter 870: tensor([56383888.,   697971.])
Cumulative counts at iter 871: tensor([56449424.,   697971.])
Cumulative counts at iter 872: tensor([56514584.,   698347.])
Cumulative counts at iter 873: tensor([56578808.,   699658.])
Cumulative counts at iter 874: tensor([56644344.,   699658.])
Cumulative counts at iter 875: tensor([56704480.,   705060.])
Cumulative counts at iter 876: tensor([56769088.,   705988.])
Cumulative counts at iter 877: tensor([56834624.,   705988.])
Cumulative counts at iter 878: tensor([56895444.,   710704.])
Cumulative counts at iter 879: tensor([56960980.,   710704.])
Cumulati

Cumulative counts at iter 997: tensor([64635752.,   769185.])
Cumulative counts at iter 998: tensor([64701288.,   769185.])
Cumulative counts at iter 999: tensor([64766824.,   769185.])
Cumulative counts at iter 1000: tensor([64832360.,   769185.])
Cumulative counts at iter 1001: tensor([64897896.,   769185.])
Cumulative counts at iter 1002: tensor([64962928.,   769690.])
Cumulative counts at iter 1003: tensor([65027564.,   770590.])
Cumulative counts at iter 1004: tensor([65093100.,   770590.])
Cumulative counts at iter 1005: tensor([65158636.,   770590.])
Cumulative counts at iter 1006: tensor([65224172.,   770590.])
Cumulative counts at iter 1007: tensor([65289708.,   770590.])
Cumulative counts at iter 1008: tensor([65355244.,   770590.])
Cumulative counts at iter 1009: tensor([65420780.,   770590.])
Cumulative counts at iter 1010: tensor([65486316.,   770590.])
Cumulative counts at iter 1011: tensor([65551852.,   770590.])
Cumulative counts at iter 1012: tensor([65617008.,   77097

Cumulative counts at iter 1128: tensor([73105984.,   884169.])
Cumulative counts at iter 1129: tensor([73170944.,   884744.])
Cumulative counts at iter 1130: tensor([73234992.,   886231.])
Cumulative counts at iter 1131: tensor([73300528.,   886231.])
Cumulative counts at iter 1132: tensor([73366064.,   886231.])
Cumulative counts at iter 1133: tensor([73431600.,   886231.])
Cumulative counts at iter 1134: tensor([73497136.,   886231.])
Cumulative counts at iter 1135: tensor([73562672.,   886231.])
Cumulative counts at iter 1136: tensor([73624336.,   890105.])
Cumulative counts at iter 1137: tensor([73685496.,   894478.])
Cumulative counts at iter 1138: tensor([73748792.,   896721.])
Cumulative counts at iter 1139: tensor([73814328.,   896721.])
Cumulative counts at iter 1140: tensor([73878560.,   898027.])
Cumulative counts at iter 1141: tensor([73944096.,   898027.])
Cumulative counts at iter 1142: tensor([74009632.,   898027.])
Cumulative counts at iter 1143: tensor([74075168.,   89

Cumulative counts at iter 1273: tensor([82499128.,   993789.])
Cumulative counts at iter 1274: tensor([82564040.,   994412.])
Cumulative counts at iter 1275: tensor([82629576.,   994412.])
Cumulative counts at iter 1276: tensor([82695112.,   994412.])
Cumulative counts at iter 1277: tensor([82759256.,   995803.])
Cumulative counts at iter 1278: tensor([82824792.,   995803.])
Cumulative counts at iter 1279: tensor([82890328.,   995803.])
Cumulative counts at iter 1280: tensor([82955864.,   995803.])
Cumulative counts at iter 1281: tensor([83021400.,   995803.])
Cumulative counts at iter 1282: tensor([83085528.,   997214.])
Cumulative counts at iter 1283: tensor([83147384.,  1000897.])
Cumulative counts at iter 1284: tensor([83209552.,  1004268.])
Cumulative counts at iter 1285: tensor([83275088.,  1004268.])
Cumulative counts at iter 1286: tensor([83339088.,  1005807.])
Cumulative counts at iter 1287: tensor([83404624.,  1005807.])
Cumulative counts at iter 1288: tensor([83469608.,  100

Cumulative counts at iter 1407: tensor([91169760.,  1105008.])
Cumulative counts at iter 1408: tensor([91233032.,  1107272.])
Cumulative counts at iter 1409: tensor([91298568.,  1107272.])
Cumulative counts at iter 1410: tensor([91364104.,  1107272.])
Cumulative counts at iter 1411: tensor([91428376.,  1108536.])
Cumulative counts at iter 1412: tensor([91492488.,  1109958.])
Cumulative counts at iter 1413: tensor([91556408.,  1111577.])
Cumulative counts at iter 1414: tensor([91621456.,  1112061.])
Cumulative counts at iter 1415: tensor([91685800.,  1113255.])
Cumulative counts at iter 1416: tensor([91751336.,  1113255.])
Cumulative counts at iter 1417: tensor([91816872.,  1113255.])
Cumulative counts at iter 1418: tensor([91882408.,  1113255.])
Cumulative counts at iter 1419: tensor([91947944.,  1113255.])
Cumulative counts at iter 1420: tensor([92013480.,  1113255.])
Cumulative counts at iter 1421: tensor([92079016.,  1113255.])
Cumulative counts at iter 1422: tensor([92144552.,  111

Cumulative counts at iter 1540: tensor([99842136.,  1148908.])
Cumulative counts at iter 1541: tensor([99906512.,  1150066.])
Cumulative counts at iter 1542: tensor([99968936.,  1153176.])
Cumulative counts at iter 1543: tensor([1.0003e+08, 1.1532e+06])
Cumulative counts at iter 1544: tensor([1.0010e+08, 1.1532e+06])
Cumulative counts at iter 1545: tensor([1.0017e+08, 1.1532e+06])
Cumulative counts at iter 1546: tensor([1.0023e+08, 1.1550e+06])
Cumulative counts at iter 1547: tensor([1.0029e+08, 1.1550e+06])
Cumulative counts at iter 1548: tensor([1.0036e+08, 1.1581e+06])
Cumulative counts at iter 1549: tensor([1.0042e+08, 1.1587e+06])
Cumulative counts at iter 1550: tensor([1.0049e+08, 1.1587e+06])
Cumulative counts at iter 1551: tensor([1.0055e+08, 1.1587e+06])
Cumulative counts at iter 1552: tensor([1.0062e+08, 1.1587e+06])
Cumulative counts at iter 1553: tensor([1.0068e+08, 1.1587e+06])
Cumulative counts at iter 1554: tensor([1.0075e+08, 1.1614e+06])
Cumulative counts at iter 1555:

Cumulative counts at iter 1673: tensor([1.0844e+08, 1.2692e+06])
Cumulative counts at iter 1674: tensor([1.0850e+08, 1.2692e+06])
Cumulative counts at iter 1675: tensor([1.0857e+08, 1.2708e+06])
Cumulative counts at iter 1676: tensor([1.0863e+08, 1.2722e+06])
Cumulative counts at iter 1677: tensor([1.0870e+08, 1.2722e+06])
Cumulative counts at iter 1678: tensor([1.0876e+08, 1.2722e+06])
Cumulative counts at iter 1679: tensor([1.0883e+08, 1.2722e+06])
Cumulative counts at iter 1680: tensor([1.0889e+08, 1.2722e+06])
Cumulative counts at iter 1681: tensor([1.0896e+08, 1.2722e+06])
Cumulative counts at iter 1682: tensor([1.0902e+08, 1.2722e+06])
Cumulative counts at iter 1683: tensor([1.0909e+08, 1.2722e+06])
Cumulative counts at iter 1684: tensor([1.0916e+08, 1.2723e+06])
Cumulative counts at iter 1685: tensor([1.0922e+08, 1.2723e+06])
Cumulative counts at iter 1686: tensor([1.0929e+08, 1.2723e+06])
Cumulative counts at iter 1687: tensor([1.0935e+08, 1.2723e+06])
Cumulative counts at iter

Cumulative counts at iter 1803: tensor([1.1686e+08, 1.3670e+06])
Cumulative counts at iter 1804: tensor([1.1693e+08, 1.3670e+06])
Cumulative counts at iter 1805: tensor([1.1699e+08, 1.3670e+06])
Cumulative counts at iter 1806: tensor([1.1706e+08, 1.3670e+06])
Cumulative counts at iter 1807: tensor([1.1712e+08, 1.3670e+06])
Cumulative counts at iter 1808: tensor([1.1719e+08, 1.3670e+06])
Cumulative counts at iter 1809: tensor([1.1725e+08, 1.3670e+06])
Cumulative counts at iter 1810: tensor([1.1732e+08, 1.3670e+06])
Cumulative counts at iter 1811: tensor([1.1738e+08, 1.3672e+06])
Cumulative counts at iter 1812: tensor([1.1745e+08, 1.3695e+06])
Cumulative counts at iter 1813: tensor([1.1751e+08, 1.3695e+06])
Cumulative counts at iter 1814: tensor([1.1758e+08, 1.3695e+06])
Cumulative counts at iter 1815: tensor([1.1764e+08, 1.3695e+06])
Cumulative counts at iter 1816: tensor([1.1771e+08, 1.3727e+06])
Cumulative counts at iter 1817: tensor([1.1777e+08, 1.3727e+06])
Cumulative counts at iter

Cumulative counts at iter 1941: tensor([1.2584e+08, 1.4337e+06])
Cumulative counts at iter 1942: tensor([1.2590e+08, 1.4337e+06])
Cumulative counts at iter 1943: tensor([1.2597e+08, 1.4337e+06])
Cumulative counts at iter 1944: tensor([1.2603e+08, 1.4337e+06])
Cumulative counts at iter 1945: tensor([1.2610e+08, 1.4351e+06])
Cumulative counts at iter 1946: tensor([1.2616e+08, 1.4359e+06])
Cumulative counts at iter 1947: tensor([1.2623e+08, 1.4359e+06])
Cumulative counts at iter 1948: tensor([1.2629e+08, 1.4373e+06])
Cumulative counts at iter 1949: tensor([1.2636e+08, 1.4373e+06])
Cumulative counts at iter 1950: tensor([1.2642e+08, 1.4373e+06])
Cumulative counts at iter 1951: tensor([1.2649e+08, 1.4380e+06])
Cumulative counts at iter 1952: tensor([1.2655e+08, 1.4380e+06])
Cumulative counts at iter 1953: tensor([1.2662e+08, 1.4396e+06])
Cumulative counts at iter 1954: tensor([1.2668e+08, 1.4396e+06])
Cumulative counts at iter 1955: tensor([1.2675e+08, 1.4399e+06])
Cumulative counts at iter

Cumulative counts at iter 2084: tensor([1.3513e+08, 1.5105e+06])
Cumulative counts at iter 2085: tensor([1.3520e+08, 1.5105e+06])
Cumulative counts at iter 2086: tensor([1.3526e+08, 1.5119e+06])
Cumulative counts at iter 2087: tensor([1.3533e+08, 1.5134e+06])
Cumulative counts at iter 2088: tensor([1.3539e+08, 1.5134e+06])
Cumulative counts at iter 2089: tensor([1.3546e+08, 1.5134e+06])
Cumulative counts at iter 2090: tensor([1.3552e+08, 1.5140e+06])
Cumulative counts at iter 2091: tensor([1.3559e+08, 1.5152e+06])
Cumulative counts at iter 2092: tensor([1.3565e+08, 1.5159e+06])
Cumulative counts at iter 2093: tensor([1.3572e+08, 1.5161e+06])
Cumulative counts at iter 2094: tensor([1.3578e+08, 1.5161e+06])
Cumulative counts at iter 2095: tensor([1.3585e+08, 1.5161e+06])
Cumulative counts at iter 2096: tensor([1.3591e+08, 1.5161e+06])
Cumulative counts at iter 2097: tensor([1.3598e+08, 1.5161e+06])
Cumulative counts at iter 2098: tensor([1.3604e+08, 1.5171e+06])
Cumulative counts at iter

Cumulative counts at iter 2216: tensor([1.4372e+08, 1.5713e+06])
Cumulative counts at iter 2217: tensor([1.4379e+08, 1.5713e+06])
Cumulative counts at iter 2218: tensor([1.4385e+08, 1.5713e+06])
Cumulative counts at iter 2219: tensor([1.4392e+08, 1.5717e+06])
Cumulative counts at iter 2220: tensor([1.4398e+08, 1.5717e+06])
Cumulative counts at iter 2221: tensor([1.4405e+08, 1.5717e+06])
Cumulative counts at iter 2222: tensor([1.4411e+08, 1.5729e+06])
Cumulative counts at iter 2223: tensor([1.4418e+08, 1.5729e+06])
Cumulative counts at iter 2224: tensor([1.4424e+08, 1.5729e+06])
Cumulative counts at iter 2225: tensor([1.4431e+08, 1.5729e+06])
Cumulative counts at iter 2226: tensor([1.4438e+08, 1.5729e+06])
Cumulative counts at iter 2227: tensor([1.4444e+08, 1.5729e+06])
Cumulative counts at iter 2228: tensor([1.4450e+08, 1.5748e+06])
Cumulative counts at iter 2229: tensor([1.4457e+08, 1.5760e+06])
Cumulative counts at iter 2230: tensor([1.4463e+08, 1.5760e+06])
Cumulative counts at iter

Cumulative counts at iter 2351: tensor([1.5249e+08, 1.6549e+06])
Cumulative counts at iter 2352: tensor([1.5255e+08, 1.6549e+06])
Cumulative counts at iter 2353: tensor([1.5262e+08, 1.6564e+06])
Cumulative counts at iter 2354: tensor([1.5268e+08, 1.6574e+06])
Cumulative counts at iter 2355: tensor([1.5275e+08, 1.6574e+06])
Cumulative counts at iter 2356: tensor([1.5281e+08, 1.6574e+06])
Cumulative counts at iter 2357: tensor([1.5288e+08, 1.6576e+06])
Cumulative counts at iter 2358: tensor([1.5294e+08, 1.6618e+06])
Cumulative counts at iter 2359: tensor([1.5300e+08, 1.6666e+06])
Cumulative counts at iter 2360: tensor([1.5306e+08, 1.6666e+06])
Cumulative counts at iter 2361: tensor([1.5313e+08, 1.6666e+06])
Cumulative counts at iter 2362: tensor([1.5320e+08, 1.6666e+06])
Cumulative counts at iter 2363: tensor([1.5326e+08, 1.6670e+06])
Cumulative counts at iter 2364: tensor([1.5333e+08, 1.6670e+06])
Cumulative counts at iter 2365: tensor([1.5339e+08, 1.6694e+06])
Cumulative counts at iter

Cumulative counts at iter 2479: tensor([1.6082e+08, 1.7140e+06])
Cumulative counts at iter 2480: tensor([1.6088e+08, 1.7140e+06])
Cumulative counts at iter 2481: tensor([1.6095e+08, 1.7140e+06])
Cumulative counts at iter 2482: tensor([1.6101e+08, 1.7140e+06])
Cumulative counts at iter 2483: tensor([1.6108e+08, 1.7143e+06])
Cumulative counts at iter 2484: tensor([1.6114e+08, 1.7147e+06])
Cumulative counts at iter 2485: tensor([1.6121e+08, 1.7172e+06])
Cumulative counts at iter 2486: tensor([1.6127e+08, 1.7194e+06])
Cumulative counts at iter 2487: tensor([1.6133e+08, 1.7194e+06])
Cumulative counts at iter 2488: tensor([1.6140e+08, 1.7194e+06])
Cumulative counts at iter 2489: tensor([1.6146e+08, 1.7209e+06])
Cumulative counts at iter 2490: tensor([1.6153e+08, 1.7209e+06])
Cumulative counts at iter 2491: tensor([1.6159e+08, 1.7209e+06])
Cumulative counts at iter 2492: tensor([1.6166e+08, 1.7209e+06])
Cumulative counts at iter 2493: tensor([1.6173e+08, 1.7209e+06])
Cumulative counts at iter

Cumulative counts at iter 2609: tensor([1.6925e+08, 1.8018e+06])
Cumulative counts at iter 2610: tensor([1.6931e+08, 1.8018e+06])
Cumulative counts at iter 2611: tensor([1.6938e+08, 1.8036e+06])
Cumulative counts at iter 2612: tensor([1.6944e+08, 1.8051e+06])
Cumulative counts at iter 2613: tensor([1.6951e+08, 1.8051e+06])
Cumulative counts at iter 2614: tensor([1.6957e+08, 1.8051e+06])
Cumulative counts at iter 2615: tensor([1.6964e+08, 1.8053e+06])
Cumulative counts at iter 2616: tensor([1.6970e+08, 1.8057e+06])
Cumulative counts at iter 2617: tensor([1.6977e+08, 1.8057e+06])
Cumulative counts at iter 2618: tensor([1.6983e+08, 1.8057e+06])
Cumulative counts at iter 2619: tensor([1.6990e+08, 1.8066e+06])
Cumulative counts at iter 2620: tensor([1.6996e+08, 1.8072e+06])
Cumulative counts at iter 2621: tensor([1.7003e+08, 1.8072e+06])
Cumulative counts at iter 2622: tensor([1.7009e+08, 1.8072e+06])
Cumulative counts at iter 2623: tensor([1.7016e+08, 1.8073e+06])
Cumulative counts at iter

Cumulative counts at iter 2748: tensor([1.7822e+08, 1.9372e+06])
Cumulative counts at iter 2749: tensor([1.7829e+08, 1.9372e+06])
Cumulative counts at iter 2750: tensor([1.7835e+08, 1.9372e+06])
Cumulative counts at iter 2751: tensor([1.7842e+08, 1.9386e+06])
Cumulative counts at iter 2752: tensor([1.7848e+08, 1.9434e+06])
Cumulative counts at iter 2753: tensor([1.7854e+08, 1.9479e+06])
Cumulative counts at iter 2754: tensor([1.7860e+08, 1.9479e+06])
Cumulative counts at iter 2755: tensor([1.7867e+08, 1.9479e+06])
Cumulative counts at iter 2756: tensor([1.7873e+08, 1.9479e+06])
Cumulative counts at iter 2757: tensor([1.7880e+08, 1.9479e+06])
Cumulative counts at iter 2758: tensor([1.7887e+08, 1.9483e+06])
Cumulative counts at iter 2759: tensor([1.7893e+08, 1.9483e+06])
Cumulative counts at iter 2760: tensor([1.7900e+08, 1.9483e+06])
Cumulative counts at iter 2761: tensor([1.7906e+08, 1.9483e+06])
Cumulative counts at iter 2762: tensor([1.7913e+08, 1.9483e+06])
Cumulative counts at iter

Cumulative counts at iter 2878: tensor([1.8666e+08, 2.0215e+06])
Cumulative counts at iter 2879: tensor([1.8672e+08, 2.0215e+06])
Cumulative counts at iter 2880: tensor([1.8679e+08, 2.0215e+06])
Cumulative counts at iter 2881: tensor([1.8685e+08, 2.0215e+06])
Cumulative counts at iter 2882: tensor([1.8692e+08, 2.0215e+06])
Cumulative counts at iter 2883: tensor([1.8698e+08, 2.0215e+06])
Cumulative counts at iter 2884: tensor([1.8705e+08, 2.0215e+06])
Cumulative counts at iter 2885: tensor([1.8712e+08, 2.0215e+06])
Cumulative counts at iter 2886: tensor([1.8718e+08, 2.0215e+06])
Cumulative counts at iter 2887: tensor([1.8725e+08, 2.0215e+06])
Cumulative counts at iter 2888: tensor([1.8731e+08, 2.0215e+06])
Cumulative counts at iter 2889: tensor([1.8738e+08, 2.0215e+06])
Cumulative counts at iter 2890: tensor([1.8744e+08, 2.0215e+06])
Cumulative counts at iter 2891: tensor([1.8751e+08, 2.0215e+06])
Cumulative counts at iter 2892: tensor([1.8757e+08, 2.0215e+06])
Cumulative counts at iter

Cumulative counts at iter 3007: tensor([1.9507e+08, 2.0663e+06])
Cumulative counts at iter 3008: tensor([1.9513e+08, 2.0663e+06])
Cumulative counts at iter 3009: tensor([1.9520e+08, 2.0663e+06])
Cumulative counts at iter 3010: tensor([1.9526e+08, 2.0663e+06])
Cumulative counts at iter 3011: tensor([1.9533e+08, 2.0663e+06])
Cumulative counts at iter 3012: tensor([1.9539e+08, 2.0663e+06])
Cumulative counts at iter 3013: tensor([1.9546e+08, 2.0663e+06])
Cumulative counts at iter 3014: tensor([1.9552e+08, 2.0663e+06])
Cumulative counts at iter 3015: tensor([1.9559e+08, 2.0667e+06])
Cumulative counts at iter 3016: tensor([1.9565e+08, 2.0678e+06])
Cumulative counts at iter 3017: tensor([1.9572e+08, 2.0678e+06])
Cumulative counts at iter 3018: tensor([1.9579e+08, 2.0678e+06])
Cumulative counts at iter 3019: tensor([1.9585e+08, 2.0678e+06])
Cumulative counts at iter 3020: tensor([1.9592e+08, 2.0678e+06])
Cumulative counts at iter 3021: tensor([1.9598e+08, 2.0678e+06])
Cumulative counts at iter

(0.660316060509598, 0.4909172950532917, 0.9802868874150709)

In [None]:
best_iou_score, train_loss_per_epoch, train_iou_per_epoch, train_acc_per_epoch, valid_loss_per_epoch, valid_iou_per_epoch, valid_acc_per_epoch = train(train_loader, valid_loader, model_dict)

In [None]:
print(f"Best IoU score: {best_iou_score}")
util.plot_train_valid(train_loss_per_epoch, valid_loss_per_epoch, name='Loss')
util.plot_train_valid(train_acc_per_epoch, valid_acc_per_epoch, name='Accuracy')
util.plot_train_valid(train_iou_per_epoch, valid_iou_per_epoch, name='Intersection over Union')

In [None]:
test_loss, test_iou, test_acc, image_outputs, image_labels = modelTest(test_loader, model_dict)
print(f"Test Loss is {test_loss}")
print(f"Test IoU is {test_iou}")
print(f"Test Pixel acc is {test_acc}")

In [None]:
image_index = 4
_, pred_mask = torch.max(pred[image_index].cpu().data, 0)
pred_mask = pred_mask.numpy()
true_mask = true[image_index].cpu().data.numpy()
# visualize_image(pred_mask, true_mask)
visualize_image(pred[image_index].cpu(), true[image_index].cpu())

In [None]:
"""
# ------ GET SAMPLE IMAGE FOR REPORT -------
test_sample_dataset = voc.VOC('test', transforms=sample_transform)
test_sample_loader = DataLoader(dataset=test_sample_dataset, batch_size=1, shuffle=False)
model.eval()
# untransformed original image
orig_inp, _ = next(iter(test_sample_loader))

# transformed image for input to network
inp, label = next(iter(test_loader))
inp = inp.to(device)
label = label.to(device)
output = model(inp)
_, pred = torch.max(output, dim=1)

util.save_sample(np.array(orig_inp[0].cpu(), dtype=np.uint8), label[0].cpu(), pred[0].cpu())
model.train()
# -------------------------------------------
"""

# housekeeping
gc.collect()
torch.cuda.empty_cache()