In [1]:
import pickle
import os
import numpy as np
import geovoronoi
import shapely
import matplotlib.pyplot as plt
import pandas as pd
import collections
from collections import defaultdict
import time
import copy

import torch

from DataManager import dataManager
from graphgen import neighborFinder

from visualize_2 import plotHeatmaps, plotAUROC
import PIL
from tqdm import tqdm, trange

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

# os.environ["CUDA_VISIBLE_DEVICES"] = '3'

In [2]:
def load_cell_data():
    store = '/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Lutz/CELL2RNA/spatial_omics/code/GNN/store'
   
    with open(f"{store}/xcoords.pickle", "rb") as f:
        xcoords = pickle.load(f)

    with open(f"{store}/ycoords.pickle", "rb") as f:
        ycoords = pickle.load(f)

    with open(f"{store}/patchxy.pickle", 'rb') as f:
        patch_xy = pickle.load(f)
        
    with open(f"{store}/cnn_embed.pickle", "rb") as f:
        embeddings = pickle.load(f)

    slide_cell_locations_all = {}
    slide_embeddings_all = {}
    for slide in ['A1', 'B1', 'C1', 'D1']:
        cell_x = []
        cell_y = []
        embeddings_ = []
        
        xy_set = set()

        cell_counter = 0
        for patch_i, (patch_x, patch_y) in enumerate(patch_xy[slide]):
            for cell_i in range(len(xcoords[slide][patch_i])):
                x = xcoords[slide][patch_i][cell_i] + patch_x
                y = ycoords[slide][patch_i][cell_i] + patch_y
                
                embedding = embeddings[slide][cell_counter]
                
                cell_counter += 1
                
                # Remove duplicates, which does happen sometimes
                if (x, y) in xy_set:
                    continue
                    
                xy_set.add((x, y))

                cell_x.append(x)
                cell_y.append(y)
                embeddings_.append(embedding)
                
        slide_cell_locations_all[slide] = np.vstack([cell_x, cell_y]).T
        slide_embeddings_all[slide] = np.array(embeddings_)
        
    return slide_cell_locations_all, slide_embeddings_all

In [3]:
cell_locations, cell_embeddings = load_cell_data()
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Lutz/CELL2RNA/spatial_omics/DH/visium/preprocessed_data/visium_data_filtered_processed.pkl', "rb") as f:
    visium = pd.read_pickle(f)
    
with open('/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Lutz/CELL2RNA/all_code/adjacency_matrices.pkl', "rb") as f:
    adjacency_matrices = pd.read_pickle(f)
    

In [4]:
visium['A1']

