In [1]:
import os
import glob
import random
import pickle

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm
from IPython import display

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

In [2]:
# Connect to Google Drive on Google Colab
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Clone the GitHub repository
!git clone https://github.com/Elias-Buerger/cvproj.git

fatal: destination path 'cvproj' already exists and is not an empty directory.


In [2]:
DATA_PATH = "./cvproj/data/"
SAVE_PATH = "drive/MyDrive/CV_V3/"

In [3]:
# NOTE: This code deletes the last batch (i.e., numpy file - including target, "_y" file)! The dataloader will NOT work otherwise!!!
files_to_delete = ["332.npy", "332_y.npy", "670.npy", "670_y.npy", "1004.npy", "1004_y.npy"]
for file_to_delete in files_to_delete:
    file_to_delete_path = os.path.join(DATA_PATH, file_to_delete)
    if os.path.exists(file_to_delete_path):
        os.remove(file_to_delete_path)

In [4]:
class AOSDataset(Dataset):
    def __init__(self, file_list, target_path, random_aug=False):
        self.file_list = file_list
        self.target_path = target_path
        self.random_aug = random_aug

    def __len__(self):
        return len(self.file_list) * 32

    def preprocess_channel(self, channel):
        # Convert channel to PIL Image
        channel = Image.fromarray(channel)

        # plt.imshow(channel, cmap="gray")
        # plt.title("Original")
        # plt.show()

        # Contrast
        # channel_contrast = np.array(TF.adjust_contrast(channel, 100))

        # plt.imshow(channel_contrast, cmap="gray")
        # plt.title("Contrast")
        # plt.show()

        # Brightness
        # channel_brightness = np.array(self.color_jitter_brightness(channel))

        # Inversion
        # channel_inverted = np.array(TF.invert(channel))

        # Histogram Equalization
        channel_equalized = np.array(TF.equalize(channel))

        # plt.imshow(channel_equalized, cmap="gray")
        # plt.title("Equalization")
        # plt.show()

        # Inversion and Histogram Equalization
        # channel_inverted_equalized = np.array(TF.equalize(TF.invert(channel)))

        return channel_equalized

    def __getitem__(self, idx):
        file_idx = idx // 32 # Determine which file to load
        batch_idx = idx % 32 # Determine the index within the batch

        data = np.load(self.file_list[file_idx])
        data = data[batch_idx]

        target_idx = self.file_list[file_idx].split("/")[-1].split(".")[0]
        target = np.load(os.path.join(self.target_path, f"{target_idx}_y.npy"))
        target = target[batch_idx]

        # Apply preprocessing steps to each channel separately
        preprocessed_channels = []
        channel_ids = list(range(data.shape[0]))
        random.shuffle(channel_ids) # Randomize the order of channel_ids
        for channel_id in channel_ids:
            channel_data = data[channel_id]

            prep_eq = self.preprocess_channel(channel_data)

            preprocessed_channels.extend([channel_data, prep_eq])

        data = np.array(preprocessed_channels)

        del preprocessed_channels

        # Normalize
        data = torch.Tensor(data) / 255.0
        target = torch.Tensor(target).unsqueeze(0) / 255.0 # 1 x ...

        # Clip values to the range [0, 1]
        data = torch.clamp(data, 0, 1)
        target = torch.clamp(target, 0, 1)

        # Apply random augmentation
        if self.random_aug:
            angle = random.choice([0, 90, 180, 270])
            data = TF.rotate(data, angle)
            target = TF.rotate(target, angle)

            if random.random() > 0.5:
                data = TF.hflip(data)
                target = TF.hflip(target)

            if random.random() > 0.5:
                data = TF.vflip(data)
                target = TF.vflip(target)

        return data, target

In [5]:
# Define the local paths for data (i.e., integrated images) and targets
data_path = DATA_PATH
target_path = DATA_PATH

# Get all training files (ignoring "_y.npy")
file_list = glob.glob(os.path.join(data_path, "*[!_y].npy"))

# Shuffle
random.seed(42)
random.shuffle(file_list)

# Define train, test, validation splits
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

train_split = int(train_ratio * len(file_list))
test_split = train_split + int(test_ratio * len(file_list))

train_file_list = file_list[:train_split]
test_file_list = file_list[train_split:test_split]
val_file_list = file_list[test_split:]

