In [3]:
import cv2
import torch
import wandb
import itertools
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.login() # use your own key

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(output_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, feature_bank):
        feature_bank = self.block(feature_bank)
        return feature_bank

In [6]:
class Decoder(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(output_channels, output_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, decoder_feature_bank, encoder_feature_bank):

        if decoder_feature_bank.shape[2:] != encoder_feature_bank.shape[2:]:
            decoder_feature_bank = F.interpolate(
                decoder_feature_bank,
                size=encoder_feature_bank.shape[2:],
                mode='bilinear',
                align_corners=False
            )

        feature_bank = torch.cat([decoder_feature_bank, encoder_feature_bank], dim = 1)
        feature_bank = self.block(feature_bank)
        return feature_bank

In [7]:
class U_Net(nn.Module):
    def __init__(self):
        super().__init__()
        channels = [3, 64, 128, 256, 512, 1024]
        self.encoder = nn.ModuleList([
            Encoder(channels[i], channels[i + 1]) for i in range(len(channels) - 1)
        ])
        channels.reverse()
        self.decoder = nn.ModuleList([
            Decoder(channels[i], channels[i + 1]) for i in range(len(channels) - 2)
        ])

        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.convUps = nn.ModuleList([
            nn.ConvTranspose2d(channels[i], channels[i + 1], kernel_size = 2, stride = 2) for i in range(len(channels) - 2)
        ])

        self.final_conv = nn.Conv2d(64, 2, kernel_size = 3, padding = 1)


    def forward(self, img):

        encoder_feature_banks = []
        feature_bank = img

        for idx, block in enumerate(self.encoder):
            feature_bank = block(feature_bank)

            if idx < len(self.encoder) - 1:
                encoder_feature_banks.append(feature_bank)
                feature_bank = self.max_pool(feature_bank)

        encoder_feature_banks.reverse()

        for idx, block in enumerate(self.decoder):
            feature_bank = self.convUps[idx](feature_bank)
            feature_bank = block(feature_bank, encoder_feature_banks[idx])

        output = self.final_conv(feature_bank)

        return output

In [8]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, mask = self.data[idx]
        return image, mask

In [9]:
def retrieve_data(type_, img_nums):
    data = []

    for img_num in img_nums:
        img_label = cv2.imread(f'./images-1024x768/{type_}/image-{img_num}.png')
        img_target = cv2.imread(f'./masks-1024x768/{type_}/mask-{img_num}.png', cv2.IMREAD_GRAYSCALE)
        foreground = (img_target > 0).astype(np.float32)
        background = 1.0 - foreground
        img_target = np.stack([background, foreground], axis = 0)
        data.append([img_label, img_target])

    return data

train_data = retrieve_data('train', [2, 7, 10, 12, 21, 24, 27, 28, 30, 43])
val_data = retrieve_data('val', [1, 11, 22, 32])
test_data = retrieve_data('test', [4, 16, 29, 36])

In [10]:
def preprocess_img(img):
    image_tensor = torch.tensor(img, dtype = torch.float32)
    image_tensor = image_tensor.permute(2, 0, 1)
    return image_tensor / 255.0

def preprocess_mask(img):
    image_tensor = torch.tensor(img, dtype = torch.float32)
    return image_tensor

train_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in train_data]
val_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in val_data]
test_data = [[preprocess_img(entry[0]), preprocess_mask(entry[1])] for entry in test_data]

In [11]:
def metrics(true_masks, predictions):

    predictions = np.concatenate(predictions).flatten()
    true_masks = np.concatenate(true_masks).flatten()

    accuracy = accuracy_score(true_masks, predictions)
    precision = precision_score(true_masks, predictions, zero_division = 0)
    recall = recall_score(true_masks, predictions, zero_division = 0)
    f1 = f1_score(true_masks, predictions, zero_division = 0)

    iou = []
    num_classes = len(np.unique(true_masks))
    for i in range(num_classes):
        intersection = np.sum((true_masks == i) & (predictions == i))
        union = np.sum((true_masks == i) | (predictions == i))
        iou.append(intersection / union if union != 0 else 0)

    return accuracy, precision, recall, f1, iou