(                 SAMD11  NOC2L  KLHL17  PLEKHN1  PERM1  HES4  ISG15  AGRN  \
 array_x array_y                                                             
 0       0             0      0       0        0      0     1      1     0   
 1       1             0      1       0        0      0     1      0     2   
 2       0             0      3       0        1      0     1      1     0   
 3       1             0      2       0        0      0     0      1     3   
 4       0             0      0       0        0      0     0      1     2   
 ...                 ...    ...     ...      ...    ...   ...    ...   ...   
 122     76            0      0       0        0      0     0      0     0   
 123     77            0      0       0        0      0     0      0     0   
 124     76            0      0       0        0      0     0      0     0   
 125     77            0      0       0        0      0     0      0     0   
 126     76            0      0       0        0      0     0   

In [5]:
def create_voronoi_regions(cell_locations):
    min_x = np.min(cell_locations[:, 0])
    max_x = np.max(cell_locations[:, 0])
    min_y = np.min(cell_locations[:, 1])
    max_y = np.max(cell_locations[:, 1])
    
    bounding_rect = shapely.geometry.Polygon([
        [min_x, min_y],
        [min_x, max_y],
        [max_x, max_y],
        [max_x, min_y],
    ])
    
    region_polys, region_pts = geovoronoi.voronoi_regions_from_coords(cell_locations, bounding_rect)
    
    filtered_polys = []
    filtered_region_pts = []
    
    for i, polygon in region_polys.items():
        polygon_min_x, polygon_min_y, polygon_max_x, polygon_max_y = polygon.bounds
        if (polygon_min_x == min_x or
            polygon_min_y == min_y or
            polygon_max_x == max_x or
            polygon_max_y == max_y):
            continue
        
        filtered_polys.append(polygon)
        filtered_region_pts.append(region_pts[i])
    
    return filtered_polys, filtered_region_pts

In [6]:
regions = {}
# Create final locations and embeddings dicts, removing droped cells
locations = {}
embeddings = {}

for slide_id in cell_locations.keys():
    regions[slide_id], preserved_indexes = create_voronoi_regions(cell_locations[slide_id])
    locations[slide_id] = cell_locations[slide_id][[x[0] for x in preserved_indexes]]
    embeddings[slide_id] = cell_embeddings[slide_id][[x[0] for x in preserved_indexes]]

In [7]:
def get_indexes_of_cells_under_each_spot(slide, cell_locations, r):
    indexes = []
    counts, visium_locations = slide
    r2 = r * r
    
    # Sort cells into tiles
    tiles = defaultdict(set)
    for cell_i, (cell_x, cell_y) in enumerate(cell_locations):
        for dx in [-r, 0, r]:
            for dy in [-r, 0, r]:
                tiles[(cell_x + dx) // (2 * r), (cell_y + dy) // (2 * r)].add(cell_i)

    for visium_i in range(len(visium_locations)):
        indexes.append([])

        visium_x, visium_y = visium_locations.iloc[visium_i]
        
        search_space = set()
        for dx in [-r, 0, r]:
            for dy in [-r, 0, r]:
                search_space = search_space.union(
                    tiles[(visium_x + dx) // (2 * r), (visium_y + dy) // (2 * r)]
                )

        for cell_i in search_space:
            cell_x, cell_y = cell_locations[cell_i]

            distance = ((cell_x - visium_x) ** 2 + (cell_y - visium_y) ** 2)
            if distance <= r2:
                indexes[-1].append(cell_i)
                
    return indexes

In [8]:
cells_by_spot = {}

for slide_id in visium.keys():
    cells_by_spot[slide_id] = get_indexes_of_cells_under_each_spot(visium[slide_id], locations[slide_id], r=80)#80#128#100#55
    
    

In [9]:
cells_by_spot['A1']

[[],
 [],
 [1153, 1154, 1155, 1151],
 [1152],
 [],
 [1149],
 [1150],
 [29],
 [28],
 [386, 387, 389],
 [105, 117],
 [],
 [116],
 [],
 [30, 43],
 [106],
 [32, 36, 37, 114],
 [407, 408, 5574],
 [1156, 410],
 [5571, 9330, 9331, 9335],
 [1158, 2719],
 [9329],
 [5775, 2835],
 [17789, 17790, 17791, 5774, 5777, 5570],
 [2834],
 [2833, 9332, 9604, 5773],
 [2837, 5781, 5782],
 [9598, 9602, 9603, 9605, 9601, 9599],
 [2836, 9609, 5785, 5786, 5787, 5790, 5791, 5795],
 [2817, 2821, 2822, 5758, 9568, 9571, 9572, 9574, 9600],
 [5757, 9611, 13967, 13968, 13969, 5788, 5789],
 [2827, 5759],
 [2852, 9360, 2728],
 [1204, 2769, 5593],
 [9362, 9363, 9364, 1178, 5592, 5597, 5598, 5599],
 [2816,
  21014,
  21012,
  21013,
  23385,
  23388,
  23389,
  13682,
  13685,
  1203,
  5594,
  5596,
  5626,
  2815],
 [13684, 9370, 2736, 5607],
 [21011, 21030, 13683],
 [1187, 9392, 9394, 2762, 2763, 5623, 5624],
 [21028, 21029, 9404, 9405],
 [13672, 13673, 13674, 13675, 13677, 13678, 13680, 9389, 9391, 9393],
 [21021, 21

In [10]:
def create_neighborhood(adjacency_matrix, starting_indexes, hop_count):
    accessible = np.zeros(adjacency_matrix.shape[0], dtype=bool)
    accessible[starting_indexes] = True
    
    for hop in range(hop_count):
        accessible = accessible | (accessible @ adjacency_matrix)
        
    return np.where(accessible)[0]

def create_neighborhood_edge_list(edge_list, starting_indexes, hop_count):
    accessible = set(starting_indexes)
    
    for hop in range(hop_count):
        next_accessible = {*accessible}
        for accessible_index in accessible:
            next_accessible.update(edge_list[accessible_index])
            
        accessible = next_accessible
        
    return np.array([*sorted(accessible)])

In [11]:
# neighborhoods_by_spot_0 = {}
# for visium_key in ['A1', 'B1', 'C1', 'D1']:
#     neighborhoods_by_spot_0[visium_key] = []
#     for one_spot in cells_by_spot['A1']:
#         neighborhoods_by_spot_0[visium_key].append(np.array(one_spot))

# neighborhoods_by_spot = {}
neighborhoods_by_spot_3 = {}
neighborhoods_by_spot_4 = {}
neighborhoods_by_spot_5 = {}
edge_lists = {}

for slide_id in cells_by_spot.keys():
    neighborhoods = []
    adj = adjacency_matrices[slide_id]
    edge_lists[slide_id] = [np.where(adj[i])[0] for i in range(adj.shape[0])]
    for spot_i in tqdm(range(len(cells_by_spot[slide_id])), desc='Creating neighborhoods for ' + str(slide_id)):
        starting_cells = cells_by_spot[slide_id][spot_i]
        
        neighborhoods.append(
            create_neighborhood_edge_list(edge_lists[slide_id], starting_cells, hop_count=3)#3
        )
    
    neighborhoods_by_spot_3[slide_id] = neighborhoods
    
    
for slide_id in cells_by_spot.keys():
    neighborhoods = []
    adj = adjacency_matrices[slide_id]
    edge_lists[slide_id] = [np.where(adj[i])[0] for i in range(adj.shape[0])]
    for spot_i in tqdm(range(len(cells_by_spot[slide_id])), desc='Creating neighborhoods for ' + str(slide_id)):
        starting_cells = cells_by_spot[slide_id][spot_i]
        
        neighborhoods.append(
            create_neighborhood_edge_list(edge_lists[slide_id], starting_cells, hop_count=4)#3
        )
    
    neighborhoods_by_spot_4[slide_id] = neighborhoods
    
    
for slide_id in cells_by_spot.keys():
    neighborhoods = []
    adj = adjacency_matrices[slide_id]
    edge_lists[slide_id] = [np.where(adj[i])[0] for i in range(adj.shape[0])]
    for spot_i in tqdm(range(len(cells_by_spot[slide_id])), desc='Creating neighborhoods for ' + str(slide_id)):
        starting_cells = cells_by_spot[slide_id][spot_i]
        
        neighborhoods.append(
            create_neighborhood_edge_list(edge_lists[slide_id], starting_cells, hop_count=5)#3
        )
    
    neighborhoods_by_spot_5[slide_id] = neighborhoods

Creating neighborhoods for A1: 100%|███████████████████████████████████| 4950/4950 [00:00<00:00, 11672.32it/s]
Creating neighborhoods for B1: 100%|███████████████████████████████████| 4922/4922 [00:00<00:00, 13467.14it/s]
Creating neighborhoods for C1: 100%|███████████████████████████████████| 4887/4887 [00:00<00:00, 13348.66it/s]
Creating neighborhoods for D1: 100%|███████████████████████████████████| 4169/4169 [00:00<00:00, 12792.84it/s]
Creating neighborhoods for A1: 100%|████████████████████████████████████| 4950/4950 [00:00<00:00, 6262.87it/s]
Creating neighborhoods for B1: 100%|████████████████████████████████████| 4922/4922 [00:00<00:00, 6847.76it/s]
Creating neighborhoods for C1: 100%|████████████████████████████████████| 4887/4887 [00:00<00:00, 6706.15it/s]
Creating neighborhoods for D1: 100%|████████████████████████████████████| 4169/4169 [00:00<00:00, 6594.77it/s]
Creating neighborhoods for A1: 100%|████████████████████████████████████| 4950/4950 [00:01<00:00, 3469.00it/s]
C

In [12]:
neighborhoods_by_spot_3['A1']#7

[array([], dtype=float64),
 array([], dtype=float64),
 array([   1,    6,   28,   29,   38,   39,   40,  117,  388,  389, 1149,
        1150, 1151, 1152, 1153, 1154, 1155, 1174, 2710, 2711, 2712, 2713,
        2714, 2715]),
 array([   1,    6,   28,   29,   40,  386,  388,  389,  437, 1149, 1150,
        1151, 1152, 1153, 1154, 1155, 1169, 1173, 1174, 1175, 2710, 2711,
        2712, 2713, 2714, 2715, 2716, 2717, 2718]),
 array([], dtype=float64),
 array([   1,    6,   28,   29,   38,   39,   40,  117,  386,  388,  389,
         437, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1169, 1173, 1174,
        1175, 2710, 2711, 2712, 2713, 2714, 2715, 2716]),
 array([   1,    6,   28,   29,   38,   39,   40,   43,   44,  105,  115,
         116,  117,  118,  386,  387,  388,  389, 1149, 1150, 1151, 1152,
        1153, 1154, 1155, 1174, 2710, 2711, 2712, 2714]),
 array([   1,    6,   28,   29,   38,   39,   40,   43,   44,  105,  115,
         116,  117,  118,  386,  387,  388,  389,  437, 1149, 1

In [13]:
import torch.utils.data

def ensure_tensorlist(L):
    return [ensure_tensor(x) for x in L]

def ensure_tensor(x):
    return torch.tensor(x) if type(x) is not torch.Tensor else x

class CellSubgraphDataset(torch.utils.data.Dataset):
    """
    One slide at a time.
    
    Parameters:
    - cell_embeddings: torch.Tensor of cell embeddings
    - cells_by_spot: List of cell indexes corresponding to Visium node
    - neighborhoods_by_spot: List of cell indexes corresponding to Visium node
    - adjacency_matrix: torch.Tensor containing adjacency matrix (cell -> cell)
    - counts: torch.Tensor containing gene counts of each Visium node
    """
    
    def __init__(self, cell_embeddings, cells_by_spot, 
#                  neighborhoods_by_spot_0, 
                 neighborhoods_by_spot_3, 
                 neighborhoods_by_spot_4, 
                 neighborhoods_by_spot_5, 
                 adjacency_matrix, counts):
        self.cells_by_spot = ensure_tensorlist(cells_by_spot)
#         self.neighborhoods_by_spot_0 = ensure_tensorlist(neighborhoods_by_spot_0)
        self.neighborhoods_by_spot_3 = ensure_tensorlist(neighborhoods_by_spot_3)
        self.neighborhoods_by_spot_4 = ensure_tensorlist(neighborhoods_by_spot_4)
        self.neighborhoods_by_spot_5 = ensure_tensorlist(neighborhoods_by_spot_5)
        self.cell_embeddings = ensure_tensor(cell_embeddings)
        self.adjacency_matrix = ensure_tensor(adjacency_matrix)
        self.counts = ensure_tensor(counts).type(torch.float)

        assert len(self.cells_by_spot) == len(self.neighborhoods_by_spot_3)
        assert len(cell_embeddings) == adjacency_matrix.shape[0]
        
    def __len__(self):
        return len(self.cells_by_spot)
    
    def __getitem__(self, index):
        cells_in_spot = self.cells_by_spot[index]
#         cells_in_neighborhood_0 = self.neighborhoods_by_spot_0[index]
        cells_in_neighborhood_3 = self.neighborhoods_by_spot_3[index]
        cells_in_neighborhood_4 = self.neighborhoods_by_spot_4[index]
        cells_in_neighborhood_5 = self.neighborhoods_by_spot_5[index]
        
        # Re-index cells. Treat cells_in_neighborhood as the universal set. Create an edge list.
        # https://stackoverflow.com/questions/22927181/selecting-specific-rows-and-columns-from-numpy-array
#         adjacency_matrix_subgraph_0 = self.adjacency_matrix[cells_in_neighborhood_0, :][:, cells_in_neighborhood_0]
        adjacency_matrix_subgraph_3 = self.adjacency_matrix[cells_in_neighborhood_3, :][:, cells_in_neighborhood_3]
        adjacency_matrix_subgraph_4 = self.adjacency_matrix[cells_in_neighborhood_4, :][:, cells_in_neighborhood_4]
        adjacency_matrix_subgraph_5 = self.adjacency_matrix[cells_in_neighborhood_5, :][:, cells_in_neighborhood_5]
        
        # tr = transformed
        cells_in_neighborhood_tr = torch.arange(len(cells_in_neighborhood_3))
#         cells_in_neighborhood_tr = torch.arange(len(cells_in_neighborhood_0))
        
#         cell_embeddings_0 = self.cell_embeddings[cells_in_neighborhood_0]
        cell_embeddings_3 = self.cell_embeddings[cells_in_neighborhood_3]
        cell_embeddings_4 = self.cell_embeddings[cells_in_neighborhood_4]
        cell_embeddings_5 = self.cell_embeddings[cells_in_neighborhood_5]
        
#         print(cells_in_neighborhood_5)
        
        from_5_4 = []
        from_4_3 = []
#         from_3_0 = []
        
        for one_index in range(cells_in_neighborhood_5.numpy().shape[0]):
            one = cells_in_neighborhood_5.numpy()[one_index]
            if one in cells_in_neighborhood_4.numpy():
                from_5_4.append(one_index)
                
        for one_index in range(cells_in_neighborhood_4.numpy().shape[0]):
            one = cells_in_neighborhood_4.numpy()[one_index]
            if one in cells_in_neighborhood_3.numpy():
                from_4_3.append(one_index)
                
#         for one_index in range(cells_in_neighborhood_3.numpy().shape[0]):
#             one = cells_in_neighborhood_3.numpy()[one_index]
#             if one in cells_in_neighborhood_0.numpy():
#                 from_3_0.append(one_index)
                
        from_5_4 = torch.tensor(np.array(from_5_4)).reshape(1, len(from_5_4))
        from_4_3 = torch.tensor(np.array(from_4_3)).reshape(1, len(from_4_3))
#         from_3_0 = torch.tensor(np.array(from_3_0)).reshape(1, len(from_3_0))
        
        cells_in_spot_tr = np.array([
            # Find index within cells_in_neighborhood of corresponding index
            # np.where returns a tuple of parallel arrays. In this case, it is
            # a tuple of length 1 with an array of length 1.
#             torch.where(cells_in_neighborhood_0 == cell_in_spot)[0][0]
            torch.where(cells_in_neighborhood_3 == cell_in_spot)[0][0]
            for cell_in_spot in cells_in_spot
        ])
        
#         edge_index_0 = torch.stack(torch.where(adjacency_matrix_subgraph_0))
        edge_index_3 = torch.stack(torch.where(adjacency_matrix_subgraph_3))
        edge_index_4 = torch.stack(torch.where(adjacency_matrix_subgraph_4))
        edge_index_5 = torch.stack(torch.where(adjacency_matrix_subgraph_4))
        
        counts = self.counts[index]
        
        return (
#                 cell_embeddings_0,
                cell_embeddings_3,
                cell_embeddings_4, 
                cell_embeddings_5, 
#                 edge_index_0, 
                edge_index_3, 
                edge_index_4, 
                edge_index_5, 
#                 cells_in_neighborhood_5,
                from_5_4,
                from_4_3,
#                 from_3_0,
                cells_in_spot_tr, counts)


In [14]:
datasets = {}

for slide_id in visium.keys():
# for slide_id in ['A1']:
    counts, locations = visium[slide_id]
    counts_tensor = torch.tensor(counts.values)
    
    included_spots = {i for i in range(counts.shape[0]) if len(cells_by_spot[slide_id][i]) > 0}
    
    datasets[slide_id] = CellSubgraphDataset(
        embeddings[slide_id],
        [x for i, x in enumerate(cells_by_spot[slide_id]) if i in included_spots],
#         [x for i, x in enumerate(neighborhoods_by_spot_0[slide_id]) if i in included_spots],
        [x for i, x in enumerate(neighborhoods_by_spot_3[slide_id]) if i in included_spots],
        [x for i, x in enumerate(neighborhoods_by_spot_4[slide_id]) if i in included_spots],
        [x for i, x in enumerate(neighborhoods_by_spot_5[slide_id]) if i in included_spots],
        adjacency_matrices[slide_id],
        torch.stack([
            counts_tensor[i] for i in range(counts_tensor.shape[0]) if i in included_spots
        ]),
    )

In [15]:
embeddings['A1'][0]

array([ 2.15851879e+00, -4.54090506e-01, -6.12399161e-01,  1.89687788e-01,
       -9.67256129e-01, -1.10435426e+00, -2.28791341e-01,  3.81311059e-01,
        2.47201607e-01,  6.06539100e-02,  1.80102062e+00, -7.77545333e-01,
       -9.12513733e-01,  2.48170829e+00,  1.33282459e+00,  2.18308955e-01,
        8.61798525e-01, -8.41341257e-01,  1.12776995e+00, -2.00717402e+00,
        2.39386654e+00, -4.62100357e-01,  8.92302275e-01, -9.67932999e-01,
       -7.30184093e-03, -6.82025313e-01,  6.85053885e-01, -2.18302560e+00,
        1.76220424e-02,  3.15338087e+00,  5.49499393e-01,  7.19597399e-01,
       -4.94046718e-01, -2.17067480e+00, -4.16925043e-01, -9.16610181e-01,
       -3.26585889e-01,  5.22920012e-01, -2.85119891e+00,  7.54085302e-01,
       -5.71950339e-02, -2.84478450e+00,  1.21916068e+00, -1.39333606e+00,
       -4.67354476e-01, -5.27500153e-01, -2.05268681e-01,  1.40883172e+00,
       -5.37662566e-01, -2.29343221e-01,  4.31148261e-02, -2.85426378e-01,
        2.16352797e+00, -

In [16]:
import torch_geometric as pyg

train_dataset = torch.utils.data.ConcatDataset([
    datasets['A1'],
    datasets['B1'],
    datasets['C1'],
])

valid_dataset = datasets['D1']

train_loader = torch.utils.data.DataLoader(train_dataset, 1, shuffle=True, num_workers=80)
valid_loader = torch.utils.data.DataLoader(valid_dataset, 1, shuffle=True, num_workers=80)

# train_loader = torch.utils.data.DataLoader(train_dataset, 32, shuffle=True, num_workers=80)
# valid_loader = torch.utils.data.DataLoader(valid_dataset, 32, shuffle=True, num_workers=80)

dataset = {'train': train_loader, 'val': valid_loader}

In [17]:
def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists.
    Defined in :numref:`sec_use_gpu`"""
    devices = [torch.device(f'cuda:{i}')
             for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

In [18]:
class CellSubgraphModel(torch.nn.Module):
    def __init__(self, n_genes, d_model):
        super(CellSubgraphModel, self).__init__()
        self.n_genes = n_genes
        self.d_model = d_model
        
#         self.graph_layer_5_4 = pyg.nn.GCNConv(d_model, d_model)
#         self.graph_layer_4_3 = pyg.nn.GCNConv(d_model, d_model)
#         self.graph_layer_3_1 = pyg.nn.GCNConv(d_model, d_model)
#         self.graph_layer_3_2 = pyg.nn.GCNConv(d_model, d_model)
#         self.graph_layer_3_0 = pyg.nn.GCNConv(d_model, d_model)
#         self.graph_layer_0 = pyg.nn.GCNConv(d_model, d_model)
        self.graph_layer_5_4 = pyg.nn.GATConv(d_model, d_model)
        self.graph_layer_4_3 = pyg.nn.GATConv(d_model, d_model)
        self.graph_layer_3_1 = pyg.nn.GATConv(d_model, d_model)
        self.graph_layer_3_2 = pyg.nn.GATConv(d_model, d_model)
        self.graph_layer_3_0 = pyg.nn.GATConv(d_model, d_model)
        self.graph_layer_0 = pyg.nn.GATConv(d_model, d_model)
        
        self.gene_head = nn.Linear(d_model, n_genes)

    def forward(self, 
                embeddings_3, 
                embeddings_4,
                embeddings_5,
                edge_index_3, 
                edge_index_4, 
                edge_index_5, 
                from_5_4,
                from_4_3,
                covered_indexes, counts,
                   ):
        
        embeddings_5 = self.graph_layer_5_4(embeddings_5[0], edge_index_5[0])
        embeddings_4 = embeddings_5[from_5_4[0][0]]
        embeddings_4 = self.graph_layer_4_3(embeddings_4, edge_index_4[0])
        embeddings_3 = embeddings_4[from_4_3[0][0]]
        embeddings_3 = self.graph_layer_3_1(embeddings_3, edge_index_3[0])
        embeddings_3 = self.graph_layer_3_2(embeddings_3, edge_index_3[0])
        embeddings_3 = self.graph_layer_3_0(embeddings_3, edge_index_3[0])
        
        # Final activation is softplus, to ensure that the results are positive
        cell_predictions = F.softplus(self.gene_head(embeddings_3))
#         spot_prediction_test = cell_predictions[covered_indexes]
        spot_prediction = cell_predictions[covered_indexes].mean(dim=0).mean(dim=0)
        spot_prediction = spot_prediction.reshape(1, spot_prediction.shape[0])
#         print('cell_predictions')
#         print(cell_predictions)
#         print(cell_predictions.shape)
#         print('spot_prediction_test')
#         print(spot_prediction_test)
#         print(spot_prediction_test.shape)
#         print('spot_prediction')
#         print(spot_prediction)
#         print(spot_prediction.shape)
#         print('end')
        return spot_prediction

In [19]:
def train_model(model, 
#                 criterion, 
                optimizer, 
#                 scheduler, 
                out_dir, out_file, num_epochs=10):
    
    devices = try_all_gpus()
#     model = nn.DataParallel(model, device_ids=devices)
    model.to(devices[1])
    
    t_0 = time.time()

    best_model = copy.deepcopy(model.state_dict())
    best_loss = 100000

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}")
        for phase in ['train', 'val',]:
            
            model.train() if phase == 'train' else model.eval()
#             Y = []
#             Y_Pred = []
#             Y_Pred_prob = []
            loss_list = []
            for one_batch in tqdm(dataset[phase]):
#                 print(type(one_batch))
#                 print(len(one_batch))
#                 print(one_batch)
                embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts = one_batch
                embeddings_3 = embeddings_3.to(devices[1])
                embeddings_4 = embeddings_4.to(devices[1])
                embeddings_5 = embeddings_5.to(devices[1])
                edge_index_3 = edge_index_3.to(devices[1])
                edge_index_4 = edge_index_4.to(devices[1])
                edge_index_5 = edge_index_5.to(devices[1])
                from_5_4 = from_5_4.to(devices[1])
                from_4_3 = from_4_3.to(devices[1])
                covered_indexes = covered_indexes.to(devices[1])
                counts = counts.to(devices[1])
            
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts)
#                     _, preds = torch.max(outputs, 1)
#                     print(outputs)
#                     print(outputs.shape)
#                     print(counts)
#                     print(counts.shape)
#                     vfxbvdfbgf
                    loss = F.mse_loss(torch.log(1 + outputs), torch.log(1 + counts))
    
                    loss_list.append(float(loss))
    
#                     loss = criterion(outputs, target)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
#                     Y.append(target.cpu())
#                     Y_Pred.append(preds.cpu())
#                     Y_Pred_prob.append(outputs.cpu().detach().numpy())
        
#             if phase == 'train': scheduler.step()
                
#             f1 = f1_score(np.hstack(Y),np.hstack(Y_Pred),average='weighted')
#             roc = roc_auc_score(np.hstack(Y),np.hstack(Y_Pred))
#             b_accuracy = balanced_accuracy_score(np.hstack(Y),np.hstack(Y_Pred))
#             accuracy = accuracy_score(np.hstack(Y),np.hstack(Y_Pred))
   
            print('loss:', sum(loss_list)/len(loss_list))
            if phase == 'val' and sum(loss_list)/len(loss_list) < best_loss:
                best_loss = sum(loss_list)/len(loss_list)
                best_model = copy.deepcopy(model.state_dict())
                
            out_file_temp = 'model_epoch_' + str(epoch)
            torch.save(model.state_dict(),os.path.join(out_dir,f"{out_file_temp}.params"))
            
        print()

    dt = time.time() - t_0
    print(f"Training completed in {dt//60}m {dt%60}s")
    print(f'Best Loss: {best_loss}')

    model.load_state_dict(best_model)
    torch.save(model.state_dict(),os.path.join(out_dir,f"{out_file}.params"))
    return model

In [20]:
model = CellSubgraphModel(n_genes=17943, d_model=1000)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)

# criterion = nn.CrossEntropyLoss()
# optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
optimizer_ft = torch.optim.Adam(model.parameters(), lr=1e-4)
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.05)

In [21]:
model_ft = train_model(model, 
#                        criterion, 
                       optimizer_ft, 
#                        exp_lr_scheduler,
                       out_dir="../../../all_code/GNN_models/GNN_model_1",
                       out_file="model_best",
                       num_epochs=20)

Epoch 0


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:54<00:00, 51.10it/s]


loss: 0.12021787484148667


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.11it/s]


loss: 0.07111460913467166

Epoch 1


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:09<00:00, 48.00it/s]


loss: 0.11863918954821356


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:39<00:00, 81.32it/s]


loss: 0.08188189421826512

Epoch 2


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:11<00:00, 47.66it/s]


loss: 0.11492467517675008


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:44<00:00, 72.97it/s]


loss: 0.07807811781375595

Epoch 3


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:22<00:00, 45.78it/s]


loss: 0.1143062487378894


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.77it/s]


loss: 0.07116562139218682

Epoch 4


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:21<00:00, 45.88it/s]


loss: 0.11392337657148889


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:38<00:00, 83.98it/s]


loss: 0.08649726494994604

Epoch 5


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:02<00:00, 49.55it/s]


loss: 0.11333261181683138


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.80it/s]


loss: 0.07718033263273076

Epoch 6


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:44<00:00, 53.45it/s]


loss: 0.11322055100322832


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.35it/s]


loss: 0.07305909185393686

Epoch 7


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:14<00:00, 47.24it/s]


loss: 0.11270781120080156


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:37<00:00, 86.94it/s]


loss: 0.07281666734552075

Epoch 8


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:54<00:00, 51.09it/s]


loss: 0.11265320841315839


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.07it/s]


loss: 0.07605009878026599

Epoch 9


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:44<00:00, 53.45it/s]


loss: 0.11237259054903309


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.42it/s]


loss: 0.08258629261803038

Epoch 10


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:49<00:00, 52.36it/s]


loss: 0.1124705636493022


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:39<00:00, 81.76it/s]


loss: 0.06981616784555672

Epoch 11


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:08<00:00, 48.34it/s]


loss: 0.11262176638963715


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.53it/s]


loss: 0.06906085066862709

Epoch 12


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:58<00:00, 50.40it/s]


loss: 0.11234799706661035


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.26it/s]


loss: 0.07708392631018184

Epoch 13


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:16<00:00, 46.79it/s]


loss: 0.11222633499778517


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:44<00:00, 72.97it/s]


loss: 0.07610986120180495

Epoch 14


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:15<00:00, 46.96it/s]


loss: 0.11213081183221242


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.40it/s]


loss: 0.07460523741538774

Epoch 15


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:12<00:00, 47.49it/s]


loss: 0.111873963202708


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.32it/s]


loss: 0.07483529166369192

Epoch 16


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:15<00:00, 46.90it/s]


loss: 0.11186861042465793


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.69it/s]


loss: 0.0689109206835896

Epoch 17


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:14<00:00, 47.20it/s]


loss: 0.11188188749574315


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.52it/s]


loss: 0.07046174862344823

Epoch 18


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:22<00:00, 45.71it/s]


loss: 0.1117191539054566


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.78it/s]


loss: 0.07281846721975899

Epoch 19


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:11<00:00, 47.77it/s]


loss: 0.11178864478511644


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.37it/s]


loss: 0.07245605984082593

Training completed in 96.0m 52.93061804771423s
Best Loss: 0.016596835106611252


In [22]:
model = CellSubgraphModel(n_genes=17943, d_model=1000)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)

# criterion = nn.CrossEntropyLoss()
# optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
optimizer_ft = torch.optim.Adam(model.parameters(), lr=5e-4)
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.05)

In [23]:
model_ft = train_model(model, 
#                        criterion, 
                       optimizer_ft, 
#                        exp_lr_scheduler,
                       out_dir="../../../all_code/GNN_models/GNN_model_2",
                       out_file="model_best",
                       num_epochs=20)

Epoch 0


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:14<00:00, 47.11it/s]


loss: 0.2022596218392077


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.64it/s]


loss: 0.10039035354424918

Epoch 1


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:11<00:00, 47.70it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 76.83it/s]


loss: 0.10039035354424918

Epoch 2


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:24<00:00, 45.33it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 75.03it/s]


loss: 0.10039035354424918

Epoch 3


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:18<00:00, 46.45it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:42<00:00, 75.92it/s]


loss: 0.10039035354424918

Epoch 4


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:21<00:00, 45.95it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.58it/s]


loss: 0.10039035354424918

Epoch 5


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:10<00:00, 47.85it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.07it/s]


loss: 0.10039035354424918

Epoch 6


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:22<00:00, 45.70it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:41<00:00, 77.03it/s]


loss: 0.10039035354424918

Epoch 7


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:16<00:00, 46.71it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.75it/s]


loss: 0.10039035354424918

Epoch 8


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:24<00:00, 45.36it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.45it/s]


loss: 0.10039035354424918

Epoch 9


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:28<00:00, 44.72it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:45<00:00, 71.75it/s]


loss: 0.10039035354424918

Epoch 10


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:19<00:00, 46.22it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:41<00:00, 78.09it/s]


loss: 0.10039035354424918

Epoch 11


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:11<00:00, 47.67it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:43<00:00, 74.61it/s]


loss: 0.10039035354424918

Epoch 12


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:16<00:00, 46.81it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:41<00:00, 77.68it/s]


loss: 0.10039035354424918

Epoch 13


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:05<00:00, 48.90it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:54<00:00, 59.52it/s]


loss: 0.10039035354424918

Epoch 14


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:13<00:00, 47.31it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:48<00:00, 66.14it/s]


loss: 0.10039035354424918

Epoch 15


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:06<00:00, 48.58it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:40<00:00, 79.02it/s]


loss: 0.10039035354424918

Epoch 16


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:04<00:00, 49.15it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:50<00:00, 63.59it/s]


loss: 0.10039035354424918

Epoch 17


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:24<00:00, 45.39it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:49<00:00, 65.14it/s]


loss: 0.10039035354424918

Epoch 18


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [04:31<00:00, 44.22it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:48<00:00, 66.91it/s]


loss: 0.10039035354424918

Epoch 19


100%|███████████████████████████████████████████████████████████████████| 11998/11998 [03:44<00:00, 53.36it/s]


loss: 0.20108499122365486


100%|█████████████████████████████████████████████████████████████████████| 3232/3232 [00:37<00:00, 86.57it/s]


loss: 0.10039035354424918

Training completed in 100.0m 39.72027325630188s
Best Loss: 0.003789865644648671


# Test

In [55]:
for i in dataset['train']:
    print(i)
    break

[tensor([[[ 2.5336, -0.5515, -0.8042,  ..., -0.2916, -0.0327,  1.1158],
         [ 2.5496, -0.5587, -0.8136,  ..., -0.2960, -0.0289,  1.1246],
         [ 2.5267, -0.5525, -0.8031,  ..., -0.2925, -0.0433,  1.1142],
         ...,
         [ 2.4233, -0.5167, -0.7456,  ..., -0.2724, -0.0853,  1.0600],
         [ 2.3882, -0.5137, -0.7317,  ..., -0.2704, -0.1135,  1.0465],
         [ 2.4902, -0.5372, -0.7808,  ..., -0.2845, -0.0528,  1.0939]]]), tensor([[[ 2.5336, -0.5515, -0.8042,  ..., -0.2916, -0.0327,  1.1158],
         [ 2.5496, -0.5587, -0.8136,  ..., -0.2960, -0.0289,  1.1246],
         [ 2.5517, -0.5604, -0.8150,  ..., -0.2969, -0.0291,  1.1259],
         ...,
         [ 2.4902, -0.5372, -0.7808,  ..., -0.2845, -0.0528,  1.0939],
         [ 2.4698, -0.5319, -0.7702,  ..., -0.2790, -0.0653,  1.0838],
         [ 2.3867, -0.5097, -0.7284,  ..., -0.2680, -0.1098,  1.0436]]]), tensor([[[ 2.5336, -0.5515, -0.8042,  ..., -0.2916, -0.0327,  1.1158],
         [ 2.5670, -0.5630, -0.8214,  ...,

In [56]:
len(i)

10

In [57]:
gatnn = pyg.nn.GATConv(1000, 1000)

In [58]:
out = gatnn(i[2][0], i[5][0])

In [65]:
out

tensor([[ 2.1633,  0.4447, -0.4929,  ...,  1.1411, -0.4601, -0.3592],
        [ 2.1771,  0.4528, -0.4600,  ...,  1.1350, -0.4759, -0.4231],
        [ 2.1805,  0.4558, -0.4493,  ...,  1.1337, -0.4808, -0.4413],
        ...,
        [ 2.1604,  0.4482, -0.4871,  ...,  1.1415, -0.4626, -0.3596],
        [ 2.1509,  0.4416, -0.5112,  ...,  1.1460, -0.4508, -0.3142],
        [ 2.1669,  0.4510, -0.4741,  ...,  1.1385, -0.4691, -0.3873]],
       grad_fn=<AddBackward0>)

In [70]:
out[i[-4][0][0]].shape

torch.Size([100, 1000])

In [60]:
out[i[-4].numpy()[0]].shape

torch.Size([100, 1000])

In [62]:
i[1].shape

torch.Size([1, 100, 1000])

In [64]:
i[2].shape

torch.Size([1, 158, 1000])

In [137]:
test = torch.tensor(0.5422)

In [140]:
float(test)

0.5422000288963318

In [80]:
test = torch.tensor([[0.4043, 0.5290, 0.5692, 0.8819, 0.8630, 0.8656],
        [0.4053, 0.5295, 0.5710, 0.8820, 0.8632, 0.8651]])

In [82]:
test.mean(dim = 0)

tensor([0.4048, 0.5293, 0.5701, 0.8820, 0.8631, 0.8654])

In [113]:
test = torch.tensor([0.8182, 0.3813, 0.9209, 0.3871, 1.0588, 0.8985])

In [114]:
test.shape

torch.Size([6])

In [124]:
test = test.reshape(1, 1, 6)

In [127]:
test.shape[2]

6

In [57]:
# Should not use embedding_3 as final, should more centered to the spot_cells

class CellSubgraphModelV2(pl.LightningModule):
    def __init__(self, n_genes: int, GraphLayer: type, d_model: int, n_layers: int):
        super().__init__()
        
        self.n_genes = n_genes
        self.d_model = d_model
        self.n_layers = n_layers
        
        self.graph_layer_5_4 = GraphLayer(d_model, d_model)
        self.graph_layer_4_3 = GraphLayer(d_model, d_model)
        self.graph_layer_3_1 = GraphLayer(d_model, d_model)
        self.graph_layer_3_2 = GraphLayer(d_model, d_model)
        self.graph_layer_3_0 = GraphLayer(d_model, d_model)
        self.graph_layer_0 = GraphLayer(d_model, d_model)
        
#         graph_layers = [
#             GraphLayer(d_model, d_model) for _ in range(n_layers)
#         ]
#         self.graph_layers = nn.ModuleList(graph_layers)
        
        self.gene_head = nn.Linear(d_model, n_genes)
        
        self.save_hyperparameters()
        
    def forward(self, 
#                 embeddings_0, 
                embeddings_3, 
                embeddings_4,
                embeddings_5,
#                 edge_index_0,
                edge_index_3, 
                edge_index_4, 
                edge_index_5, 
                from_5_4,
                from_4_3,
#                 from_3_0,
                   ):
        
        embeddings_5 = self.graph_layer_5_4(embeddings_5, edge_index_5)
        # multiple layers for each step, try 2 first
#         print(embeddings_5)
#         print(from_5_4)
        embeddings_4 = embeddings_5[from_5_4.numpy()[0]]
#         print(embeddings_4)
        embeddings_4 = self.graph_layer_4_3(embeddings_4, edge_index_4)
        embeddings_3 = embeddings_4[from_4_3.numpy()[0]]
        embeddings_3 = self.graph_layer_3_1(embeddings_3, edge_index_3)
        embeddings_3 = self.graph_layer_3_2(embeddings_3, edge_index_3)
        embeddings_3 = self.graph_layer_3_0(embeddings_3, edge_index_3)
#         embeddings_0 = embeddings_3[from_3_0.numpy()[0]]
#         embeddings_0 = self.graph_layer_0(embeddings_0, edge_index_0)
        
#         for layer_i in range(self.n_layers):
#             embeddings = self.graph_layers[layer_i](embeddings, edge_index)
        
        # Final activation is softplus, to ensure that the results are positive
        return F.softplus(self.gene_head(embeddings_3))
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)
    
    # Predicts log1p. To find raw counts, use expm1 
    def spot_prediction(self, instance):
#         embeddings_0, embeddings_3, embeddings_4,embeddings_5,edge_index_0, edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, from_3_0, covered_indexes, counts = instance
        embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts = instance
        
        cell_predictions = self.forward(
#                                         embeddings_0,
                                        embeddings_3, 
                                        embeddings_4,
                                        embeddings_5,
#                                         edge_index_0,
                                        edge_index_3,
                                        edge_index_4, 
                                        edge_index_5, 
                                        from_5_4,
                                        from_4_3,
#                                         from_3_0,
                                            )
        spot_prediction = cell_predictions[covered_indexes].mean(dim=0)
        
        return spot_prediction
    
    def training_step(self, batch, batch_idx):
        embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts = instance = [x[0] for x in batch]
        
        loss = F.mse_loss(torch.log(1 + self.spot_prediction(instance)), torch.log(1 + counts))
        
        self.log('train_loss', loss.item(), prog_bar=True, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts = instance = [x[0] for x in batch]
        
        loss = F.mse_loss(torch.log(1 + self.spot_prediction(instance)), torch.log(1 + counts))
        
        self.log('validation_loss', loss.item(), prog_bar=True, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        embeddings_3, embeddings_4,embeddings_5,edge_index_3, edge_index_4, edge_index_5, from_5_4,from_4_3, covered_indexes, counts = instance = [x[0] for x in batch]
        
        loss = F.mse_loss(torch.log(1 + self.spot_prediction(instance)), torch.log(1 + counts))
        
        self.log('test_loss', loss.item(), prog_bar=True, on_epoch=True)
        
        return loss
    
    def predict_step(self, batch, batch_idx):
        return self.spot_prediction([x[0] for x in batch])

In [61]:
# 3 layers for 3 hops
model = CellSubgraphModelV2(n_genes=17943, GraphLayer=pyg.nn.GATConv, d_model=1000, n_layers=3)

trainer = pl.Trainer(
    max_epochs=32,
)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [62]:
trainer.fit(model, train_loader, valid_loader)


  | Name            | Type    | Params
--------------------------------------------
0 | graph_layer_5_4 | GATConv | 1.0 M 
1 | graph_layer_4_3 | GATConv | 1.0 M 
2 | graph_layer_3_1 | GATConv | 1.0 M 
3 | graph_layer_3_2 | GATConv | 1.0 M 
4 | graph_layer_3_0 | GATConv | 1.0 M 
5 | graph_layer_0   | GATConv | 1.0 M 
6 | gene_head       | Linear  | 18.0 M
--------------------------------------------
24.0 M    Trainable params
0         Non-trainable params
24.0 M    Total params
95.916    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# Test