### Imports and Set Up

In [2]:
import scanpy as sc
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score 
import numpy as np
import pandas as pd
import time
import modelMLP 
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

In [3]:
sc.settings.verbosity
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')

### Prepping Data

##### encoding posistion

In [1]:
import pyensembl
from tqdm import tqdm

data = pyensembl.EnsemblRelease(109)
data.download()
data.index()

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /Users/steveyin/Library/Caches/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle


In [None]:
adata =  sc.read_h5ad("data/Norman_2019.h5ad")  # replace with your path

#export data from sc.read
ddata = adata.X.toarray()
labels = adata.obs['perturbation_name'].to_numpy()
parsed_labels = [p.split('+') if p != 'control' else [] for p in labels]
mlb = MultiLabelBinarizer()
labels_int = mlb.fit_transform(parsed_labels)

In [6]:
chromosome_lengths = {
    #from https://www.ncbi.nlm.nih.gov/grc/human/data
    # Autosomes
    '1': 231223641,
    '2': 240863511,
    '3': 198255541,
    '4': 189962376,
    '5': 181358067,
    '6': 170078524,
    '7': 158970135,
    '8': 144768136,
    '9': 122084564,
    '10': 133263006,
    '11': 134634058,
    '12': 133137821,
    '13': 97983128,
    '14': 91660769,
    '15': 85089576,
    '16': 83378703,
    '17': 83481871,
    '18': 80089650,
    '19': 58440758,
    '20': 63944268,
    '21': 40088623,
    '22': 40181019,
    # Sex Chromosomes
    'X': 154893034,
    'Y': 26452288
}

chr_names = [str(i) for i in range(1, 23)] + ['X', 'Y']
chromosome_map = {name: i for i, name in enumerate(chr_names)}

In [65]:
gene_to_pos_encoding = {}
zero_encoding = np.zeros(24)

for gene_name in tqdm(adata.var_names):
    try:
        gene = data.genes_by_name(gene_name)
        encoding = np.zeros(24)
        contig = gene[0].contig
        if pd.notna(contig) and contig in chromosome_map:
            chr_index = chromosome_map[contig]
            chr_length = chromosome_lengths.get(contig, 0)
            if chr_length > 0:
                encoding[chr_index] = gene[0].start / chr_length
        gene_to_pos_encoding[gene_name] = encoding
    except (ValueError, KeyError):
        continue 



final_positional_encoding = []

for perturbation_string in tqdm(adata.obs['perturbation_name']):
    if perturbation_string == 'control':
        final_positional_encoding.append(zero_encoding)
        continue

    gene_names_in_pert = perturbation_string.split('+')
    
    encodings_for_this_cell = []
    for gene_name in gene_names_in_pert:
        encoding = gene_to_pos_encoding.get(gene_name, zero_encoding)
        encodings_for_this_cell.append(encoding)

    averaged_encoding = np.mean(encodings_for_this_cell, axis=0)
    final_positional_encoding.append(averaged_encoding)

final_positional_encoding = np.array(final_positional_encoding)

print(f"\nSUCCESS: Final matrix generated with shape {final_positional_encoding.shape}")

100%|██████████| 19018/19018 [00:00<00:00, 232113.72it/s]
100%|██████████| 111255/111255 [00:00<00:00, 310936.89it/s]


SUCCESS: Final matrix generated with shape (111255, 24)





##### creating transcriptome encoding

##### concatenating representations

In [None]:
'''
ddata = adata.X.toarray()
ddata_reshaped = ddata[:, :, np.newaxis]
n_cells = ddata.shape[0]
positional_encodings_expanded = np.tile(positional_encodings, (n_cells, 1, 1))
X_sequence_features = np.concatenate([ddata_reshaped, positional_encodings_expanded], axis=2)
print("Shape of the new feature matrix for a sequence model:", X_sequence_features.shape)
'''

In [38]:
adata.obsp['distances']

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 1557570 stored elements and shape (111255, 111255)>

In [39]:
adata.obs

