In [None]:
# !pip install torch torchvision tensorboardX

In [None]:
import os
import glob
import random

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm
from IPython import display

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

In [None]:
# Connect to Google Drive on Google Colab
# from google.colab import drive
# drive.mount("/content/drive")

In [None]:
# Clone the GitHub repository
# !git clone https://github.com/Elias-Buerger/cvproj.git

In [None]:
DATA_PATH = "./cvproj/data/"
SAVE_PATH = "drive/MyDrive/CV_5/"

In [None]:
# NOTE: This code deletes the last batch (i.e., numpy file - including target, "_y" file)! The dataloader will NOT work otherwise!!!
files_to_delete = ["332.npy", "332_y.npy"]
for file_to_delete in files_to_delete:
    file_to_delete_path = os.path.join(DATA_PATH, file_to_delete)
    if os.path.exists(file_to_delete_path):
        os.remove(file_to_delete_path)

In [None]:
class AOSDataset(Dataset):
    def __init__(self, file_list, target_path, transform=False):
        self.file_list = file_list
        self.target_path = target_path
        self.transform = transform

    def __len__(self):
        return len(self.file_list) * 32

    def preprocess_channel(self, channel):
        # Convert channel to PIL Image
        channel = Image.fromarray(channel)

        # plt.imshow(channel, cmap="gray")
        # plt.title("Original")
        # plt.show()

        # Contrast
        channel_contrast = np.array(TF.adjust_contrast(channel_data, 100))

        # plt.imshow(channel_contrast, cmap="gray")
        # plt.title("Contrast")
        # plt.show()

        # Brightness
        # channel_brightness = np.array(self.color_jitter_brightness(channel))

        # Inversion
        # channel_inverted = np.array(TF.invert(channel))

        # Histogram Equalization
        channel_equalized = np.array(TF.equalize(channel))

        # plt.imshow(channel_equalized, cmap="gray")
        # plt.title("Equalization")
        # plt.show()

        # Inversion and Histogram Equalization
        # channel_inverted_equalized = np.array(TF.equalize(TF.invert(channel)))

        return channel_contrast, channel_equalized

    def __getitem__(self, idx):
        file_idx = idx // 32  # Determine which file to load
        batch_idx = idx % 32  # Determine the index within the batch

        data = np.load(self.file_list[file_idx])
        data = data[batch_idx]

        target_idx = self.file_list[file_idx].split("/")[-1].split(".")[0]
        target = np.load(os.path.join(self.target_path, f"{target_idx}_y.npy"))
        target = target[batch_idx]

        # Apply preprocessing steps to each channel separately
        preprocessed_channels = []
        for channel in range(data.shape[0]):
            channel_data = data[1]

            # Apply preprocessing steps
            preprocessed_contrast, preprocessed_equalized = self.preprocess_channel(channel_data)

            preprocessed_channels.extend([preprocessed_contrast, preprocessed_equalized])

        # Convert preprocessed_channels to a NumPy array
        preprocessed_channels = np.array(preprocessed_channels)

        # Concatenate the original and preprocessed channels
        data = np.concatenate([data, preprocessed_channels], axis=0)

        del preprocessed_channels

        # Normalize
        data = torch.Tensor(data) / 255.0
        target = torch.Tensor(target).unsqueeze(0) / 255.0 # 1 x ...

        # Clip values to the range [0, 1]
        data = torch.clamp(data, 0, 1)
        target = torch.clamp(target, 0, 1)

        # Apply random augmentations
        if self.transform:
            angle = random.choice([0, 90, 180, 270])
            data = TF.rotate(data, angle)
            target = TF.rotate(target, angle)

            if random.random() > 0.5:
                data = TF.hflip(data)
                target = TF.hflip(target)

            if random.random() > 0.5:
                data = TF.vflip(data)
                target = TF.vflip(target)

        return data, target

In [None]:
# Define the local paths for data and targets within the cloned repository
data_path = DATA_PATH
target_path = DATA_PATH

# Get all training files (ignoring "_y.npy")
file_list = glob.glob(os.path.join(data_path, "*[!_y].npy"))

# Determine the split for training and test
split_idx = int(0.8 * len(file_list))
train_file_list = file_list[:split_idx]
test_file_list = file_list[split_idx:]

# Create CustomDataset instances
train_dataset = AOSDataset(train_file_list, target_path, transform=True)
test_dataset = AOSDataset(test_file_list, target_path)

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=False)

