# HW4 - Graph Neural Networks

## Course Name: Analysis of Medical Images
#### Lecturers: Dr. Rohban
#### Name: 
#### Student ID: 

---

**Contact**: Ask your questions in Quera

---

### Instructions:
- Complete all exercises presented in this notebook.
- Ensure you run each cell after you've entered your solution.




# Introduction

This notebook is designed to introduce you to the basics of Graph Neural Networks (GNNs) using a dataset of pathology images. The steps you will follow are:

- Download the Dataset: Access and download a set of pathology images.
- Data Preprocessing: Prepare the data for processing.
- Nuclei Extraction: Identify the location and dimensions of nuclei in the images.
- GNN Model Implementation: Develop a GNN for classification.
- Experiments: Conduct experiments to test the effectiveness of your model.

You are encouraged to use relevant libraries, but please provide brief explanations for your choices and describe how they work.

## Requirements

In [1]:
!pip install histomicstk --find-links https://girder.github.io/large_image_wheels
!pip install ogb
!pip install gdown
!pip install torchvision
!pip install torch_geometric

## Imports

In [None]:
import histomicstk as htk  # HistomicsTK for pathology image analysis
import numpy as np
import scipy as sp
import skimage.io
import skimage.measure
import skimage.color
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
import torch.nn as nn

import copy

from torchvision.models import resnet18, ResNet18_Weights
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.model_selection import train_test_split

## Dataset

We will use the BRACS test dataset, a collection of Hematoxylin and Eosin (H&E) stained histopathological images for breast tumor classification. The dataset includes over 4,000 tumor regions-of-interest labeled in 7 categories. You can download it using the following commands. Alternatively, you may use any suitable dataset, but ensure it is accessible during evaluation.

In [None]:
!gdown "https://drive.google.com/uc?export=download&id=1Hk7rtmyR65y4ZXUt9fNN_qlv-A8JyJVd" -O data.zip
!unzip data.zip
!rm -rf data.zip

## Preprocessing and Graphic Model Extraction (15 points)

In this section, you will preprocess the images to facilitate model training. Then, identify the position and dimensions of cells using either classic methods or neural networks. Finally, use a neural network (e.g., ResNet18) to embed cell information for node embedding in the GNN.


In [None]:
def images_preprocess(images):
    # Add your preprocessing steps here
    pass

def nuclei_extracting(images):
    # Implement nuclei extraction here
    pass


## Visualization (3 points)

Visualize the results to check accuracy and quality. For example, display a random pathology image alongside the graph of adjacent nuclei.


In [None]:
# Visualization code here

## Dataset Object

Create a dataset class suitable for your data. You may use geometric_Dataset, Dataset, or any other method. Please explain your implementation logic if you choose an alternative approach.

In [None]:
import torch

from torch.utils.data import Dataset
from torch_geometric.data import Dataset as geometric_Dataset
from torch_geometric.data import Data


class PathologyDataset(Dataset):
    def __init__(self, directory):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

    def get_idx_split(self):

        train_indices = None
        test_indices = None
        val_indices = None

        return {'train': train_indices, 'val': val_indices, 'test': test_indices}

## GCN Model Implementation (25 points)

Implement the GCN model as per the provided architecture. Ensure each step is well-documented. [Please follow the figure below to implement your `forward` function.]

