In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

Importing Data

In [9]:
image = Image.open('test_image.jpeg')
transform = transforms.Compose([
    transforms.Resize((572, 572)),
    transforms.ToTensor()])
image = transform(image)
image.shape

torch.Size([3, 572, 572])

UNET Model

In [14]:
class UNET(nn.Module): 
    def __init__(self): 
        super().__init__()
        
        self.feature_map_1  = []
        self.feature_map_2  = []
        self.feature_map_3  = []
        self.feature_map_4  = []
        
    def crop(self, input, initial_size, desired_size):
        
        start = (initial_size - desired_size) // 2
        end = start + desired_size
        
        cropped_tensor = input[:, start:end, start:end]
        
        return cropped_tensor
    
    def contraction_block(self, input, input_channel, output_channel): 
        
        block = nn.Sequential(
            nn.Conv2d(input_channel,  output_channel, 3),
            nn.ReLU(),
            nn.Conv2d(output_channel, output_channel, 3), 
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        return block(input)    
        
    def contraction_pass(self, input): 
        
        self.feature_map_1 = self.contraction_block(input, 3, 64)
        self.feature_map_2 = self.contraction_block(self.feature_map_1, 64, 128)
        self.feature_map_3 = self.contraction_block(self.feature_map_2, 128, 256)
        self.feature_map_4 = self.contraction_block(self.feature_map_3, 256, 512)
        result = self.contraction_block(self.feature_map_4, 512, 1024)

        return result
    
    def expansion_block(self, input, feature_map, input_channel, output_channel):
        
        block = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, 3),
            nn.ReLU(),
            nn.Conv2d(output_channel, output_channel, 3),
            nn.ReLU(),
        )
        
        upscale_layer = nn.ConvTranspose2d(input_channel, output_channel, 2, 2)
        upscale_input = upscale_layer(input)
        cropped_feature_map = self.crop(feature_map, feature_map.shape[1], upscale_input.shape[1])
        concat = F.relu(torch.cat((cropped_feature_map, upscale_input), dim=0))
        result = block(concat)

        return result

    def expansion_pass(self, input): 
        
        upscale_1 = self.expansion_block(input, self.feature_map_4, 1024, 512)
        upscale_2 = self.expansion_block(upscale_1, self.feature_map_3, 512, 256)
        upscale_3 = self.expansion_block(upscale_2, self.feature_map_2, 256, 128)
        upscale_4 = self.expansion_block(upscale_3, self.feature_map_1, 128, 64)
        
        output_segmentation_map_layer = nn.Conv2d(64, 2, 1)
        output_segmentation_map = output_segmentation_map_layer(upscale_4)
        
        return output_segmentation_map
    
    def forward(self, input): 
        contraction_result = self.contraction_pass(input)
        expansion_result = self.expansion_pass(contraction_result)
        
        
        return expansion_result

model = UNET()

Training Split

In [15]:
result = model(image)
result.shape

torch.Size([1024, 14, 14])


torch.Size([2, 164, 164])

Test Split: What is my output going to look like? 