In [15]:
def nontraining_loop(type_, model, data, batch_size = 1, device="cuda"):

    dataset = ImageDataset(data)
    dataloader = DataLoader(dataset, batch_size, shuffle=False)

    loss_function = nn.CrossEntropyLoss()
    total_loss = 0.0

    predictions = []
    true_masks = []

    model.eval()
    with torch.no_grad():
        for image, mask in dataloader:
            image = image.to(device)
            mask = mask.to(device)

            logits = model(image)
            loss = loss_function(logits, torch.argmax(mask, dim=1))
            total_loss += loss.item()

            prediction = torch.argmax(logits, dim=1)
            true_mask = torch.argmax(mask, dim=1)

            predictions.append(prediction.cpu().numpy())
            true_masks.append(true_mask.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy, precision, recall, f1, iou = metrics(true_masks, predictions)

    wandb.log({
          f"{type_} Loss": avg_loss,
          f"{type_} Accuracy": accuracy,
          f"{type_} Precision": precision,
          f"{type_} Recall": recall,
          f"{type_} F1": f1,
          f"{type_} IoU": iou[1]
    })

    print(f"{type_} | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.4f} | Precision: {precision} | Recall: {recall} | F1 Score: {f1} | IoU: {iou}  \n")

    return avg_loss, accuracy, precision, f1, iou


In [21]:
def training_loop(train_data, val_data, num_epochs = 10, batch_size = 1, learning_rate = 1e-3, device = "cuda"):

    train_dataset = ImageDataset(train_data)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle = True)

    model = U_Net().to(device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

    train_loss = 0.0
    predictions = []
    true_masks = []

    model.train()
    for epoch in range(num_epochs):
        for image, mask in train_dataloader:
            image = image.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            logits = model(image)
            loss = loss_function(logits, mask)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            prediction = torch.argmax(logits, dim=1)
            true_mask = torch.argmax(mask, dim=1)

            predictions.append(prediction.cpu().numpy())
            true_masks.append(true_mask.cpu().numpy())

        accuracy, precision, recall, f1, iou = metrics(true_masks, predictions)

        wandb.log({
              f"Training Loss": train_loss,
              f"Training Accuracy": accuracy,
              f"Training Precision": precision,
              f"Training Recall": recall,
              f"Training F1": f1,
              f"Training IoU": iou[1]
        })

        print(f"Epoch: {epoch} | Training | Loss: {train_loss:.4f} | Accuracy: {accuracy:.4f} | Precision: {precision} | Recall: {recall} | F1 Score: {f1} | IoU: {iou}")
        val_loss = nontraining_loop("Validation", model, val_data, 1, device)

    return model, iou[1]

In [None]:
def find_best_hyperparameters():

    num_epochs = [10]
    batch_size = [1, 2, 3, 4, 5]
    learning_rate = [1e-2, 1e-3, 1e-4]

    best_model = None
    best_iou = 0

    for idx, combination in enumerate(list(itertools.product(num_epochs, batch_size, learning_rate))):

        wandb.init(entity = "computer-vision-wits", project = "U-Net", name = f"run_{idx}")

        config = wandb.config
        config.epochs = combination[0]
        config.batch_size = combination[1]
        config.learning_rate = combination[2]

        model, iou = training_loop(
            train_data,
            val_data,
            num_epochs = combination[0],
            batch_size = combination[1],
            learning_rate = combination[2],
            device = "cuda"
        )

        if iou > best_iou:
          best_iou = iou
          best_model = model

        wandb.finish()

    return model

model = find_best_hyperparameters()

In [None]:
wandb.init(entity = "computer-vision-wits", project = "U-Net", name = f"u-run-5")

training_loop(train_data, val_data, num_epochs = 10, batch_size = len(train_data), learning_rate = 1e-3, device = "cuda")

wandb.finish()

In [None]:
wandb.init(entity = "computer-vision-wits", project = "U-Net", name = f"u-net-base-testing")

nontraining_loop(
    "Testing",
    model = model,
    data = test_data,
    batch_size = 4,
    device = "cuda"
)

wandb.finish()