In [None]:
BATCH_SIZE = 16
N_CHANNELS = 4

In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch
from network_definitions.u_net import UNet
from network_definitions.fcn import FCN32s as FCN
from network_definitions.simple_network import SimpleNet
from network_definitions.pyramid_network import PyramidNet
from torchvision.models.segmentation import fcn_resnet101 as FCN_Res101

# Dataset Import

In [None]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pickle

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

class PatchesDatasetTrain(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return 20000

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        inp = io.imread(self.root_dir+"/inp/"+str(idx)+".png")
        gt = io.imread(self.root_dir+"/gt/"+str(idx)+".png",as_gray=True)
        
        sample = {'name': idx, 'inp': inp, 'gt': gt}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
class PatchesDatasetTest(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return 4629

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        inp = io.imread(self.root_dir+"/inp/"+str(idx+20000)+".png")
        gt = io.imread(self.root_dir+"/gt/"+str(idx+20000)+".png",as_gray=True)
        
        sample = {'name': idx, 'inp': inp, 'gt': gt}

        if self.transform:
            sample = self.transform(sample)

        return sample
        
    
from skimage.transform import resize
from torchvision import transforms, utils
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        name,inp,gt = sample["name"],sample["inp"],sample["gt"]

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        inp = inp.transpose((2, 0, 1))/255
        gt = gt[:,:,np.newaxis].transpose((2, 0, 1))/255
        return {"name": name, 
                "inp": torch.from_numpy(inp),
                "gt": torch.from_numpy(gt)}

In [None]:
trainset = PatchesDatasetTrain("data/cityscapes_patches", 
                               transform=transforms.Compose([ToTensor()]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=6)

"""trainset = EnsembleDataset(image_dir='data/coco/test2017',
                           results_file='',
                           transform=transforms.Compose([Rescale(256),
                                                         RandomCrop(224),
                                                         ToTensor()]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)"""



In [None]:
len(trainset)

# Training

In [None]:
from torch.utils.tensorboard import SummaryWriter
import cv2

#PATH = "work_dirs/simplenet_1/"


def train(net, trainloader, criterion, optimizer, save_path, tensorboard_path, checkpoint=None):
    
    EPOCH = 0
    
    writer = SummaryWriter(log_dir=tensorboard_path)
    
    if checkpoint != None:
        checkpoint = torch.load(checkpoint)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        EPOCH = checkpoint['epoch']
        loss = checkpoint['loss']
        net.train()
    
    for epoch in range(EPOCH,1000):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            im_seg = data["inp"].to(device, dtype=torch.float)
            im_res = data["gt"].to(device, dtype=torch.long)

            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            output = net(im_seg.float())
            
            loss = criterion(output.float(), im_res.float())
            loss.backward(retain_graph=True)
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            
            if i % 100 == 99:    # print every 100 mini-batches
                """print('[%d, %5d] segm loss: %.6f  class loss: %.6f  loss: %.6f' %
                      (epoch + 1, i + 1, running_loss_segm / 50, running_loss_class / 50, running_loss / 50))"""
                print('[%d, %5d] loss: %.6f' %
                      (epoch + 1, i + 1, running_loss / 99))
                running_loss = 0.0
                
                input_ = im_seg.cpu().detach()
                output_ = output.cpu().detach()
                #output_ = torch.argmax(output_,1)
                #print(output_.shape)
                gt_output_ = im_res.cpu().detach()
                
                input_ = input_.numpy()[0].transpose((1,2,0))
                output_ = output_.numpy()[0].transpose((1,2,0))
                
                gt_output_ = gt_output_.numpy()[0].transpose((1,2,0)).squeeze(axis=2)
                
                print(np.amax(output_))
                
                fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(15,15))
                ax=ax.flat
                    
                ax[0].set_title("Original Image Patch")  # set title
                ax[0].imshow(cv2.cvtColor(input_[:,:,0:3], cv2.COLOR_BGR2RGB))
                
                #ax.append(fig.add_subplot(2, 4, 7))
                ax[1].set_title("Boundary Input")  # set title
                ax[1].imshow(input_[:,:,3],cmap='gray')
                
                ax[2].set_title("Boundary Output")
                ax[2].imshow(output_,cmap='gray')
                
                ax[3].set_title("Boundary Output Thresholded")
                ax[3].imshow(np.around(output_),cmap='gray')
                
                ax[4].set_title("Ground Truth")
                ax[4].imshow(gt_output_,cmap='gray')
                
                fig.tight_layout()
                plt.show()
                
                print("Max Value: ",output_.max()," Min Value: ",output_.min())
            
        writer.add_scalar('Loss', loss, epoch)

        if epoch % 10 == 9:        
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, save_path+"epoch_"+str(epoch+1)+".pt")    
    
    writer.close()

    print('Finished Training')

In [None]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
OPTIMIZER = "SGD"
LOSS = "BCELoss"

print("Starting training on network: UNet ")

net = UNet(4,1)
net = net.to(device).float()

if LOSS == "BCELoss":
    criterion = nn.BCELoss()
elif LOSS == "CrossEntropyLoss":
    criterion = nn.CrossEntropyLoss()

if OPTIMIZER == "SGD":
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
elif OPTIMIZER == "Adam":
    optimizer = optim.Adam(net.parameters(), lr=0.001)

checkpoint_path = "work_dirs/unet_boundary_refinement/"
tensorboard_path = checkpoint_path+"tb/"
os.makedirs(tensorboard_path,exist_ok=True)

train(net,trainloader,criterion,optimizer, checkpoint_path, tensorboard_path)#, checkpoint="work_dirs/simplenet_1/epoch_25.pt")

In [None]:
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(4,2).float().to(device)

summary(model, (1,4,64,64))

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
event_acc = EventAccumulator('work_dirs/simplenet_1_1_1/sigmoid_BCELoss/tb')
event_acc.Reload()
# Show all tags in the log file
print(event_acc.Tags())

# E. g. get wall clock, number of steps and value for a scalar 'Accuracy'
w_times, step_nums, vals = zip(*event_acc.Scalars('Loss'))

# Testing

In [None]:
testset = PatchesDatasetTest("data/cityscapes_patches", 
                               transform=transforms.Compose([ToTensor()]))
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=6)