In [None]:
# Retrieve samples if needed...
# for sample in train_loader:
#    sample

In [None]:
# Source: https://github.com/milesial/Pytorch-UNet
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, kernel_size=3, padding=1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(mid_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.SiLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # NOTE: If bilinear, use the normal convolutions to reduce the number of channels!
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

        self.gate = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x2 = x2 * torch.sigmoid(self.gate(x1)) # F.softmax(self.gate(x1), dim=1)

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Out, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

In [None]:
# Source: https://github.com/milesial/Pytorch-UNet
class UNet(nn.Module):
    def __init__(self, n_channels_in, n_channels_out, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels_in = n_channels_in
        self.n_channels_out = n_channels_out
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels_in, 64)

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)

        self.out = Out(64, n_channels_out)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.out(x)

        return x

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir=runs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(9, 1).to(device) # UNet(12, 1).to(device)
criterion = nn.MSELoss() # DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# TensorBoard
writer = SummaryWriter()

# Initialize a figure for plotting
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].set_title("Input")
ax[1].set_title("Target")
ax[2].set_title("Output")

# Training loop
num_epochs = 50
log_interval = 100
for epoch in range(num_epochs):
    # running_loss = 0.0
    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for batch_idx, (batch_data, batch_targets) in enumerate(train_loader):
            batch_data, batch_targets = batch_data.to(device), batch_targets.to(device)

            optimizer.zero_grad()
            outputs = model(batch_data)
            # outputs = torch.clamp(outputs, 0, 1) # Clip values to [0, 1]
            loss = criterion(outputs, batch_targets)
            loss.backward()
            # nn.utils.clip_grad_value_(model.parameters(), 0.1) # Clip gradients?
            optimizer.step()

            # running_loss += loss.item()

            if batch_idx % log_interval == 0 and batch_idx > 0:
                # Monitor loss - NOTE: Disabled for increased efficiency
                # avg_loss = running_loss / log_interval
                # print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx-log_interval+1}-{batch_idx}], Loss: {avg_loss}")
                # writer.add_scalar("Loss/train", avg_loss, epoch * len(train_loader) + batch_idx)

                # Log images to TensorBoard
                # writer.add_images("Input/train", batch_data, epoch * len(train_loader) + batch_idx) # NOTE: Would have to be fixed because of the incresead number of channels...
                writer.add_images("Target/train", batch_targets, epoch * len(train_loader) + batch_idx)
                writer.add_images("Output/train", torch.clamp(outputs, 0, 1), epoch * len(train_loader) + batch_idx)

                # Clear previous plots
                display.clear_output()

                # Plot input, target, and output
                ax[0].imshow(TF.to_pil_image(batch_data[0][0].cpu()), cmap="gray")
                ax[0].set_title("Input")

                ax[1].imshow(TF.to_pil_image(batch_targets[0].cpu()), cmap="gray")
                ax[1].set_title("Target")

                ax[2].imshow(TF.to_pil_image(torch.clamp(outputs[0], 0, 1).cpu()), cmap="gray")
                ax[2].set_title("Output")

                if batch_idx % (log_interval * 1) == 0 and batch_idx > 0:
                    # Save the figure to the Google Drive folder
                    path = f"{SAVE_PATH}plot_epoch_{epoch}_batch_{batch_idx}_0.png"
                    fig.savefig(path)

                # Display the updated plot
                display.display(fig)

                # running_loss = 0.

            del batch_data, batch_targets, outputs # Release variables
            # torch.cuda.empty_cache()

    # Update learning rate
    scheduler.step()

    # Log to TensorBoard
    writer.add_scalar("Loss/train", loss.item(), epoch)

    # Validation
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for batch_data, batch_targets in test_loader:
            batch_data, batch_targets = batch_data.to(device), batch_targets.to(device)
            outputs = model(batch_data)
            outputs = torch.clamp(outputs, 0, 1) # Clip values to [0, 1]
            val_loss += criterion(outputs, batch_targets)

            del batch_data, batch_targets, outputs # Release variables
            torch.cuda.empty_cache()

        avg_val_loss = val_loss / len(test_loader)
        print(f"Validation Loss: {avg_val_loss}")
        writer.add_scalar("Loss/test", avg_val_loss, epoch)

    # Save model weights
    model_save_path = f"{SAVE_PATH}unet_epoch_{epoch}.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at epoch {epoch}")

# Close TensorBoard writer
writer.close()