In [None]:
!pip install torch-scatter
!pip install torch-geometric
!pip install plotnine


In [27]:
import random
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [8]:
# load data
metacells = pd.read_csv("./oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./oligo-SCZ-tom.csv", index_col=0)

In [6]:
metadata.iloc[1:5000:50,:]["disorder"]

Oligo#CON1_2             Control
Oligo#CON10_5            Control
Oligo#CON10_55           Control
Oligo#CON10_105          Control
Oligo#CON11_41           Control
                       ...      
Oligo#SZ15_10      Schizophrenia
Oligo#SZ16_47      Schizophrenia
Oligo#SZ16_97      Schizophrenia
Oligo#SZ16_147     Schizophrenia
Oligo#SZ18_1       Schizophrenia
Name: disorder, Length: 100, dtype: object

In [7]:
metacells.iloc[:,0].values

array([0.        , 0.        , 1.28492008, ..., 3.60950374, 1.82917698,
       0.83603014])

In [139]:
# hyperparameters
TOM_THRESHOLD = 0.03  # value below which we zero out the similarity that will be used as attention priors

In [140]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
#edges = torch.tensor(edges)

100%|██████████| 6679/6679 [04:56<00:00, 22.54it/s] 


In [11]:
# # NOT USING CURRENTLY

# class GeneGraphDataset(InMemoryDataset):
#     def __init__(self, root, edge_index, expr_mat, meta, transform=None, pre_transform=None):
#         self.edge_index = edge_index
#         self.num_graphs = expr_mat.shape[1]
#         self.expr_mat = expr_mat
#         self.num_classes = 2
#         self.y = meta["disorder"]
#         super().__init__(root, transform, pre_transform)
#         self.load("./graphData/processed/combined.pt")

#     @property
#     def raw_file_names(self):
#         pass

#     @property
#     def processed_file_names(self):
#         return [f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt" for i in range(self.num_graphs)]

#     def process(self):
#         data_list = []
#         for i in tqdm(range(self.num_graphs)):
#             node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
#             data = Data(x=node_features, edge_index=self.edge_index)
#             data_list.append(data)
#             torch.save(data, f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

#         data, slices = self.collate(data_list)
#         torch.save((data, slices), f"./graphData/processed/combined.pt")

#     def get(self, idx=None, sample=None):
#         if idx is None:
#             data = torch.load(f"./graphData/processed/{sample}-graph.pt", weights_only=False)
#         else:
#             data = torch.load(f"./graphData/processed/{self.expr_mat.columns[idx]}-graph.pt", weights_only=False)
#         return data


# dataset = GeneGraphDataset(root='./graphData', edge_index=edges.t().contiguous(), expr_mat=metacellsReduced, meta=metadataReduced)

In [141]:
le = LabelEncoder()
metadataReduced.loc[:,"disorder_encoded"] = le.fit_transform(metadataReduced["disorder"])

dataset = []
nGenes = metacellsReduced.shape[0]
num_samples = min(100, metacellsReduced.shape[1], metadataReduced.shape[0])  # Ensure we don't exceed available data
for i in tqdm(range(num_samples)):
    dataset.append(
    Data(
        x=torch.tensor(metacellsReduced.iloc[:, i].values, dtype=torch.float32).reshape((nGenes, 1)),
        edge_index=torch.tensor(edges).t().contiguous(),
        y=torch.tensor(metadataReduced["disorder_encoded"].iloc[i], dtype=torch.long)
    )
)
random.shuffle(dataset)
loader = DataLoader(dataset, batch_size=4)

KeyboardInterrupt: 

In [142]:
def train_test(metacells, metadata, edges, idx_train, idx_test, bs):
    train = []
    test = []
    nGenes = metacells.shape[0]
    for i in tqdm(idx_train):
        train.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=torch.tensor(edges).t().contiguous(),
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        ))
    for i in tqdm(idx_test):
        test.append(Data(
            x=torch.tensor(metacells.iloc[:,i].values, dtype=torch.float32).reshape((nGenes,1)),
            edge_index=torch.tensor(edges).t().contiguous(),
            y=torch.tensor(metadata["disorder_encoded"].iloc[i], dtype=torch.long)
        )) 
    random.shuffle(train)
    random.shuffle(test)
    return DataLoader(train, batch_size=bs), DataLoader(test, batch_size=bs)

In [62]:
metadata.loc[:,"disorder_encoded"] = le.fit_transform(metadata["disorder"])

