# Import packages


In [None]:
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import skimage
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from joblib import Parallel, delayed
from skimage.color import rgb2hed, rgba2rgb
from skimage.io import imread
from sklearn.decomposition import PCA
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from torchvision import datasets

In [None]:
writer = SummaryWriter('runs/experiment_1')

# Utility functions:


Define some utility functions for working with images.

In [None]:
def read_image(image_id: str) -> np.array:
    """Reads an image from the dataset

    Args:
        image_id (str): The id of the image to be read

    Returns:
        np.array: The image as a numpy array
    """

    image_folder = Path("data/patches_256")

    image_path = image_folder / f"{image_id}.png"

    rgb_image = imread(image_path)

    # if the image has an alpha channel, remove it
    if rgb_image.shape[-1] == 4:
        rgb_image = rgba2rgb(rgb_image)

    return rgb_image

In [None]:
def convert_rgb_to_hed(input_rgb_image: np.array) -> np.array:
    """
    Converts an RGB image to the HED color space.

    Parameters:
        input_rgb_image (np.array): The input RGB image.

    Returns:
        np.array: The image converted to the HED color space.
    """
    hed_image = rgb2hed(input_rgb_image)
    return hed_image

In [None]:
def calculate_intensity_avg(input_image: np.array, channel: int) -> float:
    """
    Calculates the average intensity for a specific channel in an RGB or HED image.

    Parameters:
        input_image (np.array): The input image (RGB or HED).
        channel (int): The channel index for which to calculate the average intensity.

    Returns:
        float: The average intensity for the specified channel.
    """
    return input_image[:, :, channel].mean()

In [None]:
def calculate_intensity_std(input_image: np.array, channel: int) -> float:
    """
    Calculates the standard deviation of the intensity for a specific channel in an RGB or HED image.

    Parameters:
        input_image (np.array): The input image (RGB or HED).
        channel (int): The channel index for which to calculate the standard deviation of the intensity.

    Returns:
        float: The standard deviation of the intensity for the specified channel.
    """
    return input_image[:, :, channel].std()

In [None]:
def calculate_avg_h_intensity(image_id: str) -> dict:
    """
    Calculate the average H intensity of an image.

    Parameters:
        image_id (str): The ID of the image.

    Returns:
        dict: A dictionary containing the image ID and the average H intensity.
    """
    rgb_image = read_image(image_id)
    hed_image = convert_rgb_to_hed(rgb_image)
    avg_h_intensity = calculate_intensity_avg(hed_image, 0)
    return {
        "image_id": image_id,
        "avg_h_intensity": avg_h_intensity,
    }

