In [None]:
!pip install torch-scatter
!pip install torch-geometric
!pip install plotnine


In [1]:
import random
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [2]:
# load data
metacells = pd.read_csv("./oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./oligo-SCZ-tom.csv", index_col=0)

In [16]:
le = LabelEncoder()
metadata.loc[:,"disorder_encoded"] = le.fit_transform(metadata["disorder"])

In [3]:
# hyperparameters
TOM_THRESHOLD = 0.03  # value below which we zero out the similarity that will be used as attention priors

In [4]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
#edges = torch.tensor(edges)

100%|██████████| 6679/6679 [04:45<00:00, 23.38it/s] 


In [11]:
# # NOT USING CURRENTLY

# class GeneGraphDataset(InMemoryDataset):
#     def __init__(self, root, edge_index, expr_mat, meta, transform=None, pre_transform=None):
#         self.edge_index = edge_index
#         self.num_graphs = expr_mat.shape[1]
#         self.expr_mat = expr_mat
#         self.num_classes = 2
#         self.y = meta["disorder"]
#         super().__init__(root, transform, pre_transform)
#         self.load("./graphData/processed/combined.pt")

#     @property
#     def raw_file_names(self):
#         pass

#     @property
#     def processed_file_names(self):
#         return [f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt" for i in range(self.num_graphs)]

#     def process(self):
#         data_list = []
#         for i in tqdm(range(self.num_graphs)):
#             node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
#             data = Data(x=node_features, edge_index=self.edge_index)
#             data_list.append(data)
#             torch.save(data, f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

#         data, slices = self.collate(data_list)
#         torch.save((data, slices), f"./graphData/processed/combined.pt")

#     def get(self, idx=None, sample=None):
#         if idx is None:
#             data = torch.load(f"./graphData/processed/{sample}-graph.pt", weights_only=False)
#         else:
#             data = torch.load(f"./graphData/processed/{self.expr_mat.columns[idx]}-graph.pt", weights_only=False)
#         return data


# dataset = GeneGraphDataset(root='./graphData', edge_index=edges.t().contiguous(), expr_mat=metacellsReduced, meta=metadataReduced)

In [None]:
# NOT IN USE CURRENTLY
# le = LabelEncoder()
# metadataReduced.loc[:,"disorder_encoded"] = le.fit_transform(metadataReduced["disorder"])

# dataset = []
# nGenes = metacellsReduced.shape[0]
# num_samples = min(100, metacellsReduced.shape[1], metadataReduced.shape[0])  # Ensure we don't exceed available data
# for i in tqdm(range(num_samples)):
#     dataset.append(
#     Data(
#         x=torch.tensor(metacellsReduced.iloc[:, i].values, dtype=torch.float32).reshape((nGenes, 1)),
#         edge_index=torch.tensor(edges).t().contiguous(),
#         y=torch.tensor(metadataReduced["disorder_encoded"].iloc[i], dtype=torch.long)
#     )
# )
# random.shuffle(dataset)
# loader = DataLoader(dataset, batch_size=4)

In [12]:
def train_test(metacells, metadata, edges, idx_train, idx_test, bs):
    train = []
    test = []
    nGenes = metacells.shape[0]
    print("-- Loading training data --")
    for i in tqdm(idx_train):
        train.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=torch.tensor(edges).t().contiguous(),
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        ))
    print("-- Loading testing data --")
    for i in tqdm(idx_test):
        test.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=torch.tensor(edges).t().contiguous(),
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        )) 
    random.shuffle(train)
    random.shuffle(test)
    return DataLoader(train, batch_size=bs, shuffle=True), DataLoader(test, batch_size=bs, shuffle=True)

