In [None]:
!pip install torch-scatter
!pip install torch-geometric
!pip install plotnine


In [27]:
import random
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [8]:
# load data
metacells = pd.read_csv("./oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./oligo-SCZ-tom.csv", index_col=0)

In [6]:
metadata.iloc[1:5000:50,:]["disorder"]

Oligo#CON1_2             Control
Oligo#CON10_5            Control
Oligo#CON10_55           Control
Oligo#CON10_105          Control
Oligo#CON11_41           Control
                       ...      
Oligo#SZ15_10      Schizophrenia
Oligo#SZ16_47      Schizophrenia
Oligo#SZ16_97      Schizophrenia
Oligo#SZ16_147     Schizophrenia
Oligo#SZ18_1       Schizophrenia
Name: disorder, Length: 100, dtype: object

In [7]:
metacells.iloc[:,0].values

array([0.        , 0.        , 1.28492008, ..., 3.60950374, 1.82917698,
       0.83603014])

In [9]:
# hyperparameters
TOM_THRESHOLD = 0.025  # value below which we zero out the similarity that will be used as attention priors

In [10]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
edges = torch.tensor(edges)

100%|██████████| 6679/6679 [04:53<00:00, 22.75it/s] 


In [11]:
metacellsReduced = metacells.iloc[:,0:5000:50]
metadataReduced = metadata.iloc[0:5000:50,:]

In [11]:
# # NOT USING CURRENTLY

# class GeneGraphDataset(InMemoryDataset):
#     def __init__(self, root, edge_index, expr_mat, meta, transform=None, pre_transform=None):
#         self.edge_index = edge_index
#         self.num_graphs = expr_mat.shape[1]
#         self.expr_mat = expr_mat
#         self.num_classes = 2
#         self.y = meta["disorder"]
#         super().__init__(root, transform, pre_transform)
#         self.load("./graphData/processed/combined.pt")

#     @property
#     def raw_file_names(self):
#         pass

#     @property
#     def processed_file_names(self):
#         return [f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt" for i in range(self.num_graphs)]

#     def process(self):
#         data_list = []
#         for i in tqdm(range(self.num_graphs)):
#             node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
#             data = Data(x=node_features, edge_index=self.edge_index)
#             data_list.append(data)
#             torch.save(data, f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

#         data, slices = self.collate(data_list)
#         torch.save((data, slices), f"./graphData/processed/combined.pt")

#     def get(self, idx=None, sample=None):
#         if idx is None:
#             data = torch.load(f"./graphData/processed/{sample}-graph.pt", weights_only=False)
#         else:
#             data = torch.load(f"./graphData/processed/{self.expr_mat.columns[idx]}-graph.pt", weights_only=False)
#         return data


# dataset = GeneGraphDataset(root='./graphData', edge_index=edges.t().contiguous(), expr_mat=metacellsReduced, meta=metadataReduced)

In [33]:
le = LabelEncoder()
metadataReduced.loc[:,"disorder_encoded"] = le.fit_transform(metadataReduced["disorder"])

dataset = []
nGenes = metacellsReduced.shape[0]
edgesT = edges.t().contiguous()
num_samples = min(100, metacellsReduced.shape[1], metadataReduced.shape[0])  # Ensure we don't exceed available data
for i in range(num_samples):
    dataset.append(
    Data(
        x=torch.tensor(metacellsReduced.iloc[:, i].values, dtype=torch.float32).reshape((nGenes, 1)),
        edge_index=edgesT,
        y=torch.tensor(metadataReduced["disorder_encoded"].iloc[i], dtype=torch.long)
    )
)
random.shuffle(dataset)
loader = DataLoader(dataset, batch_size=4)

In [70]:
def train_test(metacells, metadata, edges, idx_train, idx_test, bs):
    train = []
    test = []
    nGenes = metacells.shape[0]
    edgesT = edges.t().contiguous()
    for i in idx_train:
        train.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=edgesT,
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        ))
    for i in idx_test:
        test.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=edgesT,
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        )) 
    random.shuffle(train)
    random.shuffle(test)
    return DataLoader(train, batch_size=bs), DataLoader(test, batch_size=bs)

