In [1]:
import cv2
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(output_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, feature_bank):
        feature_bank = self.block(feature_bank)
        return feature_bank

In [4]:
class Decoder(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(output_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, decoder_feature_bank, encoder_feature_bank):

        if decoder_feature_bank.shape[2:] != encoder_feature_bank.shape[2:]:
            decoder_feature_bank = F.interpolate(
                decoder_feature_bank,
                size=encoder_feature_bank.shape[2:],
                mode='bilinear',
                align_corners=False
            )

        feature_bank = torch.cat([decoder_feature_bank, encoder_feature_bank], dim = 1)
        feature_bank = self.block(feature_bank)
        return feature_bank

In [5]:
class U_Net(nn.Module):
    def __init__(self):
        super().__init__()
        channels = [3, 64, 128, 256, 512, 1024]
        self.encoder = nn.ModuleList([
            Encoder(channels[i], channels[i + 1]) for i in range(len(channels) - 1)
        ])
        channels.reverse()
        self.decoder = nn.ModuleList([
            Decoder(channels[i], channels[i + 1]) for i in range(len(channels) - 2)
        ])

        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.convUps = nn.ModuleList([
            nn.ConvTranspose2d(channels[i], channels[i + 1], kernel_size = 2, stride = 2) for i in range(len(channels) - 2)
        ])

        self.final_conv = nn.Conv2d(64, 2, kernel_size = 3, padding = 1)


    def forward(self, img):

        encoder_feature_banks = []
        feature_bank = img

        for idx, block in enumerate(self.encoder):
            feature_bank = block(feature_bank)

            if idx < len(self.encoder) - 1:
                encoder_feature_banks.append(feature_bank)
                feature_bank = self.max_pool(feature_bank)

        encoder_feature_banks.reverse()

        for idx, block in enumerate(self.decoder):
            feature_bank = self.convUps[idx](feature_bank)
            feature_bank = block(feature_bank, encoder_feature_banks[idx])

        output = self.final_conv(feature_bank)

        return output

In [6]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, mask = self.data[idx]
        return image, mask

In [7]:
def retrieve_data(type_, img_nums):
    data = []

    for img_num in img_nums:
        img_label = cv2.imread(f'./images-1024x768/{type_}/image-{img_num}.png')
        img_target = cv2.imread(f'./masks-1024x768/{type_}/mask-{img_num}.png', cv2.IMREAD_GRAYSCALE)
        foreground = (img_target > 0).astype(np.float32)
        background = 1.0 - foreground
        img_target = np.stack([foreground, background], axis = 0)
        data.append([img_label, img_target])

    return data

train_data = retrieve_data('train', [2, 7, 10, 12, 21, 24, 27, 28, 30, 43])
val_data = retrieve_data('val', [1, 11, 22, 32])
test_data = retrieve_data('test', [4, 16, 29, 36])

In [8]:
def preprocess_img(img):
    image_tensor = torch.tensor(img, dtype = torch.float32)
    image_tensor = image_tensor.permute(2, 0, 1)
    return image_tensor / 255.0

def preprocess_mask(img):
    image_tensor = torch.tensor(img, dtype = torch.float32)
    return image_tensor

train_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in train_data]
val_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in val_data]
test_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in test_data]

In [9]:
def training_loop(train_data, val_data, num_epochs = 10, batch_size = 1, learning_rate = 1e-2, device = "cuda"):

    train_dataset = ImageDataset(train_data)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle = True)

    model = U_Net().to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

    model.train()

    for epoch in range(num_epochs):
        for image, mask in train_dataloader:
            image = image.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            logits = model(image)
            loss = loss_function(logits, mask)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch} | Loss: {loss.item()}')

        validation_loop(model, val_data, 1024, device)

    return model

In [10]:
def validation_loop(model, data, batch_size = 1, device = "cuda"):

    dataset = ImageDataset(data)
    dataloader = DataLoader(dataset, batch_size, shuffle = False)

    loss_function = nn.CrossEntropyLoss()
    total_loss = 0.0
    correct = 0
    total = 0

    model.eval()
    with torch.no_grad():
        for image, mask in dataloader:
            image = image.to(device)
            mask = mask.to(device)

            logits = model(image)
            loss = loss_function(logits, mask)
            total_loss += loss.item()

            total += image.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    print(f"Validation Loss: {avg_loss:.4f}")

In [48]:
training_loop(
    train_data,
    val_data,
    num_epochs = 10,
    batch_size = 1,
    learning_rate = 1e-2,
    device = "cuda"
)

Epoch 0 | Loss: 1088750080.0
Validation Loss: 15893183.0000
Epoch 1 | Loss: 6698553.5
Validation Loss: 82457.6953
Epoch 2 | Loss: 1.817455768585205
Validation Loss: 2.4299
Epoch 3 | Loss: 0.7457923889160156
Validation Loss: 0.7851
Epoch 4 | Loss: 0.6285450458526611
Validation Loss: 0.5739
Epoch 5 | Loss: 0.5655540227890015
Validation Loss: 0.5644
Epoch 6 | Loss: 0.5307329893112183
Validation Loss: 0.5518
Epoch 7 | Loss: 0.6060792207717896
Validation Loss: 0.5563
Epoch 8 | Loss: 0.5280147790908813
Validation Loss: 0.5512
Epoch 9 | Loss: 0.5297563672065735
Validation Loss: 0.5577


U_Net(
  (encoder): ModuleList(
    (0): Encoder(
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (1): Encoder(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (2): Encoder(
      (block): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (3): Encoder(
      (block): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(