# Initialisation

In [3]:
import os
import torch
os.environ['TORCH'] = torch.__version__
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from torch_geometric.transforms import NormalizeFeatures

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

In [4]:
Dataframe_Labels = pd.read_csv("../../../BLCA_DATA/Workspace/labels_str.csv")
Dataframe_link = pd.read_csv("../../../BLCA_DATA/Workspace/patient_norm.csv")
Dataframe_node= pd.read_csv("../../../BLCA_DATA/Workspace/node_embedding.csv")

Dataframe_Labels['class_int'], uniques = pd.factorize(Dataframe_Labels['class'])
Dataframe_Labels

Unnamed: 0,Patient,class,class_int
0,TCGA-2F-A9KO,LumP,0
1,TCGA-2F-A9KP,LumP,0
2,TCGA-2F-A9KQ,LumP,0
3,TCGA-2F-A9KR,Ba/Sq,1
4,TCGA-2F-A9KT,Ba/Sq,1
...,...,...,...
399,TCGA-ZF-AA56,Ba/Sq,1
400,TCGA-ZF-AA58,Ba/Sq,1
401,TCGA-ZF-AA5H,Ba/Sq,1
402,TCGA-ZF-AA5N,LumP,0


# GNN - Mise en place

Encodage des labels :

In [5]:
node_features = Dataframe_node.drop(columns=['Patient']).values
node_features = torch.tensor(node_features, dtype=torch.float)

x = node_features
patient_similarity = cosine_similarity(Dataframe_link.iloc[:, 1:])
similarity_threshold = 0.5  # Exemple de seuil de similarité

edge_index = []
edge_attr = []

for i in range(patient_similarity.shape[0]):
    for j in range(i + 1, patient_similarity.shape[0]):
        if patient_similarity[i, j] > similarity_threshold:
            edge_index.append([i, j])
            edge_attr.append((patient_similarity[i, j] - similarity_threshold)/(1 - similarity_threshold))
        patient_similarity[i, i] = 0

edge_index = torch.tensor(edge_index, dtype=torch.int64).t().contiguous()
edge_features = torch.tensor(Dataframe_link.drop(columns=['Patient']).values, dtype=torch.float)
edge_attr = torch.tensor(edge_attr, dtype=torch.float)

temporary_node_tab = Dataframe_Labels["class_int"].values
node_labels = torch.tensor(temporary_node_tab, dtype=torch.long)
node_labels

