In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import scipy
from scipy.io import loadmat
from PIL import Image
import matplotlib.pyplot as plt
import os
import tqdm

In [None]:
print(os.listdir("../input/chiu-2015/2015_BOE_Chiu"))


In [None]:
class Unet(torch.nn.Module):
    
    
    def downBlock(self, in_channels, out_channels, kernel_size = 3):
        block = torch.nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm2d(out_channels),
                                    torch.nn.Conv2d(out_channels, out_channels, kernel_size),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm2d(out_channels))
        return block
    def upBlock(self,in_channels, mid_channels,out_channels, kernel_size):
        block = torch.nn.Sequential(torch.nn.Conv2d(in_channels, mid_channels, kernel_size),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm2d(mid_channels),
                                    torch.nn.Conv2d(mid_channels, mid_channels, kernel_size),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm2d(mid_channels),
                                    torch.nn.ConvTranspose2d(mid_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1))
        return block
    
    def finalBlock(self,in_channels, mid_channels, out_channels, kernel_size):
        block = torch.nn.Sequential(torch.nn.Conv2d(in_channels, mid_channels, kernel_size),
                                   torch.nn.ReLU(),
                                   torch.nn.BatchNorm2d(mid_channels),
                                   torch.nn.Conv2d(mid_channels, mid_channels, kernel_size), 
                                   torch.nn.ReLU(),
                                   torch.nn.BatchNorm2d(mid_channels),
                                   torch.nn.Conv2d(mid_channels, out_channels, kernel_size, padding = 1),
                                   torch.nn.ReLU(),
                                   torch.nn.BatchNorm2d(out_channels))
        return block
    
    
    
    def __init__(self, in_channels, out_channels):
        super(Unet, self).__init__()
        
        #Encoder
        self.encode1 = self.downBlock(in_channels, out_channels = 64, kernel_size = 3)
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size = 2)
        self.encode2 = self.downBlock(64, 128, 3)
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size = 2)
        self.encode3 = self.downBlock(128, 256, 3)        
        self.maxpool3 = torch.nn.MaxPool2d(kernel_size = 2)
        
        #bottleneck layer
        
        self.bottleneck = torch.nn.Sequential(torch.nn.Conv2d(256,512,3),
                                             torch.nn.ReLU(),
                                             torch.nn.BatchNorm2d(512),
                                             torch.nn.Conv2d(512,512,3),
                                             torch.nn.ReLU(),
                                             torch.nn.BatchNorm2d(512),
                                             torch.nn.ConvTranspose2d(512,256, kernel_size = 3, stride = 2, padding = 1, output_padding = 1 ))
        self.decode3 = self.upBlock(512,256,128,3)
        self.decode2 = self.upBlock(256, 128, 64, 3)
        self.finalLayer = self.finalBlock(128, 64, out_channels, 3)
        
    def copy_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
        
    def forward(self, x):
        
        #encoder
        encodeBlock1 = self.encode1(x)
        encodePool1 = self.maxpool1(encodeBlock1)
        encodeBlock2 = self.encode2(encodePool1)
        encodePool2 = self.maxpool2(encodeBlock2)
        encodeBlock3 = self.encode3(encodePool2)
        encodePool3 = self.maxpool3(encodeBlock3)
        
        #bottleneck
        
        bottleneck1 = self.bottleneck(encodePool3)
        
        #decoder
        
        cat3 = self.copy_concat(bottleneck1, encodeBlock3, True)
        decodeBlock3 = self.decode3(cat3)
        cat2 = self.copy_concat(decodeBlock3, encodeBlock2, True)
        decodeBlock2 = self.decode2(cat2)
        cat1 = self.copy_concat(decodeBlock2, encodeBlock1, True)
        finalBlock = self.finalLayer(cat1)
        return finalBlock

        
unet = Unet(in_channels=1,out_channels=2)

In [None]:
path = "../input/chiu-2015/2015_BOE_Chiu"

