In [None]:
BATCH_SIZE = 1
N_CHANNELS = 8

In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch
from network_definitions.u_net import UNet
from network_definitions.u_net2 import UNet2
from network_definitions.fcn import FCN32s as FCN
from network_definitions.simple_network import SimpleNet
from network_definitions.pyramid_network import PyramidNet
from torchvision.models.segmentation import fcn_resnet101 as FCN_Res101

# Dataset Import

In [None]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pickle
import cv2

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

class EnsembleDataset(Dataset):
    """Ensemble dataset."""

    def __init__(self, root_dir, inc_img=False, transform=None):
        self.root_dir = root_dir
        self.inc_img = inc_img
        self.transform = transform

    def __len__(self):
        return 10533 #13167

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        inp = cv2.imread(self.root_dir+"/img/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)
        n1 = io.imread(self.root_dir+"/net1/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]
        n2 = io.imread(self.root_dir+"/net2/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]
        n3 = io.imread(self.root_dir+"/net3/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]
        n4 = io.imread(self.root_dir+"/net4/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]
        n5 = io.imread(self.root_dir+"/net5/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]
        res = np.dstack((inp,n1,n2,n3,n4,n5))/255
        gt = io.imread(self.root_dir+"/gt/"+str(idx)+".png", cv2.IMREAD_UNCHANGED)[:,:,np.newaxis]/255
        
        sample = {'name': idx, 'inp': res, 'gt': gt}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
from skimage.transform import resize
from torchvision import transforms, utils
    
class Resize(object):
    def __init__(self, size, n_channels):
        self.size = size
        self.n_channels = n_channels

    def __call__(self,sample):
        name,inp,gt = sample["name"],sample["inp"],sample["gt"]
        
        return {"name": name, "inp": resize(inp,(self.size,self.size,self.n_channels),preserve_range=True), "gt": resize(gt,(self.size,self.size,1),preserve_range=True)}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        name,inp,gt = sample["name"],sample["inp"],sample["gt"]

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        inp = inp.transpose((2, 0, 1))
        gt = gt.transpose((2, 0, 1))
        return {"name": name, 
                "inp": torch.from_numpy(inp),
                "gt": torch.from_numpy(gt)}

In [None]:
trainset = EnsembleDataset(root_dir='data/coco_bitwise_or_reduced_ensemble_results', 
                           inc_img=True,
                           transform=transforms.Compose([Resize(512,N_CHANNELS),
                                                         ToTensor()]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=6)

In [None]:
len(trainset)

# Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

#PATH = "work_dirs/simplenet_1/"


def train(net, trainloader, criterion, optimizer, save_path, tensorboard_path, checkpoint=None):
    
    EPOCH = 0
    
    writer = SummaryWriter(log_dir=tensorboard_path)
    
    if checkpoint != None:
        checkpoint = torch.load(checkpoint)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        EPOCH = checkpoint['epoch']
        loss = checkpoint['loss']
        net.train()
    
    for epoch in range(EPOCH,100):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            im_seg = data["inp"].to(device, dtype=torch.float)
            im_res = data["gt"].to(device, dtype=torch.float)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            output = net(im_seg.float())
            loss = criterion(output.float(), im_res.float())
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            
            if i % 2000 == 1999:    # print every 2000 mini-batches
                """print('[%d, %5d] segm loss: %.6f  class loss: %.6f  loss: %.6f' %
                      (epoch + 1, i + 1, running_loss_segm / 50, running_loss_class / 50, running_loss / 50))"""
                print('[%d, %5d] loss: %.6f' %
                      (epoch + 1, i + 1, running_loss / 1999))
                running_loss = 0.0
                
                input_ = im_seg.cpu().detach()
                output_ = output.cpu().detach()
                gt_output_ = im_res.cpu().detach()
                
                #output_ = torch.argmax(output_,1)
                #print(output_.shape)
                
                input_ = input_.numpy()[0].transpose((1,2,0))
                output_ = output_.numpy()[0].transpose((1,2,0))
                
                gt_output_ = gt_output_.numpy()[0].transpose((1,2,0)).squeeze(axis=2)
                
                fig, ax = plt.subplots(nrows=1, ncols=9, figsize=(15,15))
                ax=ax.flat
                
                ax[0].set_title("Original Image")
                ax[0].imshow(input_[:,:,0:3])
                
                
                for i in range(0,5):
                    #ax.append(fig.add_subplot(2, 4, i+1))
                    ax[i+1].set_title("Input "+str(i+1))  # set title
                    ax[i+1].imshow(input_[:,:,i+3],cmap='gray',vmin=0,vmax=1)
                    
                ax[6].set_title("Output")  # set title
                ax[6].imshow(output_,cmap='gray',vmin=0,vmax=1)
                
                ax[7].set_title("Output Rounded")  # set title
                ax[7].imshow(np.around(output_),cmap='gray',vmin=0,vmax=1)
                
                #ax.append(fig.add_subplot(2, 4, 7))
                ax[8].set_title("Ground Truth")  # set title
                ax[8].imshow(gt_output_,cmap='gray',vmin=0,vmax=1)
                
                fig.tight_layout()
                plt.show()
            
        writer.add_scalar('Loss', loss, epoch)

        if epoch % 2 == 1:        
            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, save_path+"epoch_"+str(epoch+1)+".pt")
    
    writer.close()

    print('Finished Training')

In [None]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
OPTIMIZER = "SGD"
ACTIVATION = "lrelu"
LOSS = "BCELoss"

layers = [""]

#for layers in #[[(3,8,16),(3,16,32),(5,32,64),(5,64,32),(3,32,16),(3,16,2)]]:
print("Starting training on network ",layers)
    
net = UNet2(N_CHANNELS,1)#layers,activation=ACTIVATION)
net = net.to(device).float()

if LOSS == "BCELoss":
    criterion = nn.BCELoss()
elif LOSS == "CrossEntropyLoss":
    criterion = nn.CrossEntropyLoss()

if OPTIMIZER == "SGD":
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
elif OPTIMIZER == "Adam":
    optimizer = optim.Adam(net.parameters(), lr=0.001)

checkpoint_path = "work_dirs/unet2_bitwise_or_img_ensemble_reduced_do_40"
for layer in layers:
    checkpoint_path += "_"+str(layer)
checkpoint_path += "/" + OPTIMIZER + "_" + ACTIVATION + "_" + LOSS + "/"
tensorboard_path = checkpoint_path+"tb/"
os.makedirs(tensorboard_path,exist_ok=True)

train(net,trainloader,criterion,optimizer, checkpoint_path, tensorboard_path)#, checkpoint="work_dirs/simplenet_1/epoch_25.pt")

In [None]:
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet([(3,8,16),(3,16,32),(3,32,64),(3,64,32),(3,32,16),(3,16,2)],activation="lrelu").float().to(device)

summary(model, (1,8,572,572))

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
event_acc = EventAccumulator('work_dirs/simplenet_1_1_1/sigmoid_BCELoss/tb')
event_acc.Reload()
# Show all tags in the log file
print(event_acc.Tags())

# E. g. get wall clock, number of steps and value for a scalar 'Accuracy'
w_times, step_nums, vals = zip(*event_acc.Scalars('Loss'))

# Network Summary

In [None]:
# for i in range(1):
    data = trainset[i]
    
    im_seg = data['im_seg']
    im_res = data['im_res']
    
    res = im_seg[0:3,:,:].numpy().transpose((1,2,0))
    
    fig = plt.figure()
    plt.imshow(res)