In [143]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.num_features = 1
        self.hidden_layers = 1
        self.k = 64
        self.in_heads = 4
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(self.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.2)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.2)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.2)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k, dropout=0.2)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(self.k, 2)

    def _dense_adj(self, edge_index, alpha, batch):
        # alpha = alpha.mean(dim=1)
        # adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)
        # out = []
        # for i in range(adj.size(0)):
        #     A = adj[i]
        #     deg_inv = A.sum(-1).clamp(min=1e-12).pow(-0.5)
        #     norm_adj = deg_inv.unsqueeze(1) * A * deg_inv.unsqueeze(0)
        #     out.append(norm_adj)
        # return torch.stack(out)
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha)  # shape: [B, N, N]
        deg = adj.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        norm_adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        return norm_adj


    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 
        x = F.elu(self.bn3(x))

        # convert to dense batch and corresponding mask
        x_dense, mask = to_dense_batch(x, batch)

        alpha = alpha3.mean(dim=1)
        adj_dense = to_dense_adj(ei3, batch=batch, edge_attr=alpha)
        deg = adj_dense.sum(-1).clamp(min=1e-12)
        deg_inv_sqrt = deg.pow(-0.5)
        adj_dense = deg_inv_sqrt.unsqueeze(-1) * adj_dense * deg_inv_sqrt.unsqueeze(-2)

        # # build dense Laplacian surrogate from layer‑3 attention
        # adj = self._dense_adj(ei3, alpha3, batch)
        # mask = torch.ones(adj.size(0), adj.size(1), dtype=torch.bool, device=adj.device)

        S, x, adj, mod, ort, clu = self.pool(x_dense, adj_dense, mask)

        # read‑out
        x = self.proj(x).squeeze(-1)
        logits = self.classifier(x)

        pool_reg = mod + clu + 0.1 * ort
        return logits, pool_reg
        

In [149]:
def train_epoch(loader, model, optimizer, device):
    model.train()
    total_loss = 0
    iter_loss = []
    train_iter = tqdm(loader)
    correct_readouts = []
    for data in train_iter:
        data = data.to(device)
        # data.y = data.y.long()
        optimizer.zero_grad()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        print(out)
        pred = out.argmax(dim=1)
        
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        train_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        
        loss.backward()
        optimizer.step()
        #train_iter.set_description(f"(loss {loss.item()})")
        iter_loss.append(loss.item())
        total_loss += loss.item() * data.num_graphs
    print(iter_loss)
    return total_loss / len(loader.dataset)

def test_epoch(loader, model, device):
    model.eval()
    total_loss = 0
    correct = 0
    correct_readouts = []
    test_iter = tqdm(loader)
    for data in test_iter:
        data = data.to(device)
        # data.y = data.y.long()
        out, pool_reg = model(data)
        print(out)
        loss = F.cross_entropy(out, data.y.squeeze()) + pool_reg.mean()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        correct_readout = ["SCZ" if item == 1 else "CON" for item in data.y]
        for i, item in enumerate(pred == data.y):
            correct_readout[i] += "(√)" if item else "(x)"
        correct_readouts.extend(correct_readout)
        test_iter.set_description(f"(loss {loss.item()}; {(pred == data.y).sum().item()}/{loader.batch_size} [{" ".join(correct_readout)}])")
        total_loss += loss.item() * data.num_graphs
    print(correct_readouts)
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [146]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)
device = 'cpu'

#train_loader, test_loader = train_test(metacells, metadata, edges, list(range(0, 5000, 50)), list(range(25, 5025, 50)))
#samples_use = random.sample(list(range(metacells.shape[1])), 300)
con_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 1].sample(n=120).index.values]
scz_samples = [metacells.columns.get_loc(c) for c in metadata[metadata["disorder_encoded"] == 0].sample(n=120).index.values]

samples_train = con_samples[:100]
samples_train.extend(scz_samples[:100])
random.shuffle(samples_train)
samples_test = con_samples[20:]
samples_test.extend(scz_samples[20:])
random.shuffle(samples_test)

train_loader, test_loader = train_test(metacells, metadata, edges, samples_train, samples_test, 16)

loss = train_epoch(train_loader, model, optimizer, device)
print(loss)

100%|██████████| 100/100 [04:26<00:00,  2.67s/it]
100%|██████████| 20/20 [00:53<00:00,  2.67s/it]
(loss 7.755601406097412; 4/8 [SCZ(√) CON(x) SCZ(√) SCZ(√) CON(x) CON(x) SCZ(√) CON(x)]):   0%|          | 0/13 [00:08<?, ?it/s]

tensor([[-8.3350,  7.8380],
        [-8.1754,  7.5557],
        [-7.6658,  8.3093],
        [-7.9162,  7.3422],
        [-8.5165,  8.6737],
        [-4.6252,  5.1736],
        [-5.1895,  5.2223],
        [-8.6381,  9.0780]], grad_fn=<AddmmBackward0>)


(loss 5.965826511383057; 5/8 [SCZ(√) SCZ(√) CON(x) CON(x) SCZ(√) CON(x) SCZ(√) SCZ(√)]):   8%|▊         | 1/13 [00:28<03:55, 19.60s/it]

