In [None]:
!pip install torch-scatter

In [82]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [10]:
# load data
metacells = pd.read_csv("./oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./oligo-SCZ-tom.csv", index_col=0)

In [62]:
metadata.iloc[1:5000:50,:]["disorder"]

Oligo#CON1_2             Control
Oligo#CON10_5            Control
Oligo#CON10_55           Control
Oligo#CON10_105          Control
Oligo#CON11_41           Control
                       ...      
Oligo#SZ15_10      Schizophrenia
Oligo#SZ16_47      Schizophrenia
Oligo#SZ16_97      Schizophrenia
Oligo#SZ16_147     Schizophrenia
Oligo#SZ18_1       Schizophrenia
Name: disorder, Length: 100, dtype: object

In [12]:
metacells.iloc[:,0].values

array([0.        , 0.        , 1.28492008, ..., 3.60950374, 1.82917698,
       0.83603014], shape=(6679,))

In [13]:
# hyperparameters
TOM_THRESHOLD = 0.025  # value below which we zero out the similarity that will be used as attention priors

In [14]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
edges = torch.tensor(edges)

100%|██████████| 6679/6679 [04:21<00:00, 25.57it/s] 


In [65]:
metacellsReduced = metacells.iloc[:,0:5000:50]
metadataReduced = metadata.iloc[0:5000:50,:]

In [66]:
# NOT USING CURRENTLY

class GeneGraphDataset(InMemoryDataset):
    def __init__(self, root, edge_index, expr_mat, meta, transform=None, pre_transform=None):
        self.edge_index = edge_index
        self.num_graphs = expr_mat.shape[1]
        self.expr_mat = expr_mat
        self.num_classes = 2
        self.y = meta["disorder"]
        super().__init__(root, transform, pre_transform)
        self.load("./graphData/processed/combined.pt")

    @property
    def raw_file_names(self):
        pass

    @property
    def processed_file_names(self):
        return [f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt" for i in range(self.num_graphs)]

    def process(self):
        data_list = []
        for i in tqdm(range(self.num_graphs)):
            node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
            data = Data(x=node_features, edge_index=self.edge_index)
            data_list.append(data)
            torch.save(data, f"./graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

        data, slices = self.collate(data_list)
        torch.save((data, slices), f"./graphData/processed/combined.pt")

    def get(self, idx=None, sample=None):
        if idx is None:
            data = torch.load(f"./graphData/processed/{sample}-graph.pt", weights_only=False)
        else:
            data = torch.load(f"./graphData/processed/{self.expr_mat.columns[idx]}-graph.pt", weights_only=False)
        return data


dataset = GeneGraphDataset(root='./graphData', edge_index=edges.t().contiguous(), expr_mat=metacellsReduced, meta=metadataReduced)

AttributeError: property 'num_classes' of 'GeneGraphDataset' object has no setter

In [104]:
dataset = []
nGenes = metacellsReduced.iloc[:,i].values.shape[0]
edgesT = edges.t().contiguous()
for i in range(100):
    dataset.append(Data(x=torch.tensor(metacellsReduced.iloc[:,i].values, dtype=torch.float32).reshape((nGenes, 1)), edge_index=edgesT, y=metadataReduced["disorder"].iloc[i]))

In [105]:
loader = DataLoader(dataset, batch_size=4)

In [106]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.num_features = 1
        self.hidden_layers = 1
        self.k = 128
        self.in_heads = 8
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(self.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.4)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.4)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.4)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k, dropout=0.4)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(self.k, 2)

    def _dense_adj(edge_index, alpha, batch):
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha).squeeze(0)
        deg_inv = adj.sum(-1).clamp(min=1e-12).pow(-0.5)
        return deg_inv.unsqueeze(1) * adj * deg_inv.unsqueeze(0)

    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch
        print(x)
        print(ei)
        print(batch)

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 

        # build dense Laplacian surrogate from layer‑3 attention
        adj = self._dense_adj(ei3, alpha3, batch)
        mask = torch.ones(adj.size(0), dtype=torch.bool, device=adj.device)

        S, x, adj, mod, ort, clu = self.pool1(x, adj, mask)

        # read‑out
        x_pool = self.proj(x_pool).squeeze(-1)
        logits = self.classifier(x_pool)

        pool_reg = mod + clu + 0.1 * ort
        return logits, pool_reg
        

In [107]:
def train_epoch(loader, model, optimizer, device):
    model.train()
    total_loss = 0
    for data in tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y) + pool_reg.mean()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

def test_epoch(loader, model, device):
    model.eval()
    total_loss = 0
    correct = 0
    for data in tqdm(loader):
        data = data.to(device)
        out, pool_reg = model(data)
        loss = F.cross_entropy(out, data.y) + pool_reg.mean()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [108]:
model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
device = 'cpu'

loss = train_epoch(loader, model, optimizer, device)

  0%|          | 0/25 [00:00<?, ?it/s]

tensor([[0.0000],
        [0.0000],
        [1.2849],
        ...,
        [3.9047],
        [1.8204],
        [0.6522]])
tensor([[    3,     0,    12,  ..., 26683, 26715, 26690],
        [    0,     3,     0,  ..., 26715, 26690, 26715]])
tensor([0, 0, 0,  ..., 3, 3, 3])


  0%|          | 0/25 [00:09<?, ?it/s]


TypeError: GAT._dense_adj() takes 3 positional arguments but 4 were given