tensor([0, 0, 0, 1, 1, 2, 2, 3, 0, 1, 2, 0, 1, 0, 2, 1, 1, 0, 1, 1, 2, 0, 2, 4,
        1, 1, 5, 1, 1, 1, 1, 1, 1, 0, 3, 3, 0, 1, 1, 2, 1, 5, 4, 1, 0, 1, 1, 0,
        1, 1, 5, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 0, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 2, 4, 1, 3, 0,
        3, 4, 1, 1, 1, 4, 2, 1, 1, 3, 0, 4, 1, 2, 1, 1, 3, 0, 1, 0, 0, 2, 4, 1,
        0, 1, 1, 0, 1, 1, 1, 3, 2, 5, 0, 0, 1, 2, 0, 0, 1, 1, 0, 1, 0, 2, 0, 0,
        0, 0, 0, 0, 0, 3, 2, 0, 0, 1, 0, 1, 2, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 2, 3, 3, 1, 1, 1, 3, 2, 1, 2, 2, 3, 3, 0, 1, 0, 3, 1, 1, 0, 3, 1, 3,
        3, 1, 1, 0, 1, 1, 4, 3, 3, 1, 0, 3, 3, 1, 4, 1, 2, 0, 2, 3, 0, 1, 0, 1,
        1, 0, 5, 0, 1, 1, 0, 2, 0, 1, 0, 2, 0, 1, 3, 0, 1, 0, 1, 1, 2, 1, 2, 4,
        4, 1, 1, 0, 2, 2, 1, 0, 1, 0, 1, 1, 0, 3, 2, 1, 4, 0, 2, 2, 0, 3, 2, 0,
        0, 1, 2, 0, 2, 0, 0, 2, 1, 1, 3, 4, 1, 1, 1, 0, 1, 1, 1, 1, 2, 3, 3, 0,
        0, 0, 2, 0, 2, 1, 2, 0, 2, 1, 1,

In [6]:
def set_mask(start, length):
    mask = []
    for i in range(404):
        if i < start or i >= start + length: 
            mask.append(False)
        else : 
            mask.append(True)
    return mask      

# data on which the model will be trained
train_mask = torch.tensor(set_mask(start=0, length=300), dtype=torch.bool)
val_mask = torch.tensor(set_mask(start=300, length=50), dtype=torch.bool)
test_mask = torch.tensor(set_mask(start=350, length=50), dtype=torch.bool)

train_mask  

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, 

In [7]:
data1 = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=node_labels, num_classes=6, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
print(data1)

print(len(data1.y.tolist()))
print(data1.y.tolist())
print(data1.train_mask.tolist())
print(data1.y[data1.train_mask].tolist())
print(data1.y.tolist())
print(data1.test_mask.tolist())
print(data1.y[data1.test_mask].tolist())
print(data1.y.tolist())
print(data1.val_mask.tolist())
print(data1.y[data1.val_mask].tolist())

print(data1.x)
print(data1.edge_index)

data = data1

Data(x=[404, 825], edge_index=[2, 49308], edge_attr=[49308], y=[404], num_classes=6, train_mask=[404], val_mask=[404], test_mask=[404])
404
[0, 0, 0, 1, 1, 2, 2, 3, 0, 1, 2, 0, 1, 0, 2, 1, 1, 0, 1, 1, 2, 0, 2, 4, 1, 1, 5, 1, 1, 1, 1, 1, 1, 0, 3, 3, 0, 1, 1, 2, 1, 5, 4, 1, 0, 1, 1, 0, 1, 1, 5, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 2, 4, 1, 3, 0, 3, 4, 1, 1, 1, 4, 2, 1, 1, 3, 0, 4, 1, 2, 1, 1, 3, 0, 1, 0, 0, 2, 4, 1, 0, 1, 1, 0, 1, 1, 1, 3, 2, 5, 0, 0, 1, 2, 0, 0, 1, 1, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 2, 0, 0, 1, 0, 1, 2, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 3, 3, 1, 1, 1, 3, 2, 1, 2, 2, 3, 3, 0, 1, 0, 3, 1, 1, 0, 3, 1, 3, 3, 1, 1, 0, 1, 1, 4, 3, 3, 1, 0, 3, 3, 1, 4, 1, 2, 0, 2, 3, 0, 1, 0, 1, 1, 0, 5, 0, 1, 1, 0, 2, 0, 1, 0, 2, 0, 1, 3, 0, 1, 0, 1, 1, 2, 1, 2, 4, 4, 1, 1, 0, 2, 2, 1, 0, 1, 0, 1, 1, 0, 3, 2, 1, 4, 0, 2, 2, 0, 3, 2, 0, 0, 1, 2, 0, 2, 0, 0, 2, 1, 1, 3, 4, 1, 1, 1, 0, 1, 1, 1, 1, 2, 3, 3

#### Edges

In [8]:
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
import seaborn as sn

class GATv2(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super(GATv2, self).__init__()
        torch.manual_seed(1234)
        self.conv1 = GATv2Conv(data.num_features, hidden_channels, heads=heads, edge_dim=1)
        self.conv2 = GATv2Conv(hidden_channels * heads, data.num_classes, edge_dim=1)
        
    def forward(self, x, edge_index, edge_attr):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index, edge_attr)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return x

def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.edge_attr)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test(model, data, mask):
    model.eval()
    out = model(data.x, data.edge_index, data.edge_attr)
    pred = out.argmax(dim=1)
    correct = pred[mask] == data.y[mask]
    acc = int(correct.sum()) / int(mask.sum())
    return acc, pred[mask]

def cross_validation(data, k_folds=5):
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=1234)
    all_test_acc = []

    for fold, (train_index, test_index) in enumerate(skf.split(data.x, data.y)):
        # Define masks
        data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.train_mask[train_index] = True
        data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.test_mask[test_index] = True

        # Initialize model, optimizer, and loss function
        model = GATv2(hidden_channels=8, heads=8)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
        criterion = torch.nn.CrossEntropyLoss()

        # Training loop
        for epoch in range(1, 1001):
            loss = train(model, data, optimizer, criterion)
            val_acc, _ = test(model, data, data.val_mask)
            test_acc, _ = test(model, data, data.test_mask)
            print(f'Fold: {fold + 1}, Epoch: {epoch:03d}, Loss: {loss:.4f}, Val {val_acc:.4f}, Test {test_acc:.4f}')

        # Evaluate on test set
        test_acc, y_pred = test(model, data, data.test_mask)
        y_true = data.y[data.test_mask]

        # Store results
        all_test_acc.append(test_acc)
        conf_matrix = confusion_matrix(y_true.cpu().tolist(), y_pred.cpu().tolist())

        # Plot confusion matrix
        classes = ('LumP', 'Ba/Sq', 'LumU', 'Stroma-rich', 'LumNS', 'NE-like')
        df_cm = pd.DataFrame(conf_matrix, index = [i for i in classes], columns = [i for i in classes])
        #plt.figure(figsize=(12,7))
        sn.heatmap(df_cm, annot=True)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.savefig('output.png')
        plt.show()
            
    # Calculate and print overall metrics
    mean_test_acc = np.mean(all_test_acc)
    std_test_acc = np.std(all_test_acc)
    print(f'Mean Test Accuracy: {mean_test_acc:.4f}, Std Test Accuracy: {std_test_acc:.4f}')
    
cross_validation(data, k_folds=5)

Fold: 1, Epoch: 001, Loss: 1.8446, Val 0.2200, Test 0.3210
Fold: 1, Epoch: 002, Loss: 8.8413, Val 0.2200, Test 0.3210
Fold: 1, Epoch: 003, Loss: 3.4744, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 004, Loss: 3.8248, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 005, Loss: 3.8794, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 006, Loss: 3.7523, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 007, Loss: 3.2025, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 008, Loss: 3.2805, Val 0.1200, Test 0.1358
Fold: 1, Epoch: 009, Loss: 2.7699, Val 0.1200, Test 0.1358
Fold: 1, Epoch: 010, Loss: 2.4546, Val 0.1200, Test 0.1358
Fold: 1, Epoch: 011, Loss: 2.4440, Val 0.1200, Test 0.1358
Fold: 1, Epoch: 012, Loss: 1.7970, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 013, Loss: 1.8054, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 014, Loss: 1.6183, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 015, Loss: 1.6240, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 016, Loss: 1.6903, Val 0.4000, Test 0.3704
Fold: 1, Epoch: 017, Loss: 1.6699, Val 0.4000, Test 0.37

KeyboardInterrupt: 