In [6]:
import os
import cv2
import tqdm
import itertools
import numpy as np
from tqdm.contrib import tzip
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [7]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset

In [8]:
class SegNetBase(nn.Module):
    
    def __init__(self, in_channels = 3, out_channels = 32, debug = False):
        
        super(SegNetBase, self).__init__()
        self.debug = debug
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        ## encode:
        
        self.encoder_0 = nn.Sequential(*[nn.Conv2d(in_channels = self.in_channels, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        self.encoder_1 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        self.encoder_2 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        self.encoder_3 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        
        ## decode:
        
        self.decoder_3 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size = 7, padding = 3), nn.BatchNorm2d(64)])
        self.decoder_2 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        self.decoder_1 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = 64,
                                            kernel_size=7, padding=3), nn.BatchNorm2d(64)])
        self.decoder_0 = nn.Sequential(*[nn.Conv2d(in_channels = 64, out_channels = self.out_channels,
                                            kernel_size=1, padding=0)])
    def forward(self, x):

        debug = self.debug

        ## encode:
        x_0_size = x.size()
        x = self.encoder_0(x)
        x = F.relu(x)
        x_e0_size = x.size()
        x, indices_0 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        
        x_1_size = x.size()
        x = self.encoder_1(x)
        x = F.relu(x)
        x_e1_size = x.size()
        x, indices_1 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        
        x_2_size = x.size()
        x = self.encoder_2(x)
        x = F.relu(x)
        x_e2_size = x.size()
        x, indices_2 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        
        x_3_size = x.size()
        x = self.encoder_3(x)
        x = F.relu(x)
        x_e3_size = x.size()
        x, indices_3 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        
        ## decode:
        encoded_size = x.size()

        x = F.max_unpool2d(x, indices_3, kernel_size = 2, stride = 2, output_size = x_3_size)
        x = self.decoder_3(x)
        x = F.relu(x)
        x_d3_size = x.size()
        
        x = F.max_unpool2d(x, indices_2, kernel_size = 2, stride = 2, output_size = x_2_size)
        x = self.decoder_2(x)
        x = F.relu(x)
        x_d2_size = x.size()
        
        x = F.max_unpool2d(x, indices_1, kernel_size = 2, stride = 2, output_size = x_1_size)
        x = self.decoder_1(x)
        x = F.relu(x)
        x_d1_size = x.size()
        
        x = F.max_unpool2d(x, indices_0, kernel_size = 2, stride = 2, output_size = x_0_size)
        x = self.decoder_0(x)
        x = F.relu(x)
        x_d0_size = x.size()
        
        x_softmax = F.softmax(x, dim=1)

        if debug:
            print("x_0_size: {}".format(x_0_size))
            print("x_1_size: {}".format(x_1_size))
            print("x_2_size: {}".format(x_2_size))
            #print("x_3_size: {}".format(x_3_size))

            print("encoded_size: {}".format(encoded_size))

            #print("x_d3_size: {}".format(x_d3_size))
            print("x_d2_size: {}".format(x_d2_size))
            print("x_d1_size: {}".format(x_d1_size))
            print("x_d0_size: {}".format(x_d0_size))

        return x, x_softmax

In [4]:
def train(model, train_loader, criterion, optimizer, NUM_EPOCHS = 30, log = True, load = False):
    
    print('Epochs:\t', NUM_EPOCHS)
    if load:
        model = SegNetBase()
        model.load_state_dict(torch.load("./model_best.pth"))
    
    losses = []
    t_losses = []
    prev_loss = float('inf')

    model.train()

    for epoch in range(NUM_EPOCHS):
        loss_f = []
        t_start = time.time()

        for i, (input_tensor, target_tensor) in enumerate(train_loader):
            
            input_tensor = torch.autograd.Variable(input_tensor)
            target_tensor = torch.autograd.Variable(target_tensor)
            
            predicted_tensor, softmaxed_tensor = model(input_tensor) #print(target_tensor.shape, predicted_tensor.shape)
            target_tensor = target_tensor.type(torch.LongTensor)
            loss = criterion(predicted_tensor, target_tensor)
            
            if log: print('batch loss:', float(loss))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_f.append(float(loss))
            prediction_f = softmaxed_tensor.float()
        
        delta = time.time() - t_start
        
        dataiter = iter(test_loader)
        images, label = dataiter.next()
        input_tensor = torch.autograd.Variable(images)
        target_tensor = torch.autograd.Variable(label)
        predicted_tensor, softmaxed_tensor = model(input_tensor)
        target_tensor = target_tensor.type(torch.LongTensor)
        test_loss = criterion(predicted_tensor, target_tensor)
        
        if np.array(loss_f).mean() < prev_loss:
            prev_loss = np.array(loss_f).mean()
            torch.save(model.state_dict(), './model_best.pth')
        
        losses.append(np.array(loss_f).mean())
        t_losses.append(float(test_loss))
        
        print("Epoch #{}\ttrain loss: {:.8f}\ttest loss: {:.8f}\t Time: {:2f}s".format(epoch+1, np.array(loss_f).mean(),t_losses[-1], delta))
        
    return losses, t_losses

In [5]:
def val(test_loader, model = None, model_path = ''):
    dr = Data_reader()
    correctnesses = []
    model = SegNetBase()
    model.load_state_dict(torch.load(model_path))
    if len(model_path):
        model = SegNetBase()
        model.load_state_dict(torch.load(model_path))
    
    model.eval()
    images = []
    labels = []
    results = []
    
    i = 0
    for img, label in test_loader:
        if i > 10:
            break
        i += 1
        img = img.to(device)
        label = label.cpu().numpy()
        label = label.squeeze()
        output, class_prob = model(img)
        a = 1 - torch.count_nonzero((torch.argmax(class_prob, axis = 1)-torch.tensor(label)))/(360*480)
        model_output = dr.rev_translate(torch.argmax(class_prob, axis = 1))
        
        images.append(img)
        labels.append(label)
        results.append(model_output)
        correctnesses.append(a)
    return images, labels, results, correctnesses