In [13]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.num_features = 1
        self.hidden_layers = 1
        self.k = 64
        self.in_heads = 4
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(self.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.2)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.2)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.2)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k, dropout=0.2)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(self.k, 2)

    def _dense_adj(self, edge_index, alpha, batch):
        # alpha = alpha.mean(dim=1)
        # adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)
        # out = []
        # for i in range(adj.size(0)):
        #     A = adj[i]
        #     deg_inv = A.sum(-1).clamp(min=1e-12).pow(-0.5)
        #     norm_adj = deg_inv.unsqueeze(1) * A * deg_inv.unsqueeze(0)
        #     out.append(norm_adj)
        # return torch.stack(out)
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)  # shape: [B, N, N]
        deg = adj.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        norm_adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        return norm_adj


    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 
        x = F.elu(self.bn3(x))

        # convert to dense batch and corresponding mask
        x_dense, mask = to_dense_batch(x, batch)

        alpha = alpha3.mean(dim=1)
        adj_dense = to_dense_adj(ei3, batch=batch, edge_attr=alpha)
        deg = adj_dense.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        adj_dense = deg_inv_sqrt.unsqueeze(-1) * adj_dense * deg_inv_sqrt.unsqueeze(-2)

        # # build dense Laplacian surrogate from layer‑3 attention
        # adj = self._dense_adj(ei3, alpha3, batch)
        # mask = torch.ones(adj.size(0), adj.size(1), dtype=torch.bool, device=adj.device)

        S, x, adj, mod, ort, clu = self.pool(x_dense, adj_dense, mask)

        # read‑out
        x = self.proj(x).squeeze(-1)
        logits = self.classifier(x)

        pool_reg = mod + clu + 0.1 * ort
        return logits, pool_reg
        

In [14]:
def train_epoch(loader, model, optimizer, device):
    model.train()
    total_loss = 0
    iter_loss = []
    train_iter = tqdm(loader)
    correct_readouts = []
    for data in train_iter:
        data = data.to(device)
        # data.y = data.y.long()
        optimizer.zero_grad()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        #print(out)
        pred = out.argmax(dim=1)
        
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        train_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        
        loss.backward()
        optimizer.step()
        #train_iter.set_description(f"(loss {loss.item()})")
        iter_loss.append(loss.item())
        total_loss += loss.item() * data.num_graphs
    print(iter_loss)
    return total_loss / len(loader.dataset)

def test_epoch(loader, model, device):
    model.eval()
    total_loss = 0
    correct = 0
    correct_readouts = []
    test_iter = tqdm(loader)
    for data in test_iter:
        data = data.to(device)
        # data.y = data.y.long()
        out, pool_reg = model(data)
        #print(out)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        test_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        total_loss += loss.item() * data.num_graphs
    print(correct_readouts)
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [18]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
device = 'cpu'

#train_loader, test_loader = train_test(metacells, metadata, edges, list(range(0, 5000, 50)), list(range(25, 5025, 50)))
#samples_use = random.sample(list(range(metacells.shape[1])), 300)
con_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 1].sample(n=180).index.values]
scz_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 0].sample(n=180).index.values]

samples_train = con_samples[:150]
samples_train.extend(scz_samples[:150])
random.shuffle(samples_train)

samples_test = con_samples[150:]
samples_test.extend(scz_samples[150:])
random.shuffle(samples_test)

train_loader, test_loader = train_test(metacells, metadata, edges, samples_train, samples_test, 16)

loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")

-- Loading training data --


100%|██████████| 300/300 [13:24<00:00,  2.68s/it]


-- Loading testing data --


100%|██████████| 60/60 [02:42<00:00,  2.71s/it]
(loss 1.0976862907409668; 8/16 [SCZ(√) CON(√) CON(√) SCZ(√) CON(x) SCZ(x) SCZ(x) SCZ(√) SCZ(x) CON(√) CON(√) SCZ(√)]): 100%|██████████| 19/19 [10:29<00:00, 33.14s/it]                             


