In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import torch.optim as optim


def generate_graphs(sequencing, scores, cell_min, gene_min_read):
    # Load and preprocess sequencing (gene) data
    gene_df = pd.read_csv(sequencing)
    gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
    # Filter out genes with read counts less than gene_min_read
    gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
    total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
    gene_df = gene_df.merge(total_reads_per_well, on='well_id')
    gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']

    # Mapping genes to indices
    gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
    feature_size = len(gene_id_to_index)

    # Load and preprocess cell score data
    cell_df = pd.read_csv(scores)
    cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})

    graphs = []
    for well_id in pd.unique(gene_df['well_id']):
        well_genes = gene_df[gene_df['well_id'] == well_id]
        well_cells = cell_df[cell_df['well_id'] == well_id]
        
        # Skip this well if the number of cells is less than cell_min
        if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
            continue
        
        # Prepare gene features (well_read_fraction)
        gene_features = torch.tensor(well_genes['well_read_fraction'].values, dtype=torch.float).view(-1, 1)
        # Prepare cell features (scores)
        cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)

        num_genes = gene_features.size(0)
        num_cells = cell_features.size(0)
        num_nodes = num_genes + num_cells
        
        # Create dense adjacency matrix connecting each cell to all genes
        adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
        adj[num_genes:, :num_genes] = 1  # Assuming cells come after genes in node ordering

        graph = {
            'adjacency_matrix': adj,
            'gene_features': gene_features,
            'cell_features': cell_features,
            'num_cells': num_cells,
            'num_genes': num_genes
        }
        graphs.append(graph)
    print(f'Generated dataset with {len(graphs)} graphs, and {len(gene_id_to_index)} genes')
    return graphs, feature_size, gene_id_to_index

class Attention(nn.Module):
    def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
        super(Attention, self).__init__()
        self.query = nn.Linear(feature_dim, attn_dim)
        self.key = nn.Linear(feature_dim, attn_dim)
        self.value = nn.Linear(feature_dim, feature_dim)
        self.scale = 1.0 / (attn_dim ** 0.5)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, gene_features, cell_features):
        # Queries come from the cell features
        q = self.query(cell_features)
        # Keys and values come from the gene features
        k = self.key(gene_features)
        v = self.value(gene_features)
        
        # Compute attention weights
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        # Apply dropout to attention weights
        attn_weights = self.dropout(attn_weights)  

        # Apply attention weights to the values
        attn_output = torch.matmul(attn_weights, v)
        
        return attn_output, attn_weights

class GraphTransformer(nn.Module):
    def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
        super(GraphTransformer, self).__init__()
        self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
        self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)

        # Attention layer to let each cell attend to all genes
        self.attention = Attention(hidden_dim, attn_dim)

        # This layer is used to transform the combined features after attention
        self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)

        # Output layer for predicting cell scores, ensuring it matches the number of cells
        self.cell_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, adjacency_matrix, gene_features, cell_features):
        # Apply initial transformation to gene and cell features
        transformed_gene_features = F.relu(self.gene_transform(gene_features))
        transformed_cell_features = F.relu(self.cell_transform(cell_features))

        # Incorporate attention mechanism
        attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)

        # Combine the transformed cell features with the attention output features
        combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
        
        # Apply dropout here as well
        combined_cell_features = self.dropout(combined_cell_features)  

        combined_cell_features = F.relu(self.combine_transform(combined_cell_features))

        # Combine gene and cell features for message passing
        combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)

        # Apply message passing via adjacency matrix multiplication
        message_passed_features = torch.matmul(adjacency_matrix, combined_features)

        # Predict cell scores from the post-message passed cell features
        cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])

        return cell_scores, attn_weights
    
