In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from data_handling import batchloader as loader
import torch
from torch.optim import Adam
import matplotlib.pyplot as plt

In [None]:
im_directory = 'data/images/'
mask_directory = 'data/masks/'

def batchloader(batchsize = None, batch = None):

    # if batch is given iterate throught batch indices
    if batchsize == None and batch != None:
        j = 0
        for i in batch:
            # open image
            f_im = im_directory + str(i) + ".jpg"
            f_mask = mask_directory + str(i) + ".jpg"
            im = Image.open(f_im)
            mask = Image.open(f_mask)
            mask= mask.convert('L')
            
            # extract data into numpy array  
            im_data = np.array(im.getdata()).T
            mask_data = np.array(mask.getdata()).T

            # get data into right shape and concat image data to final data array
            if j == 0:

                data_im = np.array([im_data.reshape(3, 256, 256)])
                data_mask = np.array([mask_data.reshape(1, 256, 256)])
                j += 1

            else:
                im_data = im_data.reshape(1, 3, 256, 256)
                data_im = np.concatenate((data_im, im_data))

                mask_data = mask_data.reshape(1, 1, 256, 256)
                data_mask = np.concatenate((data_mask, mask_data))
        
        return torch.tensor(data_im).float(), torch.tensor(data_mask).float()

    if batchsize != None and batch == None:

        # if batch is not given generate random batch of size batchsize
        batch = np.random.randint(5108, size = batchsize)
        j = 0
        for i in batch:
            # open image
            f_im = os.path.join(im_directory, str(i) + ".jpg")
            f_mask = os.path.join(mask_directory, str(i) + ".jpg")
            im = Image.open(f_im)
            mask = Image.open(f_mask)
            mask= mask.convert('L')
            
            # extract data into numpy array  
            im_data = np.array(im.getdata()).T
            mask_data = np.array(mask.getdata()).T

            # get data into right shape and concat image data to final data array
            if j == 0:

                data_im = np.array([im_data.reshape(3, 256, 256)])
                data_mask = np.array([mask_data.reshape(1, 256, 256)])
                j += 1

            else:
                im_data = im_data.reshape(1, 3, 256, 256)
                data_im = np.concatenate((data_im, im_data))

                mask_data = mask_data.reshape(1, 1, 256, 256)
                data_mask = np.concatenate((data_mask, mask_data))
        
        return torch.tensor(data_im).float(), torch.tensor(data_mask).float()

    if batchsize == None and batch == None:
        # if batchsize and batch are None raise Exception
        raise Exception("Weder batchsize noch batch deklariert!")

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256), dec_chs=(256, 128, 64), num_class=1, retain_dim=False, out_sz=(256,256)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out



In [None]:
Model = UNet(retain_dim=True)

lossFunc = torch.nn.MSELoss()
opt = Adam(Model.parameters())

print("Setup fertig")

train_steps = 10
test_steps = 0

# initialize a dictionary to store training historys
H = {"train_loss": [], "test_loss": []}

for e in range(10):
    # set the model in training mode
    Model.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0

    # loop over the training set
    for i in range(train_steps):

        x, y = loader.batchloader(10)

        # perform a forward pass and calculate the training loss
        pred = Model.forward(x)
        loss = lossFunc(pred, y)
        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss

    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / train_steps
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, 10))
    print("Train loss: {:.6f}".format(
    avgTrainLoss))

# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.show()