In [1]:
!pip install torch-scatter
!pip install torch-geometric
!pip install plotnine




In [None]:
import random
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [3]:
# load data
metacells = pd.read_csv("./oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./oligo-SCZ-tom.csv", index_col=0)

In [4]:
metadata.iloc[1:5000:50,:]["disorder"]

Oligo#CON1_2             Control
Oligo#CON10_5            Control
Oligo#CON10_55           Control
Oligo#CON10_105          Control
Oligo#CON11_41           Control
                       ...      
Oligo#SZ15_10      Schizophrenia
Oligo#SZ16_47      Schizophrenia
Oligo#SZ16_97      Schizophrenia
Oligo#SZ16_147     Schizophrenia
Oligo#SZ18_1       Schizophrenia
Name: disorder, Length: 100, dtype: object

In [5]:
metacells.iloc[:,0].values

array([0.        , 0.        , 1.28492008, ..., 3.60950374, 1.82917698,
       0.83603014])

In [6]:
# hyperparameters
TOM_THRESHOLD = 0.025  # value below which we zero out the similarity that will be used as attention priors

In [7]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
edges = torch.tensor(edges)

100%|██████████| 6679/6679 [05:47<00:00, 19.22it/s] 


In [8]:
metacellsReduced = metacells.iloc[:,0:5000:50]
metadataReduced = metadata.iloc[0:5000:50,:]

In [9]:
# # NOT USING CURRENTLY

# class GeneGraphDataset(InMemoryDataset):
#     def __init__(self, root, edge_index, expr_mat, meta, transform=None, pre_transform=None):
#         self.edge_index = edge_index
#         self.num_graphs = expr_mat.shape[1]
#         self.expr_mat = expr_mat
#         self.num_classes = 2
#         self.y = meta["disorder"]
#         super().__init__(root, transform, pre_transform)
#         self.load("./graphData/processed/combined.pt")

#     @property
#     def raw_file_names(self):
#         pass

#     @property
#     def processed_file_names(self):
#         return [f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt" for i in range(self.num_graphs)]

#     def process(self):
#         data_list = []
#         for i in tqdm(range(self.num_graphs)):
#             node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
#             data = Data(x=node_features, edge_index=self.edge_index)
#             data_list.append(data)
#             torch.save(data, f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

#         data, slices = self.collate(data_list)
#         torch.save((data, slices), f"./graphData/processed/combined.pt")

#     def get(self, idx=None, sample=None):
#         if idx is None:
#             data = torch.load(f"./graphData/processed/{sample}-graph.pt", weights_only=False)
#         else:
#             data = torch.load(f"./graphData/processed/{self.expr_mat.columns[idx]}-graph.pt", weights_only=False)
#         return data


# dataset = GeneGraphDataset(root='./graphData', edge_index=edges.t().contiguous(), expr_mat=metacellsReduced, meta=metadataReduced)

In [None]:
le = LabelEncoder()
metadataReduced.loc[:,"disorder_encoded"] = le.fit_transform(metadataReduced["disorder"])

dataset = []
nGenes = metacellsReduced.shape[0]
edgesT = edges.t().contiguous()
num_samples = min(100, metacellsReduced.shape[1], metadataReduced.shape[0])  # Ensure we don't exceed available data
for i in range(num_samples):
    dataset.append(
    Data(
        x=torch.tensor(metacellsReduced.iloc[:, i].values, dtype=torch.float32).reshape((nGenes, 1)),
        edge_index=edgesT,
        y=torch.tensor(metadataReduced["disorder_encoded"].iloc[i], dtype=torch.long)
    )
)
random.shuffle(dataset)
loader = DataLoader(dataset, batch_size=4)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [None]:
def train_test(metacells, metadata, edges, idx_train, idx_test):
    train = []
    test = []
    nGenes = metacells.shape[0]
    edgesT = edges.t().contiguous()
    for i in idx_train:
        train.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=edgesT,
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        ))
    for i in idx_test:
        test.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=edgesT,
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        )) 
    random.shuffle(train)
    random.shuffle(test)
    return DataLoader(train, batch_size=4), DataLoader(test, batch_size=4)

In [None]:
metadata.loc[:,"disorder_encoded"] = le.fit_transform(metadata["disorder"])

In [12]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.num_features = 1
        self.hidden_layers = 1
        self.k = 128
        self.in_heads = 8
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(self.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.4)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.4)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.4)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k, dropout=0.4)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(self.k, 2)

    def _dense_adj(self, edge_index, alpha, batch):
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)  # shape: [B, N, N]
        deg = adj.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        norm_adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        return norm_adj


    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch
        # print(x)
        # print(ei)
        # print(batch)

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 
        x = F.elu(self.bn3(x))

        # convert to dense batch and corresponding mask
        x_dense, mask = to_dense_batch(x, batch)

        # build dense adjacency matrix
        alpha = alpha3.mean(dim=1)
        adj_dense = to_dense_adj(ei3, batch=batch, edge_attr=alpha)
        deg = adj_dense.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        adj_dense = deg_inv_sqrt.unsqueeze(-1) * adj_dense * deg_inv_sqrt.unsqueeze(-2)

        S, x, adj, mod, ort, clu = self.pool(x_dense, adj_dense, mask)

        # read‑out
        x = self.proj(x).squeeze(-1)
        logits = self.classifier(x)

        pool_reg = mod + clu + 0.1 * ort
        return logits, pool_reg
        