In [None]:
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Loads protein expression data from a CSV file and splits it into training and testing datasets.

    Returns:
        training_data (pandas.DataFrame): The training dataset containing specimens A1, B1, and D1.
        testing_data (pandas.DataFrame): The testing dataset containing specimen C1.
    """
    df = pd.read_csv(
        "https://warwick.ac.uk/fac/sci/dcs/teaching/material/cs909/protein_expression_data.csv"
    )

    # create specimen id field
    df["specimen_id"] = df.VisSpot.apply(lambda x: x.split("-")[2])

    # create image id field
    df["image_id"] = df.VisSpot.apply(lambda x: x.split("-")[2]) + "_" + df.id

    df = df.set_index("image_id").sort_index()

    # use specimens A1, B1 and D1 for training
    training_data = df.loc[df["specimen_id"].isin(["A1", "B1", "D1"])]

    # use specimen C1 for testing
    testing_data = df.loc[df["specimen_id"].isin(["C1"])]

    return training_data, testing_data

# Load data


Load the proteint expression data after splitting it into training and testing:

In [None]:
training_data, testing_data = load_data()

In [None]:
print("Number of training samples:", len(training_data))
print("Number of testing samples:", len(testing_data))

# Question No. 1: (Data Analysis)


For the following questions, we will use only the `training_data`

## Counting Examples:


In [None]:
(
    training_data.groupby("specimen_id", as_index=False)
    .agg(n_sample=("id", "count"))
    .sort_values("n_sample", ascending=False)
)

## Protein Expression Histograms


In [None]:
ax = sns.displot(data=training_data, x="NESTIN", col="specimen_id", hue="specimen_id")

ax.set_titles("Protein expression in specimen NESTIN")
ax.set_xlabels("Protein expression")
ax.set_ylabels("Frequency")

In [None]:
ax = sns.displot(data=training_data, x="cMYC", col="specimen_id", hue="specimen_id")

ax.set_titles("Protein expression in specimen cMYC")
ax.set_xlabels("Protein expression")
ax.set_ylabels("Frequency")

In [None]:
ax = sns.displot(data=training_data, x="MET", col="specimen_id", hue="specimen_id")

ax.set_titles("Protein expression in specimen MET")
ax.set_xlabels("Protein expression")
ax.set_ylabels("Frequency")

From the above plots, we notice the following:

1. Different protients have different ranges. `NESTIN` has values in the range `[-7, 1]`, `cMYC` has values in the range `[-10.5, 3.2]`, and `MET` has values in the range `[-10.7, 1.58]`

2. The majority of the different protient values across different specimens are centered around 0, with fewer values spread around the extreme.

## Image Pre-processing


In [None]:
np.random.seed(42)

random_image_ids = np.random.choice(training_data.index, size=10)

for image_id in random_image_ids:
    rgb_image = read_image(image_id)

    hed_image = convert_rgb_to_hed(rgb_image)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(rgb_image)
    ax[0].set_title("RGB Image")
    ax[0].axis("off")

    ax[1].imshow(hed_image[:, :, 0], cmap="gray")
    ax[1].set_title("H Channel")
    ax[1].axis("off")

    plt.show()

## H-channel Analysis


In [None]:
avg_h_intensity_list = Parallel(n_jobs=-1, verbose=10)(
    delayed(calculate_avg_h_intensity)(image_id) for image_id in training_data.index
)

In [None]:
avg_h_intensity_df = pd.DataFrame(avg_h_intensity_list).set_index("image_id")

In [None]:
avg_h_intensity_df = avg_h_intensity_df.join(training_data[["NESTIN", "specimen_id"]])

In [None]:
avg_h_intensity_df.head()

In [None]:
ax = sns.scatterplot(
    data=avg_h_intensity_df, x="avg_h_intensity", y="NESTIN", hue="specimen_id", alpha=0.2,
)

ax.set_title("Average H intensity vs NESTIN expression")
ax.set_xlabel("Average H intensity")
ax.set_ylabel("NESTIN expression")

In [None]:
correlation = avg_h_intensity_df["avg_h_intensity"].corr(avg_h_intensity_df["NESTIN"])

In [None]:
print(f"The correlation between average H intensity and NESTIN expression is {correlation:.2f}")

From the scatter plot and the correlation value we can see that there is a positive relation between the average intensity value of the `H` channel and the expression levels of `NESTIN`.

However, this correlation is weak and won't capture the true relation of the target variable.

## Performance Metrics for Prediction

# Question No. 2: (Feature Extraction and Classical Regression)


In [None]:
def calculate_image_channel_stats(image_id: str):
    """
    Calculate the intensity statistics for each channel of an image.

    Args:
        image_id (str): The ID of the image.

    Returns:
        dict: A dictionary containing the image ID and the calculated intensity statistics for each channel.
            - "image_id": The ID of the image.
            - "h_intensity_avg": The average intensity of the H channel in the HED color space.
            - "h_intensity_std": The standard deviation of the intensity of the H channel in the HED color space.
            - "r_intensity_avg": The average intensity of the R channel in the RGB color space.
            - "r_intensity_std": The standard deviation of the intensity of the R channel in the RGB color space.
            - "g_intensity_avg": The average intensity of the G channel in the RGB color space.
            - "g_intensity_std": The standard deviation of the intensity of the G channel in the RGB color space.
            - "b_intensity_avg": The average intensity of the B channel in the RGB color space.
            - "b_intensity_std": The standard deviation of the intensity of the B channel in the RGB color space.
    """
    rgb_image = read_image(image_id)
    hed_image = convert_rgb_to_hed(rgb_image)

    h_intensity_avg = calculate_intensity_avg(hed_image, 0)
    h_intensity_std = calculate_intensity_std(hed_image, 0)

    r_intensity_avg = calculate_intensity_avg(rgb_image, 0)
    r_intensity_std = calculate_intensity_std(rgb_image, 0)

    g_intensity_avg = calculate_intensity_avg(rgb_image, 1)
    g_intensity_std = calculate_intensity_std(rgb_image, 1)

    b_intensity_avg = calculate_intensity_avg(rgb_image, 2)
    b_intensity_std = calculate_intensity_std(rgb_image, 2)

    return {
        "image_id": image_id,
        "h_intensity_avg": h_intensity_avg,
        "h_intensity_std": h_intensity_std,
        "r_intensity_avg": r_intensity_avg,
        "r_intensity_std": r_intensity_std,
        "g_intensity_avg": g_intensity_avg,
        "g_intensity_std": g_intensity_std,
        "b_intensity_avg": b_intensity_avg,
        "b_intensity_std": b_intensity_std,
    }

In [None]:
image_channels_stats_list = Parallel(n_jobs=-1, verbose=10)(
    delayed(calculate_image_channel_stats)(image_id) for image_id in training_data.index
)

In [None]:
image_channels_stats_df = pd.DataFrame(image_channels_stats_list).set_index("image_id")

In [None]:
image_channels_stats_df = image_channels_stats_df.join(training_data[["NESTIN", "specimen_id"]])

In [None]:
image_channels_stats_df.head()

In [None]:
# pca.fit_transform(training_images.reshape(-1, 256 * 256))

In [None]:
# # find the number of principal components that explain 95% of the variance
# cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
# n_components = np.argmax(cumulative_variance > 0.95) + 1
# print(f"Number of components explaining 95% of the variance: {n_components}")
# # Number of components explaining 95% of the variance: 3433

# Question No. 3 (Using Convolutional Neural Networks)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [None]:
print(f"Using {device} device")

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        # first convolutional layer
        self.conv_layer_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
        )

        # second convolutional layer
        self.conv_layer_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=2,
            ),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
        )

        # fully connected layer
        self.fc = nn.Linear(32 * 33 * 33, 1)

    def forward(self, x):
        out = self.conv_layer_1(x)
        out = self.conv_layer_2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
model = ConvNet().to(device)

In [None]:
# print the number of parameters in the model
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters in the model: {n_params}")

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_data: pd.DataFrame):
        self.image_data = image_data
        self.transform = transforms.Resize((128, 128))

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        image_id = self.image_data.loc[idx, "image_id"]
        label = self.image_data.loc[idx, "NESTIN"]

        rgb_image = read_image(image_id)
        hed_image = convert_rgb_to_hed(rgb_image)
        h_channel = hed_image[:, :, 0]
        # TODO:
        #   Do we need to normalize the H channel?
        h_channel_normalized = h_channel / 255.0

        # h_channel_normalized = h_channel_normalized.reshape(1, 256, 256)
        h_channel_normalized = np.expand_dims(h_channel_normalized, axis=0)

        h_channel_normalized_tensor = torch.tensor(
            h_channel_normalized, dtype=torch.float32
        )

        h_channel_normalized_tensor_resized = self.transform(
            h_channel_normalized_tensor
        )

        return (
            h_channel_normalized_tensor_resized,
            torch.tensor(label, dtype=torch.float32),
        )

In [None]:
training_data_meta = training_data.reset_index().loc[:, ["image_id", "NESTIN"]]
testing_data_meta = testing_data.reset_index().loc[:, ["image_id", "NESTIN"]]

In [None]:
training_data_meta.shape

In [None]:
testing_data_meta.shape

In [None]:
training_dataset = CustomDataset(training_data_meta)
testing_dataset = CustomDataset(testing_data_meta)

In [None]:
train_dataloader = DataLoader(training_dataset, batch_size=100, shuffle=True)
test_dataloader = DataLoader(testing_dataset, batch_size=100, shuffle=False)

In [None]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [None]:
def train(
    dataloader: DataLoader,
    model: ConvNet,
    loss_fn: nn.MSELoss,
    optimizer: torch.optim.SGD,
    epoch: int,
):

    size = len(dataloader.dataset)
    model.train()

    train_loss = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    avg_train_loss = train_loss / size
    writer.add_scalar("Loss/Train", avg_train_loss, epoch)

In [None]:
def test(dataloader, model, loss_fn, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()

    test_loss = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()

    avg_val_loss = test_loss / size
    writer.add_scalar("Loss/Test", avg_val_loss, epoch)

    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):

    print(f"Epoch {t+1}\n-------------------------------")

    train(train_dataloader, model, loss_fn, optimizer, epoch=t + 1)
    test(test_dataloader, model, loss_fn, epoch=t + 1)

print("Done!")

writer.flush()