In [None]:
"""
We start by cloning the repository and installing the required packages for the ProteinShake dataset.
"""
# !pip install git+https://github.com/BorgwardtLab/proteinshake.git

'\nWe start by cloning the repository and installing the required packages for the ProteinShake dataset.\n'

In [None]:
"""
Since we are using Google Colab, let's install the required packages for pytorch geometric.
Details on the installation can be found on https://pytorch-geometric.readthedocs.io/en/2.5.2/notes/installation.html
The pip commands below should only be run once in the notebook. Once the packages are installed, you can comment them out.
"""

import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# ! pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# !pip install wandb

2.5.1+cu124


# Let's starting by simply creating a PyG Data instance from a Point Cloud using K-nearest neighbors.

In [None]:
import torch
from torch_geometric.transforms import KNNGraph
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Necessary for 3D plotting


def create_knn_point_cloud_data(num_points=100, k=5):
    """
    Creates a random 3D point cloud of size `num_points`,
    then builds a k-NN graph with `k` neighbors per node.
    Returns a PyG Data object.
    """
    knn_graph= ### ... Define Class


    pos =  ## ... Generate random 3D points

    data = ### ... Create data instance

    data = # ... Build the k-NN graph

    return data


def visualize_point_cloud(data, show_edges=False):
    """
    Visualizes the point cloud in 3D.
    If `show_edges` is True, also draws lines between connected points.
    """
    # Convert PyTorch tensor to NumPy array for plotting
    pos = data.pos.cpu().numpy()

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    # Plot the points
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2],
               c='b', marker='o', s=20, alpha=0.8)

    if show_edges:
        # edge_index is [2, num_edges]
        edge_index = data.edge_index
        num_edges = edge_index.shape[1]

        # Loop through edges and draw lines between them
        for i in range(num_edges):
            start = edge_index[0, i].item()
            end = edge_index[1, i].item()

            # Points for the edge line
            x_vals = [pos[start, 0], pos[end, 0]]
            y_vals = [pos[start, 1], pos[end, 1]]
            z_vals = [pos[start, 2], pos[end, 2]]

            ax.plot(x_vals, y_vals, z_vals, c='r', alpha=0.3)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.title('3D Point Cloud with k-NN Graph' if show_edges else '3D Point Cloud')
    plt.show()


if __name__ == "__main__":
    # Create a point cloud Data object
    data = create_knn_point_cloud_data(num_points=100, k=5)
    # Visualize only points
    visualize_point_cloud(data, show_edges=False)
    # Visualize points with edges
    visualize_point_cloud(data, show_edges=True)


In [None]:
import torch
import torch_geometric
import wandb
import copy
from tqdm import tqdm
import torch.nn.functional as F
from proteinshake import tasks as ps_tasks
import torch_geometric.transforms as T
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch.nn import Linear
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv,global_mean_pool, ChebConv,global_add_pool
import torch.nn as nn

"""
Use arguments for the main model parameters we want to try in this notebook
"""

import argparse
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument('--hidden', type=int, default=320, help='Latent Dimension')
parser.add_argument('--seed', type=int, default=1234, help='Random Seed')
parser.add_argument('--batch_size', type=int, default=100, help='Batch Size')
parser.add_argument('--num_layers', type=int, default=7, help='Number of graph convolutional layers')
parser.add_argument('--backbone', type=str, default='SAGE', help='Use GCN as backbone- otherwise SAGE')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate')
args = parser.parse_args([])


# **Load the dataset and transform to graph format using k-nn**

In [None]:
""" Load the task and the dataset"""

datapath = './data/ec'
task = ps_tasks.BindingSiteDetectionTask(root=datapath)
dset = task.dataset

"""We convert the protein 3D structures to $\epsilon$-graphs ($\epsilon=8$ here):"""

def transform(data):
    data, protein_dict = data
    data.y = task.target(protein_dict)
    return data

dset = dset.to_graph(eps=8.0).pyg(transform=transform)

# **Let's plot the first sample of the dataset to get an idea about the geometry of the protein graphs we are dealing with.**

In [None]:
### Sample test for visualization
sample= ### ... Load first sample of the dataset to provide it for visualization below


In [None]:
sample

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

G = ### ... Convert to networkx format

pos = ###... Use spring layout with k=0.1 and seed=42 to visualize the graph

# Plot the graph
plt.figure(figsize=(8, 8))
nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=300, font_size=10)
plt.title("Visualization of the Expanded Protein Graph")
plt.show()

# **Is there something specific you realize about this sample visualization ? Maybe some constraint we discussed in the slides ? 🧠**

# **Load train/val/test splits; We can now create data loaders for train/val/test sets provided by ProteinShake:**

In [None]:
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

train_loader = DataLoader(Subset(dset, task.train_index), batch_size=args.batch_size,
                          shuffle=True, num_workers=0)
val_loader = DataLoader(Subset(dset, task.val_index), batch_size=args.batch_size,
                        shuffle=False, num_workers=0)
test_loader = DataLoader(Subset(dset, task.test_index), batch_size=args.batch_size,
                         shuffle=False, num_workers=0)

# **Plot the distribution of the number of nodes over all samples in the dataset**

In [None]:
import matplotlib.pyplot as plt
# Compute the number of nodes for each graph
num_nodes = [data.num_nodes for data in dset]

# Plot the distribution
plt.figure(figsize=(12, 6))
bins = range(0, max(num_nodes) + 100, 100)  # Create bins with an interval of 100
plt.hist(num_nodes, bins=bins, edgecolor='black', alpha=0.75)

