In [1]:
BATCH_SIZE = 64
N_CHANNELS = 5

In [2]:
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch
from network_definitions.u_net import UNet
from network_definitions.fcn import FCN32s as FCN
from network_definitions.simple_network import SimpleNet
from network_definitions.pyramid_network import PyramidNet
from torchvision.models.segmentation import fcn_resnet101 as FCN_Res101

# Dataset Import

In [3]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pickle

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

class EnsembleDataset(Dataset):
    """Ensemble dataset."""

    def __init__(self, results_file, transform=None):
        """
        Args:
            results_file (string): File with all the results.
        """
        with open(results_file, 'rb') as f:
            #compressed_file = bz2.BZ2File(f, 'r')
            self.results = pickle.load(f)
        self.transform = transform

    def __len__(self):
        return len(self.results)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = self.results[idx]
        sample = {'name': data[0], 'valid': data[1], 'im_seg': data[2], 'im_res': data[3]}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [4]:
from skimage.transform import resize
from torchvision import transforms, utils

class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self,sample):
        name,valid,im_seg,im_res = sample["name"],sample["valid"],sample["im_seg"],sample["im_res"]
        
        return {"name": name, "valid": valid, "im_seg": resize(im_seg,(self.size,self.size,N_CHANNELS),preserve_range=True), "im_res": resize(im_res,(self.size,self.size,1),preserve_range=True)}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        name,valid,im_seg,im_res = sample["name"],sample["valid"],sample["im_seg"],sample["im_res"]

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        im_seg = im_seg.transpose((2, 0, 1))
        im_res = im_res.transpose((2, 0, 1))
        return {"name": name, 
                "valid": valid,
                "im_seg": torch.from_numpy(im_seg),
                "im_res": torch.from_numpy(im_res)}

In [5]:
trainset = EnsembleDataset(results_file='work_dirs/dataset_generation/dataset_no_img.pkl', 
                           transform=transforms.Compose([Resize(572),
                                                         ToTensor()]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=6)

"""trainset = EnsembleDataset(image_dir='data/coco/test2017',
                           results_file='',
                           transform=transforms.Compose([Rescale(256),
                                                         RandomCrop(224),
                                                         ToTensor()]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)"""



"trainset = EnsembleDataset(image_dir='data/coco/test2017',\n                           results_file='',\n                           transform=transforms.Compose([Rescale(256),\n                                                         RandomCrop(224),\n                                                         ToTensor()]))\ntrainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n                                          shuffle=True, num_workers=2)"

# Training

In [6]:
from torch.utils.tensorboard import SummaryWriter

#PATH = "work_dirs/simplenet_1/"


def train(net, trainloader, criterion, optimizer, save_path, tensorboard_path, checkpoint=None):
    
    EPOCH = 0
    
    writer = SummaryWriter(log_dir=tensorboard_path)
    
    if checkpoint != None:
        checkpoint = torch.load(checkpoint)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        EPOCH = checkpoint['epoch']
        loss = checkpoint['loss']
        net.train()
    
    for epoch in range(EPOCH,25):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            im_seg = data["im_seg"].to(device, dtype=torch.float)
            im_res = data["im_res"].to(device, dtype=torch.float)
            valid = data["valid"].to(device, dtype=torch.long)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            out_segm,out_class = net(im_seg.float())
            loss_segm = criterion[0](out_segm, im_res.float())
            #loss_class = criterion[1](torch.round(out_class.cpu().detach()).type(torch.LongTensor), valid.cpu().detach())
            loss_class = 0
            loss = loss_segm+loss_class
            loss.backward()
            optimizer.step()

            # print statistics
            #running_loss_segm += loss_segm.item()
            #running_loss_segm += loss_class.item()
            running_loss += loss.item()
            if i % 50 == 49:    # print every 2000 mini-batches
                """print('[%d, %5d] segm loss: %.6f  class loss: %.6f  loss: %.6f' %
                      (epoch + 1, i + 1, running_loss_segm / 50, running_loss_class / 50, running_loss / 50))"""
                print('[%d, %5d] loss: %.6f' %
                      (epoch + 1, i + 1, running_loss / 50))
                running_loss = 0.0
                inp = im_seg.cpu().detach()
                output = out_segm.cpu().detach()
                output_rounded = torch.round(output)
                gt_output = im_res.cpu().detach()
                out_class = out_class.cpu().detach()
                
                inp = inp.numpy()[0].transpose((1,2,0))
                #.squeeze(axis=0)
                output = output.numpy()[0].transpose((1,2,0)).squeeze(axis=2)
                output_rounded = output_rounded.numpy()[0].transpose((1,2,0)).squeeze(axis=2)
                gt_output = gt_output.numpy()[0].transpose((1,2,0)).squeeze(axis=2)
                
                fig, ax = plt.subplots(nrows=1, ncols=8, figsize=(15,15))
                ax=ax.flat
                
                for i in range(0,5):
                    #ax.append(fig.add_subplot(2, 4, i+1))
                    ax[i].set_title("Input "+str(i+1))  # set title
                    ax[i].imshow(inp[:,:,i],cmap='gray',vmin=0,vmax=1)
                    
                
                    
                #ax.append(fig.add_subplot(2, 4, 6))
                ax[5].set_title("Output")  # set title
                ax[5].imshow(output,cmap='gray',vmin=0,vmax=1)
                
                ax[6].set_title("Rounded Output")  # set title
                ax[6].imshow(output_rounded,cmap='gray',vmin=0,vmax=1)
                
                #ax.append(fig.add_subplot(2, 4, 7))
                ax[7].set_title("GT Output")  # set title
                ax[7].imshow(gt_output,cmap='gray',vmin=0,vmax=1)
                
                fig.tight_layout()

                plt.show()
                
                print("Class Evaluation: ", out_class[0])
                print("Max Value: ",output.max()," Min Value: ",output.min())
            
        writer.add_scalar('Loss', loss, epoch)

        if epoch % 5 == 4:        
            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, save_path+"epoch_"+str(epoch+1)+".pt")
    
    writer.close()

    print('Finished Training')

