In [1]:
import cv2
import time
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

### Set Up

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### U-Net Architecture

In [2]:
class TwoConvolutions(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
        )

    def forward(self, input_):
        
        output = self.block(input_)
        
        return output

In [3]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.TwoConvolutions = TwoConvolutions(in_channels, out_channels)
        self.max_pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self, image):
        
        skip_features = self.TwoConvolutions(image)
        features = self.max_pooling(skip_features)
        
        return features, skip_features

In [4]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.upConvolution = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
        self.TwoConvolutions = TwoConvolutions(out_channels * 2, out_channels)

    def forward(self, input_, skip_input):
        
        features = self.upConvolution(input_)
        features = torch.cat([features, skip_input], dim = 1)
        features = self.TwoConvolutions(features)
        
        return features

In [5]:
class U_Net(nn.Module):
    def __init__(self, in_channels, out_channels, depth):
        super().__init__()
        
        self.channels = [in_channels] + [64 * (2 ** i) for i in range(depth + 1)]

        self.Encoder = nn.ModuleList([
            EncoderBlock(self.channels[i], self.channels[i+1]) for i in range(depth)
        ])

        self.Bottleneck = TwoConvolutions(self.channels[depth], self.channels[depth + 1])

        self.channels.reverse()
        self.channels.pop()

        self.Decoder = nn.ModuleList([
            DecoderBlock(self.channels[i], self.channels[i+1]) for i in range(depth)
        ])

        self.FinalConvolution = nn.Conv2d(in_channels = self.channels[-1], out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1))

    def forward(self, image):
        
        encoder_features = []
        features = image

        for block in self.Encoder:
            features, skip_features = block(features)
            encoder_features.append(skip_features)

        features = self.Bottleneck(features)
        encoder_features.reverse()

        for idx, block in enumerate(self.Decoder):
            features = block(features, encoder_features[idx])

        mask = self.FinalConvolution(features)

        return mask

### Optimization and Performance Report Loops

In [30]:
def training_loop(model, training_data, validation_data, batch = 1, learning_rate = 1e-2, num_epochs = 10):

    dataset = ImageDataset(training_data)
    dataloader = DataLoader(dataset, batch_size = batch, shuffle = True)
    model = model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for image, mask in dataloader:            
            image = image.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            logits = model(image)
            loss = loss_function(logits, mask)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        average_loss = total_loss / len(dataloader)
        print(f'Epoch: {epoch} | Loss: {average_loss}')
        performance_report("Validation", model, validation_data, batch)

    return model

In [32]:
def performance_report(type_, model, data, batch_size):

    dataset = ImageDataset(data)
    dataloader = DataLoader(dataset, batch_size, shuffle = False)
    loss_function = nn.CrossEntropyLoss()
    total_loss = 0.0
    model.eval()
    
    with torch.no_grad():
        
        for image, mask in dataloader:
            image = image.to(device)
            mask = mask.to(device)
            logits = model(image)
            loss = loss_function(logits, mask)
            total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    
    print(f"{type_} | Loss: {average_loss} \n")

### Data & Preprocessing

In [8]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, mask = self.data[idx]
        return image, mask

In [16]:
def prep_data(type_, image_nums):
    
    data = []

    for image_num in image_nums:
        
        image = cv2.imread(f'./images-1024x768/{type_}/image-{image_num}.png')
        image = np.transpose(image, (2, 0, 1))
        image = torch.tensor(image, dtype = torch.float32)
        image = image / 255.0
        
        mask = cv2.imread(f'./masks-1024x768/{type_}/mask-{image_num}.png', cv2.IMREAD_GRAYSCALE)
        mask = (mask > 0).astype('int')
        mask = torch.tensor(mask, dtype = torch.long)
        data.append([image, mask])

    return data

In [17]:
def view(image):

    if hasattr(image, "detach"):
        image = image.detach().cpu().numpy()

    plt.imshow(image, cmap = "binary")
    plt.colorbar()
    plt.title("View Image")
    plt.show()

### Playground

In [18]:
train_data = prep_data('train', [2, 7, 10, 12, 21, 24, 27, 28, 30, 43])
val_data = prep_data('val', [1, 11, 22, 32])
test_data = prep_data('test', [4, 16, 29, 36])