# Create AOSDataset instances
train_dataset = AOSDataset(train_file_list, target_path, random_aug=True)
test_dataset = AOSDataset(test_file_list, target_path)
val_dataset = AOSDataset(val_file_list, target_path)

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=6)
val_loader = DataLoader(val_dataset, batch_size=6)

In [6]:
# Retrieve samples if needed...
# for sample in train_loader:
#     sample

In [7]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, kernel_size=3, padding=1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(mid_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.SiLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

        self.gate = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x2 = x2 * torch.sigmoid(self.gate(x1))

        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)


class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Out, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

In [8]:
# Gated U-Net with Bilinear Upsampling
# U-Net Architecture Adapted from https://github.com/milesial/Pytorch-UNet
class GatedUNetWithBilinearUpsampling(nn.Module):
    def __init__(self, n_channels_in, n_channels_out, bilinear=True):
        super(GatedUNetWithBilinearUpsampling, self).__init__()
        self.n_channels_in = n_channels_in
        self.n_channels_out = n_channels_out
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels_in, 64)

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // 2) # // 2 -> Bilinear Upsampling

        self.up1 = Up(1024, 512 // 2)
        self.up2 = Up(512, 256 // 2)
        self.up3 = Up(256, 128 // 2)
        self.up4 = Up(128, 64)

        self.out = Out(64, n_channels_out)

    def forward(self, x):
        x1 = self.inc(x)

        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.out(x)

        return x

In [9]:
def save_batch_images(data, save_path, prefix):
    os.makedirs(save_path, exist_ok=True)

    for item_id in range(len(data)):
        item_data = data[item_id].cpu().numpy()

        for channel_id in range(item_data.shape[0]):
            channel_data = item_data[channel_id]

            # Clip values between 0 and 1 and scale to [0, 255]
            channel_data = np.clip(channel_data, 0, 1)
            channel_data = (channel_data * 255).astype(np.uint8)

            # Convert numpy array to PIL Image
            image = Image.fromarray(channel_data, mode="L")

            # Create a folder for each item in the batch
            folder_path = os.path.join(save_path, f"{prefix}_{item_id + 1}")
            os.makedirs(folder_path, exist_ok=True)

            # Save the image as PNG
            image.save(os.path.join(folder_path, f"channel_{channel_id + 1}.png"))

            del channel_data, image

        del item_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GatedUNetWithBilinearUpsampling(6, 1).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Track losses
train_losses = []
test_losses = []
val_losses = []

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = 0.0
    model.train()
    for batch_idx, (batch_data, batch_targets) in enumerate(tqdm(train_loader, desc=f"[TRAIN] Epoch {epoch+1}/{num_epochs}", unit="Batches")):
        batch_data, batch_targets = batch_data.to(device), batch_targets.to(device)

        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs, batch_targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        del loss, batch_data, batch_targets, outputs
        # torch.cuda.empty_cache()

    # Calculate average training loss for the epoch
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Save model weights
    model_save_path = f"{SAVE_PATH}unet_epoch_{epoch}.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at epoch: {epoch}")

    # Update learning rate
    scheduler.step()

    # Test
    test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch_idx, (batch_data, batch_targets) in enumerate(tqdm(test_loader, desc=f"[TEST] Epoch {epoch+1}/{num_epochs}", unit="Batches")):
            batch_data, batch_targets = batch_data.to(device), batch_targets.to(device)
            outputs = model(batch_data)
            loss = criterion(outputs, batch_targets)

            test_loss += loss.item()

            # Save some batches as examples for the report...
            if (batch_idx + 1) % (len(test_loader) // 3) == 0:
                # Save batch_data, batch_targets, and outputs as images
                save_batch_images(batch_data, f"{SAVE_PATH}test_batch_data_{epoch + 1}", "batch_data")
                save_batch_images(batch_targets, f"{SAVE_PATH}test_batch_targets_{epoch + 1}", "batch_targets")
                save_batch_images(outputs, f"{SAVE_PATH}test_batch_outputs_{epoch + 1}", "outputs")


            del loss, batch_data, batch_targets, outputs
            # torch.cuda.empty_cache()

    # Calculate average test loss for the epoch
    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)

    # Validation
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        save_counter = 0
        for batch_idx, (batch_data, batch_targets) in enumerate(tqdm(val_loader, desc=f"[VALIDATION] Epoch {epoch+1}/{num_epochs}", unit="Batches")):
            batch_data, batch_targets = batch_data.to(device), batch_targets.to(device)
            outputs = model(batch_data)
            loss = criterion(outputs, batch_targets)

            val_loss += loss.item()

            # Save some batches as examples for the report...
            if (batch_idx + 1) % (len(val_loader) // 3) == 0:
                # Save batch_data, batch_targets, and outputs as images
                save_batch_images(batch_data, f"{SAVE_PATH}val_batch_data_{epoch + 1}", "batch_data")
                save_batch_images(batch_targets, f"{SAVE_PATH}val_batch_targets_{epoch + 1}", "batch_targets")
                save_batch_images(outputs, f"{SAVE_PATH}val_batch_outputs_{epoch + 1}", "outputs")

            del loss, batch_data, batch_targets, outputs
            # torch.cuda.empty_cache()

    # Calculate average validation loss for the epoch
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # Plot and save losses
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epoch + 2), train_losses, label="Train Loss", marker="o")
    plt.plot(range(1, epoch + 2), val_losses, label="Validation Loss", marker="o")
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{SAVE_PATH}loss_plot_{epoch + 1}.png")
    plt.close()

    loss_dict = {"train_losses": train_losses, "test_losses": test_losses, "val_losses": val_losses}
    with open(f"{SAVE_PATH}losses_{epoch + 1}.pkl", "wb") as file:
        pickle.dump(loss_dict, file)

[TRAIN] Epoch 1/50: 100%|██████████| 4272/4272 [2:10:34<00:00,  1.83s/Batches]


Model saved at epoch: 0


[TEST] Epoch 1/50: 100%|██████████| 534/534 [03:05<00:00,  2.89Batches/s]
[VALIDATION] Epoch 1/50: 100%|██████████| 539/539 [03:07<00:00,  2.88Batches/s]
[TRAIN] Epoch 2/50: 100%|██████████| 4272/4272 [2:14:41<00:00,  1.89s/Batches]


Model saved at epoch: 1


[TEST] Epoch 2/50: 100%|██████████| 534/534 [03:02<00:00,  2.93Batches/s]
[VALIDATION] Epoch 2/50: 100%|██████████| 539/539 [03:03<00:00,  2.93Batches/s]
[TRAIN] Epoch 3/50: 100%|██████████| 4272/4272 [2:15:47<00:00,  1.91s/Batches]


Model saved at epoch: 2


[TEST] Epoch 3/50: 100%|██████████| 534/534 [03:05<00:00,  2.87Batches/s]
[VALIDATION] Epoch 3/50: 100%|██████████| 539/539 [03:07<00:00,  2.87Batches/s]
[TRAIN] Epoch 4/50: 100%|██████████| 4272/4272 [2:18:21<00:00,  1.94s/Batches]


Model saved at epoch: 3


[TEST] Epoch 4/50: 100%|██████████| 534/534 [03:04<00:00,  2.89Batches/s]
[VALIDATION] Epoch 4/50: 100%|██████████| 539/539 [03:05<00:00,  2.90Batches/s]
[TRAIN] Epoch 5/50: 100%|██████████| 4272/4272 [2:19:30<00:00,  1.96s/Batches]


Model saved at epoch: 4


[TEST] Epoch 5/50: 100%|██████████| 534/534 [03:06<00:00,  2.87Batches/s]
[VALIDATION] Epoch 5/50: 100%|██████████| 539/539 [03:09<00:00,  2.85Batches/s]
[TRAIN] Epoch 6/50: 100%|██████████| 4272/4272 [2:19:40<00:00,  1.96s/Batches]


Model saved at epoch: 5


[TEST] Epoch 6/50: 100%|██████████| 534/534 [03:06<00:00,  2.86Batches/s]
[VALIDATION] Epoch 6/50: 100%|██████████| 539/539 [03:09<00:00,  2.85Batches/s]
[TRAIN] Epoch 7/50: 100%|██████████| 4272/4272 [2:17:01<00:00,  1.92s/Batches]


Model saved at epoch: 6


[TEST] Epoch 7/50: 100%|██████████| 534/534 [03:05<00:00,  2.88Batches/s]
[VALIDATION] Epoch 7/50: 100%|██████████| 539/539 [03:09<00:00,  2.85Batches/s]
[TRAIN] Epoch 8/50:   6%|▌         | 235/4272 [07:57<2:12:38,  1.97s/Batches]