In [7]:
import spacr
import pickle
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader, NeighborSampler
from sklearn.metrics import mean_squared_error
from torch_geometric.nn import SAGEConv
from torch.nn import Linear

def generate_graph(gene_df, cell_df):
    # Ensure the dataframes are sorted to maintain consistent ordering
    gene_df = gene_df.sort_values(by=['gene_id', 'well_id'])
    cell_df = cell_df.sort_values(by=['cell_id', 'well_id'])
    
    # Map gene_id and cell_id to unique indices for graph construction
    gene_id_to_index = {gene_id: idx for idx, gene_id in enumerate(gene_df.gene_id.unique())}
    cell_id_to_index = {cell_id: idx+len(gene_id_to_index) for idx, cell_id in enumerate(cell_df.cell_id.unique())}

    # Creating a mapping from well_id to gene fractions
    well_to_gene_fractions = gene_df.groupby('well_id').apply(lambda x: dict(zip(x.gene_id, x.well_read_fraction))).to_dict()

    # Prepare edges and edge attributes
    edge_index = []
    edge_attr = []
    # Cell scores as labels
    cell_scores = []

    for _, cell_row in cell_df.iterrows():
        cell_idx = cell_id_to_index[cell_row['cell_id']]
        cell_scores.append(cell_row['score'])  # Add cell score to labels list
        if cell_row['well_id'] in well_to_gene_fractions:
            for gene_id, fraction in well_to_gene_fractions[cell_row['well_id']].items():
                if gene_id in gene_id_to_index:  # Check if gene is in the index map
                    gene_idx = gene_id_to_index[gene_id]
                    edge_index.append([cell_idx, gene_idx])  # Note the order: cell to gene
                    edge_attr.append(fraction)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).view(-1, 1)
    cell_scores = torch.tensor(cell_scores, dtype=torch.float)

    # Node features for genes could be one-hot encoded or gene fractions; placeholder for now
    gene_features = torch.eye(len(gene_id_to_index))  # One-hot encoding of genes as placeholder
    cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))  # Placeholder for cell features

    # Combine features for all nodes
    x = torch.cat([cell_features, gene_features], dim=0)  # Ensure cell features come first to align with cell_scores

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)

    # Return both the graph data and the gene_id_to_index dictionary
    return data, gene_id_to_index

    
class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        self.conv1 = SAGEConv(feature_size, 16)
        self.conv2 = SAGEConv(16, 32)
        self.out = Linear(32, 1)

    def forward(self, x, edge_index=None, adjs=None):
        if adjs is None:
            # Assume full graph processing for feature importance evaluation
            x = F.relu(self.conv1(x, edge_index))
            x = F.dropout(x, training=self.training)
            x = F.relu(self.conv2(x, edge_index))
        else:
            # Batch processing with NeighborSampler
            for i, (edge_index, _, size) in enumerate(adjs):
                x_target = x[:size[1]]  # Target nodes for this layer
                if i == 0:
                    x = self.conv1((x, x_target), edge_index)
                else:
                    x = self.conv2((x, x_target), edge_index)
                x = F.relu(x)
                x = F.dropout(x, training=self.training)
        
        x = self.out(x)
        return x

def train_gnn(graph_data, model_save_path, feature_size, lr, epochs, batch_size, size):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = GNN(feature_size=feature_size).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    graph_data = graph_data.to(device)
    train_loader = NeighborSampler(graph_data.edge_index,
                                   node_idx=None,
                                   sizes=[size, size],
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=False)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_data in train_loader:
            batch_size, n_id, adjs = batch_data
            adjs = [adj.to(device) for adj in adjs]  # Ensure adjs are on the correct device
            optimizer.zero_grad()
            out = model(x=graph_data.x[n_id], adjs=adjs)  # Correctly pass adjs
            loss = criterion(out, graph_data.y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}', end='\r', flush=True)
    del optimizer, criterion, train_loader, batch_data
    torch.cuda.empty_cache()
    torch.save(model.state_dict(), model_save_path)
    return model




In [6]:
model_save_path = '/home/olafsson/Desktop/gnn/mode.pth'
compute_gene_importance(model,
                        graph_data,
                        model_save_path,
                        n_permutations=10)

graph_data.y shape: torch.Size([710938])
original_preds shape: torch.Size([712070])


ValueError: Found input variables with inconsistent numbers of samples: [710938, 712070]

In [8]:
sequencing = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/sequencing.csv'
score = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv_cell.csv'

# Example loading step
gene_df = pd.read_csv(sequencing)
cell_df = pd.read_csv(score)

gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
gene_df = gene_df.drop(columns=['Unnamed: 0', 'plate', 'row', 'col', 'grna_seq', 'gene'])
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
gene_df = gene_df.drop(columns=['read_count', 'total_reads'])

cell_df = cell_df.rename(columns={"prcfo": "cell_id", "prc": "well_id", "pred": "score"})
cell_df = cell_df.drop(columns=['parasite_area', 'parasite_area', 'recruitment'])
display(gene_df)
display(cell_df)
feature_size = len(gene_df['gene_id'].unique())
feature_size

Unnamed: 0,gene_id,well_id,well_read_fraction
0,TGGT1_313050_2,p1_r1_c1,0.274431
1,TGGT1_207865_1,p1_r1_c1,0.249161
2,TGGT1_409250_65,p1_r1_c1,0.159549
3,TGGT1_212410_369,p1_r1_c1,0.102387
4,TGGT1_239600_2,p1_r1_c1,0.068351
...,...,...,...
17735,TGGT1_269885A_3,p2_r5_c10,0.061794
17736,TGGT1_235140_3,p2_r5_c2,0.259861
17737,TGGT1_000000_14,p2_r5_c2,0.252900
17738,TGGT1_310010_1,p2_r5_c2,0.250580