In [None]:
def test(net, testloader):
    
    aoa_inp,ap_inp,ar_inp,f1_inp,aoa_pred,ap_pred,ar_pred,f1_pred = 0,0,0,0,0,0,0,0
    
    for i, data in enumerate(testloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        im_seg = data["inp"].to(device, dtype=torch.float)
        im_res = data["gt"].to(device, dtype=torch.long)

        # forward + backward + optimize
        output = net(im_seg.float())

        input_ = im_seg.cpu().detach()
        output_ = output.cpu().detach()
        #output_ = torch.argmax(output_,1)
        #print(output_.shape)
        gt_output_ = im_res.cpu().detach()

        pos_input_ = input_.numpy()[0].transpose((1,2,0))
        pos_output_ = output_.numpy()[0].transpose((1,2,0))

        pos_gt_output_ = gt_output_.numpy()[0].transpose((1,2,0)).squeeze(axis=2)

        #print(np.amax(output_))

        """fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(15,15))
        ax=ax.flat

        ax[0].set_title("Original Image Patch")  # set title
        ax[0].imshow(cv2.cvtColor(pos_input_[:,:,0:3], cv2.COLOR_BGR2RGB))

        #ax.append(fig.add_subplot(2, 4, 7))
        ax[1].set_title("Boundary Input")  # set title
        ax[1].imshow(pos_input_[:,:,3],cmap='gray')

        ax[2].set_title("Boundary Output")
        ax[2].imshow(pos_output_,cmap='gray')

        ax[3].set_title("Boundary Output Thresholded")
        ax[3].imshow(np.around(pos_output_),cmap='gray')

        ax[4].set_title("Ground Truth")
        ax[4].imshow(pos_gt_output_,cmap='gray')

        fig.tight_layout()
        plt.show()"""
        
        output_ = np.around(output_)
        
        tp_inp = np.count_nonzero(np.logical_and(gt_output_,input_[:,3,:,:,][:,np.newaxis,:,:]))
        fp_inp = np.count_nonzero(np.logical_and(np.logical_not(gt_output_),input_[:,3,:,:,][:,np.newaxis,:,:]))
        tn_inp = np.count_nonzero(np.logical_and(np.logical_not(gt_output_),np.logical_not(input_[:,3,:,:,][:,np.newaxis,:,:])))
        fn_inp = np.count_nonzero(np.logical_and(gt_output_,np.logical_not(input_[:,3,:,:,][:,np.newaxis,:,:])))
        oa_inp = (tp_inp+tn_inp)/(tp_inp+fp_inp+tn_inp+fn_inp)
        p_inp = tp_inp/(tp_inp+fp_inp)
        r_inp = tp_inp/(tp_inp+fn_inp)
        
        tp_pred = np.count_nonzero(np.logical_and(gt_output_,output_))
        fp_pred = np.count_nonzero(np.logical_and(np.logical_not(gt_output_),output_))
        tn_pred = np.count_nonzero(np.logical_and(np.logical_not(gt_output_),np.logical_not(output_)))
        fn_pred = np.count_nonzero(np.logical_and(gt_output_,np.logical_not(output_)))
        oa_pred = (tp_pred+tn_pred)/(tp_pred+fp_pred+tn_pred+fn_pred)
        p_pred = tp_pred/(tp_pred+fp_pred)
        r_pred = tp_pred/(tp_pred+fn_pred)
        
        aoa_inp += oa_inp
        ap_inp += p_inp
        ar_inp += r_inp
        f1_inp += (2*p_inp*r_inp)/(p_inp+r_inp)
        aoa_pred += oa_pred
        ap_pred += p_pred
        ar_pred += r_pred
        f1_pred += (2*p_pred*r_pred)/(p_pred+r_pred)
    
    aoa_inp /= len(testloader)
    ap_inp /= len(testloader)
    ar_inp /= len(testloader)
    f1_inp /= len(testloader)
    aoa_pred /= len(testloader)
    ap_pred /= len(testloader)
    ar_pred /= len(testloader)
    f1_pred /= len(testloader)
        
    
    print("Accuracy Input: "+str(aoa_inp)+"  Precision Input: "+str(ap_inp)+"  Recall Input: "+str(ar_inp)+"  F1 Input: "+str(f1_inp))
    print("Accuracy Prediction: "+str(aoa_pred)+"  Precision Prediction: "+str(ap_pred)+"  Recall Prediction: "+str(ar_pred)+"  F1 Prediction: "+str(f1_pred))

    
    return oa_pred,ap_pred,ar_pred,f1_pred

In [None]:
results_oa = []
results_p = []
results_r = []
results_f1 = []

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = UNet(4,1)
net = net.to(device).float()

for i in range(10,1001,10):
    checkpoint = "work_dirs/unet_boundary_refinement/epoch_"+str(i)+".pt"

    checkpoint = torch.load(checkpoint)
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()
    
    print("Epoch "+str(i))
    oa,ap,ar,f1 = test(net,testloader)
    results_oa.append(oa)
    results_p.append(ap)
    results_r.append(ar)
    results_f1.append(f1)

import matplotlib.pyplot as plt

# line 1 points
x0 = list(range(10,1001,10))
y0 = results_oa
# plotting the line 1 points 
plt.plot(x0, y0, label = "Average Overall Accuracy")
plt.show()

# line 1 points
x1 = list(range(10,1001,10))
y1 = results_p
# plotting the line 1 points 
plt.plot(x1, y1, label = "Average Precision")
plt.show()


# line 2 points
x2 = list(range(10,1001,10))
y2 = results_r
# plotting the line 2 points 
plt.plot(x2, y2, label = "Average Recall")
plt.show()

# line 1 points
x3 = list(range(10,1001,10))
y3 = results_f1
# plotting the line 1 points 
plt.plot(x3, y3, label = "Average F1 Score")
# Display a figure.
plt.show()