# Add labels and title
plt.xlabel('Number of Nodes', fontsize=14)
plt.ylabel('Frequency', fontsize=14)
plt.title('Distribution of Number of Nodes in the dataset', fontsize=16)
plt.xticks(bins, rotation=45, ha='right')  # Rotate x-axis labels and align them to the right
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()  # Adjust layout to prevent overlapping
plt.show()

# **Below we define the Graph Neural Network (GNN) that takes as input the protein graph samples constructed above and embeds them in a latent space. In the latter, the GNN learns the optimal weights to classify the nodes as belonging (or NOT) to the binding site.**

In [None]:
"""
Our model will start by first embedding the protein to a
latent space of dimension "hidden_dims" which is a hyper-parameter to tune.

Then a number "num_layers" of graph convolutional blocks is applied to the graph
to learn over different neighborhoods of each node.

Finally, the graph embeddings are sent to a linear classifier to predict
whether each node belongs to the binding site or not.

"""

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

class ProteinModel(torch.nn.Module):

    def __init__(self, hidden_dims,num_layers,num_classes):
        super(ProteinModel, self).__init__()

        self.x_embedding = nn.Embedding(20, hidden_dims)

        ### ... Build graph convolutional layers here


    def forward(self, x, edge_index, batch,device):
        x = self.x_embedding(x)### This is an inital embedding to the correct latent space

          ## Add the correspding number of graph convolutional layers you defined above
          ## Add batchnorm and relu
        classifier= ### ... Linear Classifier head for output

        return classifier


In [None]:
"""
We start by building a GCN model with x layers and x hidden dimensions (Check args list above)
"""
num_classes=task.num_classes
model=ProteinModel(args.hidden,args.num_layers,num_classes)

"""
Build an optimizer and define the train and test function
"""

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.lr
)
criterion = nn.CrossEntropyLoss()

# set device
device = torch.device(torch.cuda.current_device()) \
        if torch.cuda.is_available() else torch.device('cpu')


def retrieve_sets(flattened_list, length_indices):
    sets = []
    start_index = 0
    for length in length_indices:
        end_index = start_index + length
        subset = flattened_list[start_index:end_index]
        sets.append(subset.detach().cpu().numpy())
        start_index = end_index
    return sets


def train_epoch(model):
    model.train()
    model.cuda()
    running_loss = 0.
    for step, batch in enumerate(train_loader):
        size = len(batch.y)
        batch = batch.to(device)

        optimizer.zero_grad()

        y_hat=model(batch.x,batch.edge_index,batch.batch,device)

        # Flatten the list of lists
        flattened_list = [item for sublist in batch.y for item in sublist]

        # Convert the flattened list into a tensor
        tenso = torch.tensor(flattened_list)

        loss = criterion(y_hat, tenso.cuda())
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * size

    n_sample = len(train_loader.dataset)
    epoch_loss = running_loss / n_sample
    return epoch_loss


"""ProteinShake provides an evaluation function for each task `task.evaluate(y_true, y_pred)`."""

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()

    y_true = []
    y_pred = []

    scoresAll=0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        y_hat=model(batch.x,batch.edge_index,batch.batch,device)
        y_pred = y_hat.argmax(-1)

        length_indices=[]
        scoresBatch=0
        for m in range(len(batch.y)):
          length_indices.append(len(batch.y[m]))

        predictions=retrieve_sets(y_pred, length_indices)
        for n in range(len(batch.y)):
          scores = task.evaluate(batch.y[n], predictions[n])#['mcc']
          scoresBatch+=scores['mcc']
        scoresAll+=scoresBatch/len(batch.y)

    #     y_true.append(batch.y.cpu())
    #     y_pred.append(y_hat.cpu())

    # y_true = torch.cat(y_true, dim = 0).numpy()
    # y_pred = torch.vstack(y_pred).numpy()
    # y_pred = y_pred.argmax(-1)
    # scores = task.evaluate(y_true, y_pred)
    return scoresAll/(step+1)



In [None]:
args.backbone

'SAGE'

In [None]:
"""## Training"""

model.to(device)

epochs = 100 # we train only 20 epochs here, but more epochs may result in better performance.

config = dict (
  Changes="None",
  hidden_dim=args.hidden,
  batch_size=args.batch_size,
  learning_rate = args.lr,
  seed = args.seed,
  layers = args.num_layers
)

import wandb

wandb.init(
project="Winter School Tutorial",
name=args.backbone,
config=config,
)

best_val_score = 0.0
pbar = tqdm(range(epochs))
for epoch in pbar:
    train_loss = train_epoch(model)
    val_scores = eval_epoch(model, val_loader)
    val_score = val_scores#['mcc']
    postfix = {'train_loss': train_loss, 'val_acc': val_score}

    wandb.log({"Val Acc": val_score})
    wandb.log({"Train Loss": train_loss})
    wandb.log({"Epoch": epoch})

    pbar.set_postfix(postfix)

    if val_score > best_val_score:
        best_val_score = val_score
        best_weights = copy.deepcopy(model.state_dict())

model.load_state_dict(best_weights)

"""## Testing the trained model"""

task.evaluate

test_scores = eval_epoch(model, test_loader)
wandb.log({"Test": test_scores})
print(test_scores)


[34m[1mwandb[0m: Currently logged in as: [33mahariri[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


100%|██████████| 100/100 [32:22<00:00, 19.42s/it, train_loss=0.121, val_acc=0.746]


0.703429291462553


# **You have trained your first GNN model on Protein, congrats ! Now that this task is done, re-run the above cells by changing the backbone GNN. For instance, if you started by training a GCN backbone, change it now to a SAGE backbone and see how the plots compare on WandB.**

# **Now try increasing the number of layers for each backbone, how do the results change ?**