tensor([[-3.7563,  2.4700],
        [-7.5753,  7.3161],
        [-9.5765, 10.3384],
        [-4.3211,  5.1288],
        [-9.6896, 10.3454],
        [-8.6414,  8.1325],
        [-3.9907,  3.3896],
        [-7.3901,  7.3329]], grad_fn=<AddmmBackward0>)


(loss 1.406540870666504; 7/8 [SCZ(√) SCZ(√) SCZ(√) CON(x) SCZ(√) SCZ(√) SCZ(√) SCZ(√)]):  15%|█▌        | 2/13 [00:47<03:35, 19.56s/it]

tensor([[-6.5134,  8.3789],
        [-3.2028,  2.2591],
        [-1.9902,  1.4948],
        [-4.8717,  4.6920],
        [-5.3413,  5.4917],
        [-2.2897,  0.9355],
        [-3.5557,  2.7508],
        [-4.1006,  3.7156]], grad_fn=<AddmmBackward0>)


(loss 6.334936141967773; 3/8 [SCZ(√) CON(x) CON(x) SCZ(√) SCZ(√) CON(x) CON(x) CON(x)]):  23%|██▎       | 3/13 [01:05<03:11, 19.13s/it]

tensor([[-1.9975,  1.8896],
        [-2.2335,  1.3377],
        [-3.1560,  5.0133],
        [-2.8360,  3.8423],
        [-3.2403,  6.2218],
        [-6.0087,  6.6277],
        [-5.9726,  7.2975],
        [-4.7863,  6.6047]], grad_fn=<AddmmBackward0>)


(loss 5.068115711212158; 3/8 [CON(x) SCZ(√) CON(x) SCZ(√) CON(x) CON(x) CON(x) SCZ(√)]):  31%|███       | 4/13 [01:25<02:51, 19.09s/it]

tensor([[-4.3908,  4.9766],
        [-4.0711,  6.6481],
        [-3.6122,  3.2154],
        [-2.4490,  3.3109],
        [-2.9604,  3.1101],
        [-4.6331,  5.4177],
        [-3.6830,  2.9533],
        [-2.7056,  2.6399]], grad_fn=<AddmmBackward0>)


(loss 3.068268299102783; 3/8 [CON(x) CON(x) SCZ(x) CON(x) SCZ(√) CON(x) SCZ(√) SCZ(√)]):  38%|███▊      | 5/13 [01:44<02:32, 19.08s/it]

tensor([[-1.1328,  1.9511],
        [-1.9524,  1.7715],
        [-0.1078, -0.1396],
        [-5.9280,  7.1430],
        [-2.7101,  3.7611],
        [-1.1157,  0.9273],
        [-3.8785,  3.9066],
        [-0.5194,  1.4269]], grad_fn=<AddmmBackward0>)


(loss 1.3569869995117188; 7/8 [CON(√) SCZ(√) CON(√) CON(√) SCZ(√) SCZ(√) SCZ(√) CON(x)]):  46%|████▌     | 6/13 [02:03<02:13, 19.08s/it]

tensor([[ 0.8878, -2.1236],
        [-1.1889,  1.2264],
        [ 1.4898, -0.8963],
        [ 1.3275, -1.8695],
        [-2.5372,  3.6928],
        [-2.7458,  4.1521],
        [-0.5886, -0.0566],
        [-3.5248,  5.0081]], grad_fn=<AddmmBackward0>)


(loss 1.8510932922363281; 5/8 [CON(x) CON(√) SCZ(√) CON(√) CON(x) SCZ(x) CON(√) SCZ(√)]):  54%|█████▍    | 7/13 [02:22<01:55, 19.17s/it]

tensor([[-1.5166,  2.3181],
        [-0.0971, -0.2287],
        [-2.1619,  2.5948],
        [ 0.6362, -1.1685],
        [-0.7364,  1.0006],
        [ 3.1607, -3.2469],
        [ 1.0557, -1.3973],
        [-1.1923,  0.4189]], grad_fn=<AddmmBackward0>)


(loss 1.639204978942871; 5/8 [SCZ(√) CON(√) CON(x) CON(√) CON(x) CON(√) CON(√) CON(x)]):  62%|██████▏   | 8/13 [02:41<01:35, 19.20s/it] 

tensor([[-3.1452,  1.9779],
        [ 0.5176, -0.6947],
        [-1.8562,  2.6809],
        [-0.0320, -0.1834],
        [-2.2266,  2.5920],
        [ 2.3947, -2.4158],
        [ 0.8213, -1.0809],
        [ 0.7137,  1.4583]], grad_fn=<AddmmBackward0>)


(loss 3.2221479415893555; 3/8 [SCZ(√) SCZ(√) CON(x) SCZ(x) CON(x) SCZ(√) CON(x) SCZ(x)]):  69%|██████▉   | 9/13 [03:01<01:17, 19.41s/it]