def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, epochs=100, save_fldr='', acc_threshold = 0.1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    training_log = []
    
    accumulate_grad_batches=1
    threshold=acc_threshold
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        optimizer.zero_grad()
        batch_count = 0  # Initialize batch_count
        
        for graph in graphs:
            adjacency_matrix = graph['adjacency_matrix'].to(device)
            gene_features = graph['gene_features'].to(device)
            cell_features = graph['cell_features'].to(device)
            num_cells = graph['num_cells']
            predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
            predictions = predictions.squeeze()
            true_scores = cell_features[:num_cells, 0]
            loss = criterion(predictions, true_scores) / accumulate_grad_batches
            loss.backward()

            # Calculate "accuracy"
            with torch.no_grad():
                correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
                total_correct += correct_predictions
                total_samples += num_cells

            batch_count += 1  # Increment batch_count
            if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulate_grad_batches
        
        accuracy = total_correct / total_samples
        training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
    
    # Save the training log and model as before
    os.makedirs(save_fldr, exist_ok=True)
    log_path = os.path.join(save_fldr, 'training_log.csv')
    training_log_df = pd.DataFrame(training_log)
    training_log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
    model_path = os.path.join(save_fldr, 'model.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    return model
        
def annotate_cells_with_genes(graphs, model, gene_id_to_index):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    annotated_data = []
    with torch.no_grad():  # Disable gradient computation
        for graph in graphs:
            adjacency_matrix = graph['adjacency_matrix'].to(device)
            gene_features = graph['gene_features'].to(device)
            cell_features = graph['cell_features'].to(device)

            predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
            predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
            attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())

            if attn_weights.shape[0] != cell_features.size(0):
                # Skip if the first dimension of attn_weights does not match the number of cells
                continue
            
            for cell_idx in range(cell_features.size(0)):
                true_score = cell_features[cell_idx, 0].item()
                predicted_score = predictions[cell_idx]
                most_probable_gene_idx = attn_weights[cell_idx].argmax()
                most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx]

                gene_id = list(gene_id_to_index.keys())[most_probable_gene_idx]

                annotated_data.append({
                    "Cell ID": cell_idx,
                    "Most Probable Gene": gene_id,
                    "Cell Score": true_score,
                    "Predicted Cell Score": predicted_score,
                    "Probability Score for Highest Gene": most_probable_gene_score
                })

    return pd.DataFrame(annotated_data)

In [3]:
sequencing = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/sequencing.csv'
scores = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv_cell.csv'

graphs, feature_size, gene_id_to_index = generate_graphs(sequencing, scores, cell_min=50, gene_min_read=200)

Generated dataset with 1860 graphs, and 1054 genes


In [None]:
model = train_graph_transformer(graphs,
                        lr=0.00001,
                        dropout_rate=0.1,
                        epochs=10000,
                        save_fldr='/home/olafsson/Desktop/gnn',
                        acc_threshold = 0.25)

Epoch 41, Loss: 0.15246710189034843, Accuracy: 0.35204392922513734

In [None]:
result = annotate_cells_with_genes(graphs, model, gene_id_to_index)

In [None]:
result

In [None]:
class MPNN(MessagePassing):
    def __init__(self, node_in_features, edge_in_features, out_features):
        super(MPNN, self).__init__(aggr='mean')  # 'mean' aggregation.
        self.message_mlp = Sequential(
            Linear(node_in_features + edge_in_features, 128),
            ReLU(),
            Linear(128, out_features)
        )
        self.update_mlp = Sequential(
            Linear(out_features, out_features),
            ReLU(),
            Linear(out_features, out_features)
        )

    def forward(self, x, edge_index, edge_attr):
        # x: Node features [N, node_in_features]
        # edge_index: Graph connectivity [2, E]
        # edge_attr: Edge attributes/features [E, edge_in_features]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: Input features of neighbors [E, node_in_features]
        # edge_attr: Edge attributes [E, edge_in_features]
        tmp = torch.cat([x_j, edge_attr], dim=-1)  # Concatenate node features with edge attributes
        return self.message_mlp(tmp)

    def update(self, aggr_out):
        # aggr_out: Aggregated messages [N, out_features]
        return self.update_mlp(aggr_out)
    
def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
    # Assumes output and target are the predicted and true scores, respectively
    weights = torch.ones_like(target)
    high_score_mask = target >= score_threshold
    weights[high_score_mask] = high_score_weight
    return ((output - target) ** 2 * weights).mean()

def generate_single_graph(sequencing, scores):
    # Load and preprocess sequencing data
    gene_df = pd.read_csv(sequencing)
    gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
    total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
    gene_df = gene_df.merge(total_reads_per_well, on='well_id')
    gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']

    # Load and preprocess cell score data
    cell_df = pd.read_csv(scores)
    cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})

    # Initialize mappings
    gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
    cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}

    # Initialize edge indices and attributes
    edge_index = []
    edge_attr = []

    # Associate each cell with all genes in the same well
    for well_id, group in gene_df.groupby('well_id'):
        if well_id in cell_df['well_id'].values:
            cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
            gene_indices = group['gene_id'].map(gene_id_to_index).values
            fractions = group['well_read_fraction'].values
            
            for cell_idx in cell_indices:
                for gene_idx, fraction in zip(gene_indices, fractions):
                    edge_index.append([cell_idx, gene_idx])
                    edge_attr.append([fraction])

    # Convert lists to PyTorch tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)

    # One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
    gene_features = torch.eye(len(gene_id_to_index))
    cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))

    # Combine features
    x = torch.cat([cell_features, gene_features], dim=0)

    # Create the graph data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)

    return data, gene_id_to_index, len(gene_id_to_index)