[3.814974784851074, 2.5474331378936768, 2.6891894340515137, 1.3881561756134033, 1.7792565822601318, 1.254065752029419, 1.6825120449066162, 2.746351480484009, 2.3953263759613037, 1.3268976211547852, 1.4262665510177612, 1.181106448173523, 1.120349645614624, 1.1021437644958496, 0.8499464988708496, 1.1406341791152954, 1.156118631362915, 1.1296062469482422, 1.0976862907409668]
1.6828586705525717
---------


(loss 1.3325178623199463; 6/16 [SCZ(√) SCZ(x) CON(x) CON(√) CON(x) SCZ(x) SCZ(√) SCZ(√) SCZ(x) CON(x) SCZ(√) CON(√)]): 100%|██████████| 19/19 [10:14<00:00, 32.34s/it]                             


[1.1380420923233032, 0.8862426280975342, 1.1923143863677979, 1.2817268371582031, 1.5939185619354248, 1.0858101844787598, 1.0377044677734375, 1.019392490386963, 1.4167453050613403, 1.0878490209579468, 0.8589027523994446, 1.0719878673553467, 1.275989294052124, 1.3627748489379883, 1.0286309719085693, 1.2043269872665405, 1.0068129301071167, 1.0057168006896973, 1.3325178623199463]
1.1495614306131998
---------


(loss 1.006194829940796; 6/16 [SCZ(√) SCZ(x) SCZ(x) SCZ(√) CON(√) SCZ(x) SCZ(x) SCZ(x) CON(√) CON(√) SCZ(x) CON(√)]): 100%|██████████| 19/19 [10:30<00:00, 33.16s/it]                              

[0.8160133361816406, 1.182992935180664, 1.032702088356018, 1.143090844154358, 1.3953118324279785, 1.0840619802474976, 0.8551602959632874, 1.1616489887237549, 1.2275526523590088, 1.3359923362731934, 1.067302942276001, 1.0237374305725098, 0.8108710050582886, 1.1549307107925415, 0.7621551752090454, 1.106541633605957, 0.943183183670044, 1.1669431924819946, 1.006194829940796]
1.0679913965861003
---------





In [None]:
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")
loss = train_epoch(train_loader, model, optimizer, device)
print(loss)
print("---------")

In [19]:
total_loss, correct = test_epoch(test_loader, model, device)

(loss 0.956275463104248; 9/16 [CON(√) SCZ(√) SCZ(x) SCZ(√) CON(x) CON(√) SCZ(√) CON(√) SCZ(√) SCZ(x) CON(√) SCZ(√)]): 100%|██████████| 4/4 [00:55<00:00, 13.99s/it]                              


['SCZ(x)', 'CON(√)', 'SCZ(√)', 'CON(√)', 'SCZ(x)', 'CON(√)', 'SCZ(√)', 'CON(√)', 'SCZ(√)', 'SCZ(√)', 'CON(√)', 'SCZ(x)', 'CON(√)', 'CON(√)', 'SCZ(√)', 'SCZ(√)', 'CON(x)', 'SCZ(√)', 'CON(x)', 'CON(√)', 'CON(√)', 'CON(x)', 'CON(√)', 'SCZ(x)', 'SCZ(√)', 'CON(x)', 'SCZ(√)', 'SCZ(x)', 'SCZ(√)', 'SCZ(x)', 'CON(x)', 'SCZ(√)', 'CON(x)', 'CON(x)', 'SCZ(x)', 'CON(√)', 'CON(√)', 'CON(x)', 'CON(√)', 'CON(√)', 'SCZ(√)', 'SCZ(√)', 'CON(x)', 'SCZ(√)', 'CON(√)', 'SCZ(x)', 'SCZ(√)', 'CON(√)', 'CON(√)', 'SCZ(√)', 'SCZ(x)', 'SCZ(√)', 'CON(x)', 'CON(√)', 'SCZ(√)', 'CON(√)', 'SCZ(√)', 'SCZ(x)', 'CON(√)', 'SCZ(√)']


In [20]:
correct

0.6666666666666666