In [35]:
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet(5,[1],activation="sigmoid").float().to(device)

summary(model, (1,5,572,572))

Layer (type:depth-idx)                   Output Shape              Param #
SimpleNet                                --                        --
├─Sequential: 1-1                        [1, 1, 572, 572]          --
│    └─Conv2D: 2-1                       [1, 1, 572, 572]          --
│    │    └─Conv2d: 3-1                  [1, 1, 572, 572]          6
│    └─BatchNorm: 2-2                    [1, 1, 572, 572]          --
│    │    └─BatchNorm2d: 3-2             [1, 1, 572, 572]          2
│    └─Sigmoid: 2-3                      [1, 1, 572, 572]          --
├─Sequential: 1-2                        [1, 1]                    --
│    └─Conv2d: 2-4                       [1, 1, 286, 286]          2
│    └─Conv2d: 2-5                       [1, 1, 143, 143]          2
│    └─Conv2d: 2-6                       [1, 1, 72, 72]            2
│    └─Flatten: 2-7                      [1, 5184]                 --
│    └─Linear: 2-8                       [1, 1024]                 5,309,440
│    └─LeakyR

In [7]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
OPTIMIZER = "Adam"
ACTIVATION = "sigmoid"
LOSS = "BCELoss"

for layers in [[1],[3],[5]]:
    print("Starting training on network ",layers)
    
    net = SimpleNet(N_CHANNELS,layers,activation=ACTIVATION)
    net = net.to(device).float()
    
    if LOSS == "BCELoss":
        criterion = nn.BCELoss()
        
    criterion_class = nn.NLLLoss()
        
        
    if OPTIMIZER == "SGD":
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    elif OPTIMIZER == "Adam":
        optimizer = optim.Adam(net.parameters(), lr=0.001)
    
    checkpoint_path = "work_dirs/simplenet"
    for layer in layers:
        checkpoint_path += "_"+str(layer)
    checkpoint_path += "/" + OPTIMIZER + "_" + ACTIVATION + "_" + LOSS + "/"
    tensorboard_path = checkpoint_path+"tb/"
    os.makedirs(tensorboard_path,exist_ok=True)
    
    train(net,trainloader,(criterion,criterion_class),optimizer, checkpoint_path, tensorboard_path)#, checkpoint="work_dirs/simplenet_1/epoch_25.pt")

Starting training on network  [1]


KeyboardInterrupt: 

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
event_acc = EventAccumulator('work_dirs/simplenet_1_1_1/sigmoid_BCELoss/tb')
event_acc.Reload()
# Show all tags in the log file
print(event_acc.Tags())

# E. g. get wall clock, number of steps and value for a scalar 'Accuracy'
w_times, step_nums, vals = zip(*event_acc.Scalars('Loss'))

# Network Summary

In [None]:
# for i in range(1):
    data = trainset[i]
    
    im_seg = data['im_seg']
    im_res = data['im_res']
    
    res = im_seg[0:3,:,:].numpy().transpose((1,2,0))
    
    fig = plt.figure()
    plt.imshow(res)