In [None]:
import pickle
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from collections import defaultdict
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import degree, add_self_loops, softmax
from torch_geometric.loader import DataLoader, NeighborSampler
from sklearn.metrics import mean_squared_error
from torch_geometric.nn import SAGEConv, global_mean_pool, Linear, TransformerConv, GCNConv, GATConv, MessagePassing
from torch import Tensor, nn
from torch_geometric.data import Batch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.nn import Linear, Module
import torch
import torch.nn.functional as F
from torch.nn import Linear, Module
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.inits import reset
from torch_geometric.nn.conv import MessagePassing

def collate(batch):
    data_list = [data for _, data in batch]
    return Batch.from_data_list(data_list)


def generate_well_graphs(sequencing, scores):
    # Load and preprocess sequencing data
    gene_df = pd.read_csv(sequencing)
    gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
    total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
    gene_df = gene_df.merge(total_reads_per_well, on='well_id')
    gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']

    # Load and preprocess cell score data
    cell_df = pd.read_csv(scores)
    cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})

    # Initialize mappings
    gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
    cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}

    # Initialize a dictionary to store edge information for each well subgraph
    wells_subgraphs = defaultdict(lambda: {'edge_index': [], 'edge_attr': []})

    # Associate each cell with all genes in the same well
    for well_id, group in gene_df.groupby('well_id'):
        if well_id in cell_df['well_id'].values:
            cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
            gene_indices = group['gene_id'].map(gene_id_to_index).values
            fractions = group['well_read_fraction'].values

            for cell_idx in cell_indices:
                for gene_idx, fraction in zip(gene_indices, fractions):
                    wells_subgraphs[well_id]['edge_index'].append([cell_idx, gene_idx])
                    wells_subgraphs[well_id]['edge_attr'].append([fraction])

    # Process well subgraphs into PyTorch Geometric Data objects
    well_data_list = []
    for well_id, subgraph in wells_subgraphs.items():
        edge_index = torch.tensor(subgraph['edge_index'], dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(subgraph['edge_attr'], dtype=torch.float)
        num_nodes = max(max(edge) for edge in subgraph['edge_index']) + 1
        x = torch.ones((num_nodes, 1))  # Feature matrix with a single feature set to 1 for each node

        # Retrieve cell scores for the current well
        cell_scores = cell_df[cell_df['well_id'] == well_id]['score'].values
        # Create a tensor for cell scores, ensuring the order matches that of the nodes in the graph
        y = torch.tensor(cell_scores, dtype=torch.float)
        
        subgraph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        well_data_list.append((well_id, subgraph_data))
    
    return well_data_list, gene_id_to_index, len(gene_id_to_index), cell_id_to_index

class CustomTransformerConv(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, beta=False, dropout=0.0, edge_dim=None):
        super().__init__(node_dim=0, aggr='add')  # Specify aggregation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.beta = beta
        self.dropout = dropout
        self.edge_dim = edge_dim

        # The linear layers for the multi-head attention mechanism
        self.lin_query = Linear(in_channels, heads * out_channels, bias=False)
        self.lin_key = Linear(in_channels, heads * out_channels, bias=False)
        self.lin_value = Linear(in_channels, heads * out_channels, bias=False)

        # Optional edge transformation
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)

        # Optional beta parameter for combining aggregation and skip connection
        if self.beta:
            self.lin_gate = torch.nn.Linear(in_channels + out_channels, 1, bias=True)
        
        # The final linear transformation that is applied to each node feature vector
        self.lin_out = Linear(heads * out_channels, out_channels, bias=True) if concat else Linear(out_channels, out_channels, bias=True)

        # For storing the attention weights
        self.att = None 

        # Initialize the parameters
        self.reset_parameters()

    def reset_parameters(self):
        # Reset the parameters here
        self.lin_query.reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim is not None:
            self.lin_edge.reset_parameters()
        if self.beta:
            self.lin_gate.reset_parameters()
        self.lin_out.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None):
        print(f"Input features shape: {x.shape}")
        query = self.lin_query(x).view(-1, self.heads, self.out_channels)
        query = self.lin_query(x)
        print(f"Query shape (pre-view): {query.shape}")
        key = self.lin_key(x).view(-1, self.heads, self.out_channels)
        value = self.lin_value(x).view(-1, self.heads, self.out_channels)
        
        # Propagate the messages
        out = self.propagate(edge_index, x=(query, key, value), edge_attr=edge_attr, size=None)
        
        # Reshape and concatenate head outputs if required
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
        
        # Apply root node transformation with skip connection if required
        if self.root_weight:
            out = out + self.lin_root(x[:out.size(0), :])
        
        return out

    def message(self, x_j, x_i, edge_attr, index, ptr, size_i):
        # Compute messages
        # This needs to be implemented based on your model's specifics
        query, key, value = x_i[0], x_j[1], x_j[2]
        # Compute the attention scores
        alpha = (query * key).sum(dim=-1) / self.scale
        alpha = softmax(alpha, index, ptr, size_i)
        
        # Apply attention scores to the values
        out = value * alpha.view(-1, self.heads, 1)
        return out.view(-1, self.heads * self.out_channels)