![test](https://drive.google.com/uc?id=128AuYAXNXGg7PIhJJ7e420DoPWKb-RtL)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, return_embeds=False):
        # TODO: Implement a function that initializes self.convs,
        # self.bns, and self.softmax.

        super(GCN, self).__init__()

        # A list of GCNConv layers
        self.convs = None

        # A list of 1D batch normalization layers
        self.bns = None

        # The log softmax layer
        self.softmax = None

        ############# Your code here ############
        ## Note:
        ## 1. You should use torch.nn.ModuleList for self.convs and self.bns
        ## 2. self.convs has num_layers GCNConv layers
        ## 3. self.bns has num_layers - 1 BatchNorm1d layers
        ## 4. You should use torch.nn.LogSoftmax for self.softmax
        ## 5. The parameters you can set for GCNConv include 'in_channels' and
        ## 'out_channels'. For more information please refer to the documentation:
        ## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv
        ## 6. The only parameter you need to set for BatchNorm1d is 'num_features'
        ## For more information please refer to the documentation:
        ## https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html

        #########################################

        # Probability of an element getting zeroed
        self.dropout = dropout

        # Skip classification layer and return node embeddings
        self.return_embeds = return_embeds

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        # TODO: Implement a function that takes the feature tensor x and
        # edge_index tensor adj_t and returns the output tensor as
        # shown in the figure.

        out = None

        ############# Your code here ############
        ## Note:
        ## 1. Construct the network as shown in the figure
        ## 2. torch.nn.functional.relu and torch.nn.functional.dropout are useful
        ## For more information please refer to the documentation:
        ## https://pytorch.org/docs/stable/nn.functional.html
        ## 3. Don't forget to set F.dropout training to self.training
        ## 4. If return_embeds is True, then skip the last softmax layer

        #########################################

        return out

## Graph Prediction Model (15 points)

Implement a GCN Graph Prediction model using the node embeddings from the GCN model and global pooling to create graph-level embeddings.

### Graph Mini-Batching
Before diving into the actual model, we introduce the concept of mini-batching with graphs. In order to parallelize the processing of a mini-batch of graphs, PyG combines the graphs into a single disconnected graph data object (*torch_geometric.data.Batch*). *torch_geometric.data.Batch* inherits from *torch_geometric.data.Data* (introduced earlier) and contains an additional attribute called `batch`.

The `batch` attribute is a vector mapping each node to the index of its corresponding graph within the mini-batch:

    batch = [0, ..., 0, 1, ..., n - 2, n - 1, ..., n - 1]

This attribute is crucial for associating which graph each node belongs to and can be used to e.g. average the node embeddings for each graph individually to compute graph level embeddings.


### Implemention
Now, we have all of the tools to implement a GCN Graph Prediction model!  

We will reuse the existing GCN model to generate `node_embeddings` and then use  `Global Pooling` over the nodes to create graph level embeddings that can be used to predict properties for the each graph. Remeber that the `batch` attribute will be essential for performining Global Pooling over our mini-batch of graphs.


In [None]:
from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.nn import global_add_pool, global_mean_pool

### GCN to predict graph property
class GCN_Graph(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers, dropout):
        super(GCN_Graph, self).__init__()

        # Load encoders for Atoms in molecule graphs
        self.node_encoder = AtomEncoder(hidden_dim)

        # Node embedding model
        # Note that the input_dim and output_dim are set to hidden_dim
        self.gnn_node = GCN(hidden_dim, hidden_dim,
            hidden_dim, num_layers, dropout, return_embeds=True)

        self.pool = None

        ############# Your code here ############
        ## Note:
        ## 1. Initialize self.pool as a global mean pooling layer
        ## For more information please refer to the documentation:
        ## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#global-pooling-layers

        #########################################

        # Output layer
        self.linear = torch.nn.Linear(hidden_dim, output_dim)


    def reset_parameters(self):
      self.gnn_node.reset_parameters()
      self.linear.reset_parameters()

    def forward(self, batched_data):
        # TODO: Implement a function that takes as input a
        # mini-batch of graphs (torch_geometric.data.Batch) and
        # returns the predicted graph property for each graph.
        #
        # NOTE: Since we are predicting graph level properties,
        # your output will be a tensor with dimension equaling
        # the number of graphs in the mini-batch


        # Extract important attributes of our mini-batch
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch
        embed = self.node_encoder(x)

        out = None

        ############# Your code here ############
        ## Note:
        ## 1. Construct node embeddings using existing GCN model
        ## 2. Use the global pooling layer to aggregate features for each individual graph
        ## For more information please refer to the documentation:
        ## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#global-pooling-layers
        ## 3. Use a linear layer to predict each graph's property

        #########################################

        return out

In [None]:
def train(model, device, data_loader, optimizer, loss_fn):
    # TODO: Implement a function that trains your model by
    # using the given optimizer and loss_fn.
    model.train()
    loss = 0

    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")):
      batch = batch.to(device)

      if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
          pass
      else:
        ## ignore nan targets (unlabeled) when computing training loss.
        is_labeled = batch.y == batch.y

        ############# Your code here ############
        ## Note:
        ## 1. Zero grad the optimizer
        ## 2. Feed the data into the model
        ## 3. Use `is_labeled` mask to filter output and labels
        ## 4. You may need to change the type of label to torch.float32
        ## 5. Feed the output and label to the loss_fn
        ## (~3 lines of code)

        #########################################

        loss.backward()
        optimizer.step()

    return loss.item()

In [None]:
from tqdm import tqdm

# The evaluation function
def eval(model, device, loader, evaluator, save_model_results=False, save_file=None):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    if save_model_results:
        print ("Saving Model Predictions")

        # Create a pandas dataframe with a two columns
        # y_pred | y_true
        data = {}
        data['y_pred'] = y_pred.reshape(-1)
        data['y_true'] = y_true.reshape(-1)

        df = pd.DataFrame(data=data)
        # Save to csv
        df.to_csv('our_graph_' + save_file + '.csv', sep=',', index=False)

    return evaluator.eval(input_dict)

## Traing and Testing

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
args = {'device': device, 'num_layers': 5, 'hidden_dim': 256, 'dropout': 0.5, 'lr': 0.001, 'epochs': 30}

In [None]:
evaluator = None
dataset = None

# Ex. DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, num_workers=0)
train_loader = None
test_loader = None
val_loader = None

In [None]:
model = GCN_Graph(
    args['hidden_dim'], dataset.num_tasks, args['num_layers'], args['dropout']).to(device)

model.reset_parameters()

optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
loss_fn = torch.nn.BCEWithLogitsLoss()

best_model = None
best_valid_acc = 0

for epoch in range(1, 1 + args["epochs"]):

    print('Training...')
    loss = train(model, device, train_loader, optimizer, loss_fn)

    print('Evaluating...')
    train_result = eval(model, device, train_loader, evaluator)
    val_result = eval(model, device, val_loader, evaluator)
    test_result = eval(model, device, test_loader, evaluator)

    train_acc, valid_acc, test_acc = train_result[dataset.eval_metric], val_result[dataset.eval_metric], test_result[dataset.eval_metric]
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        best_model = copy.deepcopy(model)

    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}% Test: {100 * test_acc:.2f}%')

## Other Experiments (5 points)

Experiment with different global pooling layers in PyTorch Geometric and observe the changes in model performance. [Two global pooling layers other than mean pooling in Pytorch Geometric.]

In [None]:
# Experiment_1

In [None]:
# Experiment_2

Important Reminder: The functions provided in this notebook serve as initial templates and examples. You are encouraged to modify and adapt them as needed to suit your specific requirements or to improve performance. However, while making these changes, please endeavor to preserve the overall structure and objectives of the exercise.