In [None]:
dataPath = [os.path.join(path, 'Subject_0{}.mat'.format(i)) for i in range(1,10)] + [os.path.join(path, 'Subject_10.mat')]

In [None]:
WIDTH = 284
HEIGHT = 284
WIDTHOUT = 196
HEIGHTOUT = 196
indices = [i for i in range(5,55, 5)]

In [None]:
mat = loadmat(dataPath[0])
img_tensor = mat['images']
manual_fluid_tensor = mat['manualFluid1']

In [None]:
img_array = np.transpose(img_tensor, (2, 0, 1))
manual_fluid_array = np.transpose(manual_fluid_tensor, (2, 0, 1))

In [None]:
temp = np.resize(img_array, (61, 250,250))

In [None]:
plt.imshow(manual_fluid_array[30])

In [None]:
def threshold(x):
    if x == 0:
        return 1
    else:
        return 0
    
thresh = np.vectorize(threshold, otypes = [np.int])

def createDataSet(paths):
    x = list()
    y = list()
    
    for path in tqdm.tqdm(paths):
        mat = loadmat(path)
        images = mat['images']
        fluidTensor = mat['manualFluid1']
        
        images = np.transpose(images, (2,0,1)) / 255
        images = np.resize(images, (images.shape[0], WIDTH, HEIGHT))
        fluidArray = np.transpose(fluidTensor, (2,0,1))
        fliudArray = thresh(fluidArray)
        fludiArray = np.resize(fluidArray, (fluidArray.shape[0], WIDTHOUT, HEIGHTOUT))
        
        for index in indices:
            x = x + [np.expand_dims(images[index], 0)]
            y = y + [np.expand_dims(fludiArray[index], 0)]
    return np.array(x), np.array(y)

trainX, trainY = createDataSet(dataPath[:9])
valX, valY = createDataSet(dataPath[9:])

In [None]:
trainX.shape, trainY.shape, valX.shape, valY.shape

In [None]:
trainY

In [None]:
BATCHSIZE = 18
EPOCHS = 100
THRESHOLD = 0.5


In [None]:
def train(optimizer, criterion, inputs, labels):
    optimizer.zero_grad()
    # Forward, optimize, backward
    outputs = unet(inputs)
    outputs = outputs.permute(0, 2, 3, 1) #check this one, if possible remove it
    labels = labels.resize(BATCHSIZE * WIDTHOUT * HEIGHTOUT) #change this one
    outputs = outputs.resize(BATCHSIZE * WIDTHOUT * HEIGHTOUT, 2)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss

In [None]:
learningRate = 0.01
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = learningRate, momentum=0.99)


In [None]:
def validationLoss(x_val, y_val):
    x_val = torch.from_numpy(x_val).float()
    y_val = torch.from_numpy(y_val).long()
    
    m = x_val.shape[0]
    outputs = unet(x_val)
    # outputs.shape =(batch_size, n_classes, img_cols, img_rows) 
    outputs = outputs.permute(0, 2, 3, 1)
    # outputs.shape =(batch_size, img_cols, img_rows, n_classes) 
    outputs = outputs.resize(m*width_out*height_out, 2)
    labels = y_val.resize(m*width_out*height_out)
    loss = F.cross_entropy(outputs, labels)
    return loss.data
    

In [None]:
BATCHES = trainX.shape[0] // BATCHSIZE
unet = unet.cuda()

In [None]:
for i in range(EPOCHS):
    totalLoss = 0
    print(i)
    for batch in range(BATCHES):
        
        batchX = torch.from_numpy(trainX[batch * BATCHSIZE : (batch+1) * BATCHSIZE]).float()
        batchY = torch.from_numpy(trainY[batch * BATCHSIZE : (batch+1) * BATCHSIZE]).long()
        
        batchX, batchY = batchX.cuda(), batchY.cuda()
        batchLoss = train(optimizer, criterion, batchX, batchY)
        totalLoss += batchLoss
    if i%20 == 0:
        print('Loss at epoch {} is {}'.format(i, totalLoss / BATCHES)) 
        