Unnamed: 0_level_0,guide_identity,read_count,UMI_count,coverage,gemgroup,good_coverage,number_of_cells,guide_AHR,guide_ARID1A,guide_ARRDC3,...,n_genes,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,leiden,perturbation_name,perturbation_type,perturbation_value,perturbation_unit
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGAAGC-1,NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0,1252,67,18.686567,1,True,2,0,0,0,...,4108,4108,19413.0,1327.0,6.835625,10,control,genetic,,
AAACCTGAGGCATGTG-1,TSC22D1_NegCtrl0__TSC22D1_NegCtrl0,2151,104,20.682692,1,True,1,0,0,0,...,3142,3142,13474.0,962.0,7.139676,3,TSC22D1,genetic,,
AAACCTGAGGCCCTTG-1,KLF1_MAP2K6__KLF1_MAP2K6,1037,59,17.576271,1,True,1,0,0,0,...,4229,4229,23228.0,1548.0,6.664371,7,KLF1+MAP2K6,genetic,,
AAACCTGCACGAAGCA-1,NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0,958,39,24.564103,1,True,1,0,0,0,...,2114,2114,6842.0,523.0,7.643963,2,control,genetic,,
AAACCTGCAGACGTAG-1,CEBPE_RUNX1T1__CEBPE_RUNX1T1,244,14,17.428571,1,True,1,0,0,0,...,2753,2753,9130.0,893.0,9.780942,10,CEBPE+RUNX1T1,genetic,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCATCAGTACGT-8,FOXA3_NegCtrl0__FOXA3_NegCtrl0,2068,95,21.768421,8,True,1,0,0,0,...,3305,3305,14727.0,898.0,6.097644,3,FOXA3,genetic,,
TTTGTCATCCACTCCA-8,CELF2_NegCtrl0__CELF2_NegCtrl0,829,33,25.121212,8,True,1,0,0,0,...,2842,2842,9750.0,566.0,5.805128,10,CELF2,genetic,,
TTTGTCATCCCAACGG-8,BCORL1_NegCtrl0__BCORL1_NegCtrl0,136,9,15.111111,8,True,1,0,0,0,...,2824,2824,8670.0,490.0,5.651672,4,BCORL1,genetic,,
TTTGTCATCCTCCTAG-8,ZBTB10_PTPN12__ZBTB10_PTPN12,1254,59,21.254237,8,True,3,0,0,0,...,5180,5179,29247.0,1551.0,5.303108,5,PTPN12+ZBTB10,genetic,,


In [42]:
list(gene_info_df['gene_name'])


gene = data.genes_by_name('TSC22D1')
gene

[Gene(gene_id='ENSG00000102804', gene_name='TSC22D1', biotype='protein_coding', contig='13', start=44432143, end=44577147, strand='-', genome='GRCh38')]

In [None]:
gene_to_pos_encoding = {
    gene: encoding for gene, encoding in zip(gene_names, positional_encodings)
}

gene_to_pos_encoding

{'ENSG00000243485': array([0.00012782, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ]),
 'ENSG00000238009': array([0.00038618, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ]),
 'ENSG00000279457': array([0.00080103, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0. 

In [None]:
# This is the list of perturbation names for each cell, in order
perturbation_labels = adata.obs['perturbation_name'].tolist()

# Create a list to store the new feature vector for each cell
cell_perturbation_encodings = []

# Define a zero vector for controls or missing genes
zero_encoding = np.zeros(24)

print("Generating positional features for each cell...")
for pert_name in tqdm(perturbation_labels):
    if pert_name == 'control':
        # Controls have no specific gene location, so we use a zero vector
        cell_perturbation_encodings.append(zero_encoding)
    else:
        # Split the name in case of combinations like 'AHR+FEV'
        genes_in_pert = pert_name.split('+')
        
        # Look up the encoding for each gene in the perturbation
        encodings_for_pert = [gene_to_pos_encoding.get(g, zero_encoding) for g in genes_in_pert]
        
        # Average the vectors for the combination
        # For a single gene, this is just the gene's own vector
        # np.mean() correctly handles the list of numpy arrays
        averaged_encoding = np.mean(encodings_for_pert, axis=0)
        cell_perturbation_encodings.append(averaged_encoding)

# Convert the list of vectors into a single NumPy matrix
cell_perturbation_encodings = np.array(cell_perturbation_encodings)

print("\nShape of the new per-cell positional feature matrix:", cell_perturbation_encodings.shape)
# Expected output: (111255, 24)

In [None]:
#split data
X_train, X_test, y_train, y_test = train_test_split(
    ddata, 
    labels_int, 
    test_size=0.2, 
    random_state=67, #SIX SEVEENNNNNNNNNN
    #stratify=labels_int
)

##### Prepping Model

In [None]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from modelTRAN import TransMLP

input_size = X_train.shape[1] # Number of genes
num_classes = labels_int[0].size
learning_rate = 0.00026
num_epochs = 25

model = modelMLP.MLP(input_size=input_size, num_classes=num_classes)
criterion = nn.BCEWithLogitsLoss() # Best for multi-class classification
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay= 2e-6)