Unnamed: 0,cell_id,well_id,score
0,p1_r10_c10_f1_o123,p1_r10_c10,0.999409
1,p1_r10_c10_f1_o159,p1_r10_c10,0.998264
2,p1_r10_c10_f2_o120,p1_r10_c10,0.998019
3,p1_r10_c10_f3_o48,p1_r10_c10,0.998897
4,p1_r10_c10_f3_o70,p1_r10_c10,0.013018
...,...,...,...
710933,p9_r9_c9_f9_o50,p9_r9_c9,0.041357
710934,p9_r9_c9_f9_o80,p9_r9_c9,0.001581
710935,p9_r9_c9_f9_o81,p9_r9_c9,0.005244
710936,p9_r9_c9_f9_o86,p9_r9_c9,0.406849


1132

In [4]:
print(f'feature_size: {feature_size}')
graph_data, gene_id_to_index = generate_graph(gene_df,cell_df)
dict_file_path = '/home/olafsson/Desktop/gnn/dict.pth'

with open(dict_file_path, 'wb') as file:
    pickle.dump(gene_id_to_index, file)

feature_size: 1132


In [9]:
model = train_gnn(graph_data,
                  model_save_path='/home/olafsson/Desktop/gnn/mode.pth',
                  feature_size=feature_size,
                  lr=0.001,
                  epochs=10000,
                  batch_size=2048,
                  size=2000)

Epoch 10000, Loss: 0.1930284325854253

In [None]:
[2K
[2K


In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(data):
    G = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    # Add edges
    for i in range(edge_index.shape[1]):
        source = edge_index[0, i]
        target = edge_index[1, i]
        G.add_edge(source, target)

    plt.figure(figsize=(10, 10))
    nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray', node_size=50, font_size=6)
    plt.show()
    
# Define a simple GNN model
class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        self.conv1 = SAGEConv(feature_size, 16)
        self.conv2 = SAGEConv(16, 32)
        self.out = Linear(32, 1)  # The output size here aligns with the final layer's output

    def forward(self, x, adjs):
        # adjs is a list of tuples provided by NeighborSampler, each for a layer
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes for this layer
            if i == 0:
                x = self.conv1((x, x_target), edge_index)
            else:
                x = self.conv2((x, x_target), edge_index)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
        
        # Apply the final linear transformation
        x = self.out(x)
        return x
    
    
def train_gnn(graph_data, feature_size, lr=0.01, epochs=100, batch_size=1024, size=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = GNN(feature_size=feature_size).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    # Assume graph_data already in device
    graph_data = graph_data.to(device)

    # Set up the NeighborSampler
    # Here, you need to define the sizes of the neighborhoods for each layer
    # This is just an example configuration
    train_loader = NeighborSampler(
        graph_data.edge_index,
        node_idx=None,  # sample from the entire graph; adjust as needed
        sizes=[size, size],  # Sizes of the neighborhoods for each layer
        batch_size=batch_size,  # Adjust based on your GPU memory and task
        shuffle=True,
        drop_last=False,
    )

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_size, n_id, adjs in train_loader:
            adjs = [adj.to(device) for adj in adjs]
            optimizer.zero_grad()
            out = model(graph_data.x[n_id], adjs)  # Note the adjustment here
            loss = criterion(out, graph_data.y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}', end='\r', flush=True)
    return model

def generate_graph_v1(gene_df, cell_df):
    # Ensure the dataframes are sorted to maintain consistent ordering
    gene_df = gene_df.sort_values(by=['gene_id', 'well_id'])
    cell_df = cell_df.sort_values(by=['cell_id', 'well_id'])
    
    # Map gene_id and cell_id to unique indices for graph construction
    gene_id_to_index = {gene_id: idx for idx, gene_id in enumerate(gene_df.gene_id.unique())}
    cell_id_to_index = {cell_id: idx+len(gene_id_to_index) for idx, cell_id in enumerate(cell_df.cell_id.unique())}

    # Creating a mapping from well_id to gene fractions
    well_to_gene_fractions = gene_df.groupby('well_id').apply(lambda x: dict(zip(x.gene_id, x.well_read_fraction))).to_dict()

    # Prepare edges and edge attributes
    edge_index = []
    edge_attr = []
    # Cell scores as labels
    cell_scores = []

    for _, cell_row in cell_df.iterrows():
        cell_idx = cell_id_to_index[cell_row['cell_id']]
        cell_scores.append(cell_row['score'])  # Add cell score to labels list
        if cell_row['well_id'] in well_to_gene_fractions:
            for gene_id, fraction in well_to_gene_fractions[cell_row['well_id']].items():
                if gene_id in gene_id_to_index:  # Check if gene is in the index map
                    gene_idx = gene_id_to_index[gene_id]
                    edge_index.append([cell_idx, gene_idx])  # Note the order: cell to gene
                    edge_attr.append(fraction)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).view(-1, 1)
    cell_scores = torch.tensor(cell_scores, dtype=torch.float)

    # Node features for genes could be one-hot encoded or gene fractions; placeholder for now
    gene_features = torch.eye(len(gene_id_to_index))  # One-hot encoding of genes as placeholder
    cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))  # Placeholder for cell features

    # Combine features for all nodes
    x = torch.cat([cell_features, gene_features], dim=0)  # Ensure cell features come first to align with cell_scores

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)

    return data

In [None]:
'pip install torch-sparse -f https://data.pyg.org/whl/torch-2.2.1+cu121.html'
'pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cu121.html'