In [None]:
def train_epoch(loader, model, optimizer, device):
    model.train()
    total_loss = 0
    iter_loss = []
    train_iter = tqdm(loader)
    for data in train_iter:
        data = data.to(device)
        # data.y = data.y.long()
        optimizer.zero_grad()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        loss.backward()
        optimizer.step()
        train_iter.set_description(f"(loss {loss.item()})")
        iter_loss.append(loss.item())
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset), iter_loss

def test_epoch(loader, model, device):
    model.eval()
    total_loss = 0
    correct = 0
    correct_readouts = []
    test_iter = tqdm(loader)
    for data in test_iter:
        data = data.to(device)
        # data.y = data.y.long()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        test_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset), correct / len(loader.dataset), correct_readouts

In [None]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
device = 'cpu'

train_loader, test_loader = train_test(metacells, metadata, edges, list(range(0, 5000, 50)), list(range(25, 5025, 50)))

loss, loss_arr = train_epoch(train_loader, model, optimizer, device)
print(loss)

  4%|▍         | 1/25 [00:45<18:21, 45.90s/it]

Batch 1/25: Loss: 0.5056


  8%|▊         | 2/25 [01:25<16:10, 42.21s/it]

Batch 2/25: Loss: 0.4776


 12%|█▏        | 3/25 [02:11<16:05, 43.89s/it]

Batch 3/25: Loss: 0.4883


 16%|█▌        | 4/25 [02:50<14:39, 41.89s/it]

Batch 4/25: Loss: 0.4961


 20%|██        | 5/25 [03:37<14:39, 43.98s/it]

Batch 5/25: Loss: 0.3855


 24%|██▍       | 6/25 [04:16<13:18, 42.04s/it]

Batch 6/25: Loss: 0.5240


 28%|██▊       | 7/25 [04:59<12:43, 42.39s/it]

Batch 7/25: Loss: 0.3922


 32%|███▏      | 8/25 [05:40<11:52, 41.92s/it]

Batch 8/25: Loss: 0.6588


 36%|███▌      | 9/25 [06:20<11:02, 41.38s/it]

Batch 9/25: Loss: 0.3633


 40%|████      | 10/25 [07:02<10:23, 41.58s/it]

Batch 10/25: Loss: 0.3893


 44%|████▍     | 11/25 [07:46<09:53, 42.39s/it]

Batch 11/25: Loss: 0.4746


 48%|████▊     | 12/25 [08:15<08:17, 38.26s/it]

Batch 12/25: Loss: 0.5020


 52%|█████▏    | 13/25 [08:43<07:00, 35.05s/it]

Batch 13/25: Loss: 0.3590


 56%|█████▌    | 14/25 [09:08<05:54, 32.24s/it]

Batch 14/25: Loss: 0.4852


 60%|██████    | 15/25 [09:32<04:57, 29.75s/it]

Batch 15/25: Loss: 0.3858


 64%|██████▍   | 16/25 [09:57<04:13, 28.14s/it]

Batch 16/25: Loss: 0.5541


 68%|██████▊   | 17/25 [10:20<03:33, 26.67s/it]

Batch 17/25: Loss: 0.6258


 72%|███████▏  | 18/25 [10:44<03:00, 25.77s/it]

Batch 18/25: Loss: 0.7096


 76%|███████▌  | 19/25 [11:08<02:31, 25.21s/it]

Batch 19/25: Loss: 0.4449


 80%|████████  | 20/25 [11:31<02:02, 24.52s/it]

Batch 20/25: Loss: 3.3261


 84%|████████▍ | 21/25 [11:55<01:38, 24.53s/it]

Batch 21/25: Loss: 5.1089


 88%|████████▊ | 22/25 [12:17<01:11, 23.88s/it]

Batch 22/25: Loss: 4.6731


 92%|█████████▏| 23/25 [12:41<00:47, 23.83s/it]

Batch 23/25: Loss: 5.8178


 96%|█████████▌| 24/25 [13:06<00:24, 24.15s/it]

Batch 24/25: Loss: 5.6066


100%|██████████| 25/25 [13:29<00:00, 32.39s/it]


Batch 25/25: Loss: 5.6775
Train Loss: {epoch_loss:.4f}


100%|██████████| 25/25 [06:29<00:00, 15.56s/it]


Test Loss: 1.1116, Test Accuracy: 0.2200


In [None]:
total_loss, correct, correct_readouts = test_epoch(test_loader, model, device)
print("Test Loss: ", total_loss)
print("Test Accuracy: ", correct)