tensor([[-1.6790,  2.0018],
        [-0.5998,  0.8099],
        [-1.3778,  1.4988],
        [ 4.3874, -4.2341],
        [-2.6363,  2.7654],
        [-1.6076,  4.0961],
        [-0.5761,  1.3819],
        [ 2.3507, -2.5429]], grad_fn=<AddmmBackward0>)


(loss 2.5492827892303467; 4/8 [CON(√) SCZ(√) SCZ(x) CON(x) SCZ(x) SCZ(√) SCZ(x) SCZ(√)]):  77%|███████▋  | 10/13 [03:21<00:58, 19.39s/it]

tensor([[ 0.6244, -0.3667],
        [-1.9587,  1.0724],
        [ 2.0684, -1.5145],
        [-1.8887,  2.9309],
        [ 2.7886, -1.2110],
        [-1.6545,  2.1598],
        [ 2.6585, -2.6945],
        [-0.7705, -0.6245]], grad_fn=<AddmmBackward0>)


(loss 4.637099742889404; 3/8 [SCZ(x) SCZ(x) CON(x) SCZ(√) CON(√) CON(√) CON(x) SCZ(x)]):  85%|████████▍ | 11/13 [03:40<00:38, 19.47s/it] 

tensor([[ 5.6644, -8.1291],
        [ 1.3003, -0.2403],
        [-2.4543,  3.9574],
        [-1.6726, -0.7041],
        [ 2.6408, -2.2582],
        [ 2.2616, -1.5346],
        [-1.5574,  2.3413],
        [ 4.0824, -5.2320]], grad_fn=<AddmmBackward0>)


(loss 0.31108415126800537; 4/8 [CON(√) SCZ(√) CON(√) CON(√)]):  92%|█████████▏| 12/13 [03:55<00:19, 19.31s/it]                          

tensor([[ 1.4181, -1.1514],
        [-2.6380,  2.6309],
        [ 2.6778, -2.7521],
        [ 0.2987, -0.4599]], grad_fn=<AddmmBackward0>)


(loss 0.31108415126800537; 4/8 [CON(√) SCZ(√) CON(√) CON(√)]): 100%|██████████| 13/13 [04:01<00:00, 18.55s/it]

[7.755601406097412, 5.965826511383057, 1.406540870666504, 6.334936141967773, 5.068115711212158, 3.068268299102783, 1.3569869995117188, 1.8510932922363281, 1.639204978942871, 3.2221479415893555, 2.5492827892303467, 4.637099742889404, 0.31108415126800537]
3.600851740837097





In [150]:
total_loss, correct = test_epoch(test_loader, model, device)

(loss 4.835780143737793; 5/8 [CON(√) CON(√) CON(√) CON(√) SCZ(x) SCZ(x) SCZ(x) CON(√)]):  33%|███▎      | 1/3 [00:07<00:15,  7.84s/it]

tensor([[ 6.4445, -7.0102],
        [ 6.6995, -7.3085],
        [ 6.1920, -6.7347],
        [ 7.2206, -7.8991],
        [ 5.7347, -6.2378],
        [ 4.9250, -5.3053],
        [ 7.0807, -7.7250],
        [ 4.3374, -4.6316]], grad_fn=<AddmmBackward0>)
tensor([[ 6.1008, -6.6298],
        [ 8.5782, -9.4206],
        [ 7.0398, -7.6904],
        [ 1.5264, -1.4489],
        [ 5.9499, -6.4575],
        [ 5.5295, -5.9902],
        [ 5.7168, -6.1994],
        [ 7.0232, -7.6850]], grad_fn=<AddmmBackward0>)


(loss 6.2702860832214355; 4/8 [SCZ(x) SCZ(x) SCZ(x) SCZ(x) CON(√) CON(√) CON(√) CON(√)]):  67%|██████▋   | 2/3 [00:16<00:08,  8.34s/it]

tensor([[ 8.5071, -9.3462],
        [ 7.2600, -7.9316],
        [ 6.6295, -7.2367],
        [ 2.3972, -2.4397]], grad_fn=<AddmmBackward0>)


(loss 11.939884185791016; 1/8 [SCZ(x) SCZ(x) SCZ(x) CON(√)]): 100%|██████████| 3/3 [00:21<00:00,  7.14s/it]                            


['CON(√)', 'CON(√)', 'CON(√)', 'CON(√)', 'SCZ(x)', 'SCZ(x)', 'SCZ(x)', 'CON(√)', 'SCZ(x)', 'SCZ(x)', 'SCZ(x)', 'SCZ(x)', 'CON(√)', 'CON(√)', 'CON(√)', 'CON(√)', 'SCZ(x)', 'SCZ(x)', 'SCZ(x)', 'CON(√)']