In [62]:
metadata.loc[:,"disorder_encoded"] = le.fit_transform(metadata["disorder"])

In [119]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.num_features = 1
        self.hidden_layers = 1
        self.k = 512
        self.in_heads = 8
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(self.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.4)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.4)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.4)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k, dropout=0.4)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(self.k, 2)

    def _dense_adj(self, edge_index, alpha, batch):
        # alpha = alpha.mean(dim=1)
        # adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)
        # out = []
        # for i in range(adj.size(0)):
        #     A = adj[i]
        #     deg_inv = A.sum(-1).clamp(min=1e-12).pow(-0.5)
        #     norm_adj = deg_inv.unsqueeze(1) * A * deg_inv.unsqueeze(0)
        #     out.append(norm_adj)
        # return torch.stack(out)
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)  # shape: [B, N, N]
        deg = adj.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        norm_adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        return norm_adj


    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 
        x = F.elu(self.bn3(x))

        # convert to dense batch and corresponding mask
        x_dense, mask = to_dense_batch(x, batch)

        alpha = alpha3.mean(dim=1)
        adj_dense = to_dense_adj(ei3, batch=batch, edge_attr=alpha)
        deg = adj_dense.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        adj_dense = deg_inv_sqrt.unsqueeze(-1) * adj_dense * deg_inv_sqrt.unsqueeze(-2)

        # # build dense Laplacian surrogate from layer‑3 attention
        # adj = self._dense_adj(ei3, alpha3, batch)
        # mask = torch.ones(adj.size(0), adj.size(1), dtype=torch.bool, device=adj.device)

        S, x, adj, mod, ort, clu = self.pool(x_dense, adj_dense, mask)

        # read‑out
        x = self.proj(x).squeeze(-1)
        logits = self.classifier(x)

        pool_reg = mod + clu + 0.1 * ort
        return logits, pool_reg
        

In [124]:
def train_epoch(loader, model, optimizer, device):
    model.train()
    total_loss = 0
    iter_loss = []
    train_iter = tqdm(loader)
    correct_readouts = []
    for data in train_iter:
        data = data.to(device)
        # data.y = data.y.long()
        optimizer.zero_grad()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()

        pred = out.argmax(dim=1)
        
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        train_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        
        loss.backward()
        optimizer.step()
        #train_iter.set_description(f"(loss {loss.item()})")
        iter_loss.append(loss.item())
        total_loss += loss.item() * data.num_graphs
    print(iter_loss)
    return total_loss / len(loader.dataset)

def test_epoch(loader, model, device):
    model.eval()
    total_loss = 0
    correct = 0
    correct_readouts = []
    test_iter = tqdm(loader)
    for data in test_iter:
        data = data.to(device)
        # data.y = data.y.long()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        test_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        total_loss += loss.item() * data.num_graphs
    print(correct_readouts)
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [None]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0)
device = 'cpu'

#train_loader, test_loader = train_test(metacells, metadata, edges, list(range(0, 5000, 50)), list(range(25, 5025, 50)))
#samples_use = random.sample(list(range(metacells.shape[1])), 300)
con_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 1].sample(n=150).index.values]
scz_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 0].sample(n=150).index.values]
samples_train = con_samples[:100]
samples_train.extend(scz_samples[:100])
random.shuffle(samples_train)
samples_test = con_samples[100:]
samples_test.extend(scz_samples[100:])
random.shuffle(samples_test)
train_loader, test_loader = train_test(metacells, metadata, edges, samples_train, samples_test, 8)

loss = train_epoch(train_loader, model, optimizer, device)
print(loss)

