In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GATConv, GATv2Conv, DMoNPooling
from torch_geometric.utils import to_dense_adj
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm

import matplotlib.pyplot as plt
import pandas as pd

from plotnine import ggplot

In [2]:
# load data
metacells = pd.read_csv("./final/oligo-SCZ-metacellExpr.csv", index_col=0)
metadata = pd.read_csv("./final/oligo-SCZ-meta.csv", index_col=0)
tom = pd.read_csv("./final/oligo-SCZ-tom.csv", index_col=0)

In [5]:
metacells.shape

(6679, 7378)

In [6]:
metacells.iloc[:,0].values

array([0.        , 0.        , 1.28492008, ..., 3.60950374, 1.82917698,
       0.83603014], shape=(6679,))

In [7]:
# hyperparameters
TOM_THRESHOLD = 0.025  # value below which we zero out the similarity that will be used as attention priors

In [8]:
# create graph dataset
edges = []
for i in tqdm(range(len(tom.columns))):
    for j in range(i):
        if tom.iloc[i,j] > TOM_THRESHOLD:
            edges.extend([[i,j],[j,i]])
edges = torch.tensor(edges)

100%|██████████| 6679/6679 [04:11<00:00, 26.53it/s] 


In [17]:
metacells.columns[0]

'Oligo#CON1_1'

In [None]:
class GeneGraphDataset(Dataset):
    def __init__(self, root, edge_index, expr_mat, transform=None, pre_transform=None):
        self.edge_index = edge_index
        self.num_graphs = expr_mat.shape[1]
        self.expr_mat = expr_mat
        super().__init__(root, transform, pre_transform)
        # self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        #data_list = []
        for i in tqdm(range(self.num_graphs)):
            num_nodes = 3
            node_features = torch.tensor(self.expr_mat.iloc[:,i].values)
            data = Data(x=node_features, edge_index=self.edge_index)
            #data_list.append(data)
            torch.save(data, f"./final/graphData/processed/{self.expr_mat.columns[i]}-graph.pt")

        # data, slices = self.collate(data_list)
        # torch.save((data, slices), self.processed_paths[0])

    def get(self, idx=None, sample=None):
        if idx is None:
            data = torch.load(f"./final/graphData/processed/{sample}-graph.pt")
        else:
            data = torch.load(f"./final/graphData/processed/{self.expr_mat.columns[idx]}-graph.pt")

# Usage
dataset = GeneGraphDataset(root='./final/graphData', edge_index=edges.t().contiguous(), expr_mat=metacells)
loader = DataLoader(dataset, batch_size=32)

Processing...
  3%|▎         | 228/7378 [00:59<31:18,  3.81it/s]


KeyboardInterrupt: 

In [None]:
geneGraph.num_features

# change: each metacell gets its own graph

7378

In [None]:
# define model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.hidden_layers = 8
        self.k1 = 512
        self.k2 = 128
        self.in_heads = 8
        self.mid_heads = 2
        self.out_heads = 1

        self.conv1 = GATv2Conv(geneGraph.num_features, self.hidden_layers, heads=self.in_heads, dropout=0.4)
        self.bn1 = torch.nn.BatchNorm1d(self.hidden_layers * self.in_heads)
        self.conv2 = GATv2Conv(self.hidden_layers*self.in_heads, self.hidden_layers, heads=self.mid_heads, dropout=0.4)
        self.bn2 = torch.nn.BatchNorm1d(self.hidden_layers * self.mid_heads)
        self.conv3 = GATv2Conv(self.hidden_layers*self.mid_heads, self.hidden_layers, heads=self.out_heads, dropout=0.4)
        self.bn3 = torch.nn.BatchNorm1d(self.hidden_layers)

        self.pool1 = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k1, dropout=0.4)
        self.pool2 = DMoNPooling([self.hidden_layers, self.hidden_layers], k=self.k2, dropout=0.4)

        self.proj = torch.nn.Linear(self.hidden_layers, 1)

        self.classifier = torch.nn.Linear(128, 2)

    def _dense_adj(edge_index, alpha, batch):
        alpha = alpha.mean(dim=1)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=alpha).squeeze(0)
        deg_inv = adj.sum(-1).clamp(min=1e-12).pow(-0.5)
        return deg_inv.unsqueeze(1) * adj * deg_inv.unsqueeze(0)

    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch

        x = F.elu(self.conv1(x, ei))
        x = F.elu(self.conv2(x, ei))                      
        x, (ei3, alpha3) = self.conv3(x, ei, return_attention_weights=True) 

        # build dense Laplacian surrogate from layer‑3 attention
        adj = self._dense_adj(ei3, alpha3, batch)
        mask = torch.ones(adj.size(0), dtype=torch.bool, device=adj.device)

        S1, x1, adj1, mod1, ort1, clu1 = self.pool1(x, adj, mask)
        S2, x2, adj2, mod2, ort2, clu2 = self.pool2(x1, adj1)

        # read‑out
        x2 = self.proj(x2).squeeze(-1)
        logits = self.classifier(x2)

        pool_reg = (mod1 + clu1 + 0.1 * ort1) + (mod2 + clu2 + 0.1 * ort2)
        return logits, pool_reg

        

np.float64(0.0624200607218792)