class GraphTransformer(torch.nn.Module):
    def __init__(self, num_node_features, dropout_rate=0.1):
        super(GraphTransformer, self).__init__()
        # Assuming you want to predict a single value per graph, adjust the out_channels as needed.
        num_heads = 4  # Example: 4 attention heads
        out_channels = 1  # Example: predicting a single score per graph
        self.conv1 = CustomTransformerConv(num_node_features, 128, heads=num_heads, dropout=dropout_rate, edge_dim=1)
        self.conv2 = CustomTransformerConv(128 * num_heads, 256, heads=num_heads, dropout=dropout_rate, edge_dim=1)
        self.lin = Linear(256 * num_heads, out_channels)  # Adjusted for a single output feature

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Here we call forward on the CustomTransformerConv.
        # Make sure edge_attr is only passed if you have edge features.
        # Adjust the head dimensions and any additional logic based on your architecture.

        x = F.relu(self.conv1(x, edge_index, edge_attr=edge_attr))
        # more layers...

        return x

def train_graph_network(graph_data_list, feature_size, model_path, batch_size=8, epochs=100, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GraphTransformer(num_node_features=feature_size).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss(reduction='mean')

    data_loader = TorchDataLoader(graph_data_list, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in data_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out.view(-1), data.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data_loader)}')
    
    torch.save(model.state_dict(), model_path)

In [None]:
train_graph_network(graph_data_list=graph_data,
                    feature_size=feature_size,
                    model_path='/home/olafsson/Desktop/gnn/model/pth',
                    batch_size=8,
                    epochs=100,
                    lr=0.001)

In [None]:
sequencing = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/sequencing.csv'
scores = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv_cell.csv'
graph_data, gene_id_to_index, feature_size, cell_id_to_index = generate_well_graphs(sequencing,scores)

In [None]:
model_save_path = '/home/olafsson/Desktop/gnn/mode.pth'
compute_gene_importance(model,
                        graph_data,
                        model_save_path,
                        n_permutations=10)

In [None]:
sequencing = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/sequencing.csv'
score = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv_cell.csv'

# Example loading step
gene_df = pd.read_csv(sequencing)
cell_df = pd.read_csv(score)

gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
gene_df = gene_df.drop(columns=['Unnamed: 0', 'plate', 'row', 'col', 'grna_seq', 'gene'])
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
gene_df = gene_df.drop(columns=['read_count', 'total_reads'])

cell_df = cell_df.rename(columns={"prcfo": "cell_id", "prc": "well_id", "pred": "score"})
cell_df = cell_df.drop(columns=['parasite_area', 'parasite_area', 'recruitment'])
display(gene_df)
display(cell_df)
feature_size = len(gene_df['gene_id'].unique())
feature_size

In [None]:
print(f'feature_size: {feature_size}')
graph_data, gene_id_to_index = generate_graph(gene_df,cell_df)
dict_file_path = '/home/olafsson/Desktop/gnn/dict.pth'

with open(dict_file_path, 'wb') as file:
    pickle.dump(gene_id_to_index, file)

In [None]:
'pip install torch-sparse -f https://data.pyg.org/whl/torch-2.2.1+cu121.html'
'pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cu121.html'
print(f'feature_size: {feature_size}')

dict_file_path = '/home/olafsson/Desktop/gnn/dict.pth'

with open(dict_file_path, 'wb') as file:
    pickle.dump(gene_id_to_index, file)

In [None]:
I first transect a library of gRNAs targeting ~1400 genes into Toxoplasma tachyzoites and grow the parasites under selection for 1 week. This generates a pooled population of mutant parasites, each parasite is missing one gene. I then seed HFF cells in 384 well plates and transfer 10 mutants on average to each well. These parasites grow for a few days to generate sub-pools of parasite populations consisting of on average 10 unique mutants. At this point i transfer mutants to corresponding wells in new 384 well plates, these plates have cells that the parasites will infect. I then fix, stain and image these new plates. The rest of the parasites in the original 384 well plates are sequenced so i know which mutants were present in each well. Single cell images are then cropped from each field of view classified by a CNN. I only include cells infected by one parasite. So at the end of the experiment i have infected cells infected by 1 mutant parasite with phenotype scores and i know the genes that are knocked out in the parasites in each well. I also know the relative abundance of each mutant in each well through the proportion of sequencing reads in each well. 