(loss 1.2066829204559326; 4/8 [CON(√) SCZ(x) CON(√) SCZ(x) CON(√) SCZ(x) CON(√) SCZ(x)]):   0%|          | 0/25 [00:18<?, ?it/s]

torch.Size([8, 512, 1])
torch.Size([8, 512])


(loss 1.2541576623916626; 6/8 [SCZ(√) CON(x) CON(x) SCZ(√) SCZ(√) SCZ(√) SCZ(√) SCZ(√)]):   4%|▍         | 1/25 [01:05<18:50, 47.09s/it]

torch.Size([8, 512, 1])
torch.Size([8, 512])


(loss 1.4801230430603027; 5/8 [SCZ(√) CON(x) SCZ(√) CON(x) SCZ(√) SCZ(√) SCZ(√) CON(x)]):   8%|▊         | 2/25 [01:50<17:41, 46.14s/it]

torch.Size([8, 512, 1])
torch.Size([8, 512])


(loss 1.4291549921035767; 4/8 [CON(√) SCZ(x) SCZ(√) CON(x) SCZ(x) SCZ(x) SCZ(√) SCZ(√)]):  12%|█▏        | 3/25 [02:31<16:00, 43.67s/it]

torch.Size([8, 512, 1])
torch.Size([8, 512])


(loss 0.9442040324211121; 5/8 [CON(√) SCZ(√) CON(x) CON(√) SCZ(√) CON(x) SCZ(√) SCZ(x)]):  16%|█▌        | 4/25 [03:15<15:22, 43.95s/it]

torch.Size([8, 512, 1])
torch.Size([8, 512])


(loss 1.1846413612365723; 3/8 [CON(x) SCZ(x) CON(√) SCZ(x) CON(x) SCZ(x) CON(√) SCZ(√)]):  20%|██        | 5/25 [03:57<14:23, 43.15s/it]

torch.Size([8, 512, 1])
torch.Size([8, 512])


In [None]:
total_loss, correct = test_epoch(test_loader, model, device)

In [93]:
samples_use = random.sample(list(range(metacells.shape[1])), 300)
print(metadata.iloc[samples_use,]["disorder_encoded"].sum())

132


In [104]:
[metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 1].sample(n=150).index.values]

[4888,
 7295,
 5850,
 4063,
 6223,
 5086,
 4807,
 5386,
 4141,
 4477,
 4416,
 4579,
 4599,
 5944,
 5482,
 6598,
 5067,
 4232,
 6042,
 6029,
 6201,
 5396,
 5458,
 5735,
 6649,
 6530,
 5194,
 7174,
 5763,
 6811,
 6145,
 7233,
 6696,
 4564,
 6024,
 5272,
 4671,
 6258,
 5767,
 4584,
 6991,
 4566,
 5987,
 5633,
 7089,
 5874,
 4626,
 5434,
 5400,
 4361,
 5621,
 3953,
 4428,
 4221,
 5658,
 7059,
 4282,
 4442,
 6070,
 6806,
 6588,
 5982,
 6407,
 6021,
 5724,
 4741,
 5449,
 7066,
 5305,
 6802,
 4959,
 5596,
 5520,
 6705,
 6333,
 4871,
 3914,
 4078,
 7370,
 5450,
 4743,
 6714,
 4112,
 6515,
 5374,
 3928,
 5755,
 5215,
 4866,
 6190,
 5203,
 4168,
 6293,
 6526,
 7164,
 4122,
 5192,
 5156,
 4710,
 5321,
 6224,
 4303,
 5791,
 5731,
 6548,
 6905,
 5617,
 5652,
 4177,
 6079,
 4314,
 3913,
 6008,
 6038,
 5903,
 6456,
 4453,
 6426,
 5201,
 5048,
 4188,
 5038,
 4862,
 7249,
 4642,
 4057,
 5199,
 6141,
 4869,
 4991,
 5236,
 6265,
 5796,
 6063,
 4191,
 6511,
 6320,
 6733,
 6697,
 5369,
 4162,
 5785,
 4616,