In [2]:
import os

## Set directory
os.chdir('/hpc/group/pbenfeylab/CheWei/CW_data/genesys')

import networkx as nx
from genesys_evaluate_v1 import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
import anndata

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## Conda Env genesys on DCC
print(torch.__version__)
print(sc.__version__) 

1.11.0
1.9.6


In [4]:
## Genes considered/used (shared among samples) 
gene_list = pd.read_csv('./gene_list_1108.csv')

## Load data

In [5]:
with open("./genesys_root_data.pkl", 'rb') as file_handle:
    data = pickle.load(file_handle)
    
batch_size = 2000
dataset = Root_Dataset(data['X_test'], data['y_test'])
loader = DataLoader(dataset,
                         batch_size = batch_size,
                         shuffle = True, drop_last=True)
train_dataset = Root_Dataset(data['X_train'], data['y_train'])
train_loader = DataLoader(train_dataset,
                         batch_size = batch_size,
                         shuffle = True, drop_last=True)

In [6]:
with open("./genesys_rswt_data.pkl", 'rb') as file_handle:
    data = pickle.load(file_handle)

X_all = np.vstack((data['X_train'],data['X_val'],data['X_test']))
y_all = pd.concat((data['y_train'],data['y_val'],data['y_test']))
unseen_dataset = Root_Dataset_NoQC(X_all, y_all)
unseen_loader = DataLoader(unseen_dataset,
                         batch_size = batch_size,
                         shuffle = True, drop_last=True)

In [7]:
input_size = data['X_train'].shape[1]
## 10 cell types 
output_size = 10
embedding_dim = 256
hidden_dim = 256
n_layers = 2
device = "cpu"
path = "./"

## Load trained GeneSys model (Evaluate)

In [8]:
model = ClassifierLSTM(input_size, output_size, embedding_dim, hidden_dim, n_layers).to(device)
model.load_state_dict(torch.load(path+"/workstation/genesys_model_trained_on_root_atlas_20240308_continue4.pth", map_location=torch.device('cpu')))
model = model
model.eval()

ClassifierLSTM(
  (fc1): Sequential(
    (0): Linear(in_features=17513, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): GaussianNoise()
  )
  (fc): Sequential(
    (0): ReLU()
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (b_to_z): DBlock(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
  )
  (bz2_infer_z1): DBlock(
    (fc1): Linear(in_features=1024, out_features=256, bias=True)
    (fc2): Linear(in_features=1024, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc

### Sample data (2000 cells)

In [9]:
with open("./genesys_root_data.pkl", 'rb') as file_handle:
    data = pickle.load(file_handle)
    
batch_size = 2000
dataset = Root_Dataset(data['X_test'], data['y_test'])
loader = DataLoader(dataset,
                         batch_size = batch_size,
                         shuffle = True, drop_last=True)
train_dataset = Root_Dataset(data['X_train'], data['y_train'])
train_loader = DataLoader(train_dataset,
                         batch_size = batch_size,
                         shuffle = True, drop_last=True)

In [10]:
classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium', 'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
class2num = {c: i for (i, c) in enumerate(classes)}
num2class = {i: c for (i, c) in enumerate(classes)}

In [11]:
sample = next(iter(loader))
xo = sample['x'].to(device)
y = sample['y'].to(device)
y_label = [num2class[i] for i in y.tolist()]

In [12]:
## 2000 cell type trajectories (11 dev stage) sampled, each stage has 17513 gene expression)
xo.shape

torch.Size([2000, 11, 17513])

In [13]:
## How many cell type trajectories are sampled for each cell type?
pd.Series(y_label).value_counts()

Xylem               231
Phloem              217
Pericycle           209
Columella           206
Atrichoblast        205
Endodermis          204
Trichoblast         195
Procambium          190
Cortex              182
Lateral Root Cap    161
dtype: int64

### GEP impact on development

In [14]:
model = ClassifierLSTM(input_size, output_size, embedding_dim, hidden_dim, n_layers).to(device)
model.load_state_dict(torch.load(path+"/workstation/genesys_model_trained_on_root_atlas_20240308_continue4.pth", map_location=torch.device('cpu')))
model = model
model.eval()

ClassifierLSTM(
  (fc1): Sequential(
    (0): Linear(in_features=17513, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): GaussianNoise()
  )
  (fc): Sequential(
    (0): ReLU()
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (b_to_z): DBlock(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
  )
  (bz2_infer_z1): DBlock(
    (fc1): Linear(in_features=1024, out_features=256, bias=True)
    (fc2): Linear(in_features=1024, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc

In [84]:
GEP_raw_score = pd.read_csv('../../cNMF/20250722_root_atlas_cnmf_parallel.gene_spectra_score.k_30.dt_0_1.txt', delim_whitespace=True)
GEP_raw_score = GEP_raw_score.T

In [85]:
GEP_raw_score

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,21,22,23,24,25,26,27,28,29,30
AT1G01010,-0.000918,-0.000449,-0.000904,0.000053,0.000019,-0.000163,-0.000538,-0.000526,0.002932,-6.148138e-05,...,-0.000003,-0.000611,-0.000151,0.000103,-2.267858e-04,3.507503e-04,-0.000168,-0.000149,-1.151191e-04,2.233542e-04
AT1G01020,-0.000036,0.000268,-0.000233,-0.000134,0.000773,-0.000113,0.000124,-0.000402,0.000288,-1.836733e-04,...,0.000245,-0.000026,-0.000146,0.000148,1.615548e-04,3.322189e-04,-0.000034,0.000142,7.106826e-05,-2.479525e-05
AT1G03987,-0.000006,-0.000004,0.000046,-0.000011,-0.000008,-0.000008,-0.000011,-0.000010,-0.000008,7.104151e-08,...,-0.000009,-0.000005,-0.000004,-0.000008,-1.398544e-06,9.708034e-07,-0.000003,-0.000007,-8.340389e-07,-8.712404e-07
AT1G01030,0.000060,-0.000126,0.000092,-0.000279,-0.000195,-0.000052,-0.000141,-0.000451,-0.000140,1.716998e-03,...,0.000403,-0.000163,-0.000090,0.000056,4.537205e-05,-1.224061e-04,-0.000063,-0.000004,8.596555e-06,-1.130477e-04
AT1G01040,-0.000328,-0.000093,-0.000133,0.000053,-0.000164,-0.000138,-0.000211,-0.000035,0.000131,2.383958e-04,...,-0.000222,0.000295,0.000155,-0.000132,6.112147e-06,1.942994e-04,-0.000200,-0.000077,-3.206050e-05,-4.143671e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AT4G04815,0.000095,-0.000015,-0.000030,-0.000005,-0.000008,-0.000005,-0.000021,-0.000006,-0.000006,-7.108223e-06,...,0.000083,-0.000005,-0.000003,-0.000021,8.958760e-06,-4.608114e-06,-0.000001,-0.000018,-1.998597e-07,-3.624372e-06
AT4G08485,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,...,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000e+00
AT4G06380,-0.000039,0.000053,0.000003,-0.000036,0.000006,-0.000024,-0.000067,-0.000024,-0.000027,-2.869888e-05,...,0.000100,-0.000013,-0.000002,0.000410,-4.087492e-07,-7.544947e-06,-0.000004,-0.000036,7.929812e-07,7.557334e-06
AT4G21030,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,...,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000e+00


In [86]:
# Find the GEP with the highest absolute loading for each gene
gep_assignments = GEP_raw_score.idxmax(axis=1)  # Get GEP with highest value
gep_scores = GEP_raw_score.lookup(GEP_raw_score.index, gep_assignments)  # Get actual values

  gep_scores = GEP_raw_score.lookup(GEP_raw_score.index, gep_assignments)  # Get actual values


In [87]:
# Add the assignment and score to the dataframe
GEP_score = pd.DataFrame({
    'Gene': GEP_raw_score.index,
    'GEP': gep_assignments,
    'Score': gep_scores
})

In [88]:
GEP_score

Unnamed: 0,Gene,GEP,Score
AT1G01010,AT1G01010,9,0.002932
AT1G01020,AT1G01020,5,0.000773
AT1G03987,AT1G03987,3,0.000046
AT1G01030,AT1G01030,10,0.001717
AT1G01040,AT1G01040,22,0.000295
...,...,...,...
AT4G04815,AT4G04815,1,0.000095
AT4G08485,AT4G08485,1,0.000000
AT4G06380,AT4G06380,24,0.000410
AT4G21030,AT4G21030,1,0.000000


In [89]:
# Rank genes within each GEP by score
GEP_score['Rank'] = GEP_score.groupby('GEP')['Score'].rank(method='first', ascending=False).astype(int)

In [90]:
# Reset index and merge back the original data
GEP_score = GEP_score.merge(
    GEP_raw_score.reset_index(), left_on='Gene', right_on='index'
).drop(columns='index')

In [91]:
# Rename columns to match second image
gep_columns = [f'GEP_{i+1}' for i in range(GEP_raw_score.shape[1])]
GEP_score.columns = ['Gene', 'GEP', 'Score', 'Rank'] + gep_columns

In [92]:
GEP_score

Unnamed: 0,Gene,GEP,Score,Rank,GEP_1,GEP_2,GEP_3,GEP_4,GEP_5,GEP_6,...,GEP_21,GEP_22,GEP_23,GEP_24,GEP_25,GEP_26,GEP_27,GEP_28,GEP_29,GEP_30
0,AT1G01010,9,0.002932,79,-0.000918,-0.000449,-0.000904,0.000053,0.000019,-0.000163,...,-0.000003,-0.000611,-0.000151,0.000103,-2.267858e-04,3.507503e-04,-0.000168,-0.000149,-1.151191e-04,2.233542e-04
1,AT1G01020,5,0.000773,381,-0.000036,0.000268,-0.000233,-0.000134,0.000773,-0.000113,...,0.000245,-0.000026,-0.000146,0.000148,1.615548e-04,3.322189e-04,-0.000034,0.000142,7.106826e-05,-2.479525e-05
2,AT1G03987,3,0.000046,1404,-0.000006,-0.000004,0.000046,-0.000011,-0.000008,-0.000008,...,-0.000009,-0.000005,-0.000004,-0.000008,-1.398544e-06,9.708034e-07,-0.000003,-0.000007,-8.340389e-07,-8.712404e-07
3,AT1G01030,10,0.001717,176,0.000060,-0.000126,0.000092,-0.000279,-0.000195,-0.000052,...,0.000403,-0.000163,-0.000090,0.000056,4.537205e-05,-1.224061e-04,-0.000063,-0.000004,8.596555e-06,-1.130477e-04
4,AT1G01040,22,0.000295,609,-0.000328,-0.000093,-0.000133,0.000053,-0.000164,-0.000138,...,-0.000222,0.000295,0.000155,-0.000132,6.112147e-06,1.942994e-04,-0.000200,-0.000077,-3.206050e-05,-4.143671e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28953,AT4G04815,1,0.000095,1133,0.000095,-0.000015,-0.000030,-0.000005,-0.000008,-0.000005,...,0.000083,-0.000005,-0.000003,-0.000021,8.958760e-06,-4.608114e-06,-0.000001,-0.000018,-1.998597e-07,-3.624372e-06
28954,AT4G08485,1,0.000000,1667,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000e+00
28955,AT4G06380,24,0.000410,444,-0.000039,0.000053,0.000003,-0.000036,0.000006,-0.000024,...,0.000100,-0.000013,-0.000002,0.000410,-4.087492e-07,-7.544947e-06,-0.000004,-0.000036,7.929812e-07,7.557334e-06
28956,AT4G21030,1,0.000000,1668,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000e+00


In [93]:
# Remove rows with Score < or = 0
GEP_score = GEP_score[GEP_score['Score'] > 0]

In [94]:
GEP_score = GEP_score.sort_values(by=['GEP', 'Rank'], ascending=[True, True])
GEP_score['GEP'] = 'GEP_' + GEP_score['GEP'].astype(str)

In [95]:
GEP_score

Unnamed: 0,Gene,GEP,Score,Rank,GEP_1,GEP_2,GEP_3,GEP_4,GEP_5,GEP_6,...,GEP_21,GEP_22,GEP_23,GEP_24,GEP_25,GEP_26,GEP_27,GEP_28,GEP_29,GEP_30
7836,AT5G10130,GEP_1,0.010744,1,0.010744,-8.624264e-04,0.000157,-0.000671,-0.000445,-0.000588,...,-0.000950,-0.000357,-2.891213e-04,-0.000539,-0.000401,-0.000339,-0.000197,0.000346,-1.135266e-05,0.000012
4287,AT1G54010,GEP_1,0.010547,2,0.010547,-8.080902e-04,-0.001929,-0.000414,-0.000364,-0.000358,...,-0.000526,-0.000225,-1.939703e-04,-0.000810,-0.000218,-0.000210,-0.000095,-0.000283,-7.620894e-06,-0.000056
25759,AT4G37160,GEP_1,0.009257,3,0.009257,-6.513025e-04,-0.001888,-0.000467,-0.000399,-0.000380,...,-0.000635,-0.000249,-2.003414e-04,-0.000524,-0.000201,-0.000230,-0.000150,0.000043,-4.744473e-06,-0.000042
21638,AT2G43610,GEP_1,0.009036,4,0.009036,-9.237657e-04,0.002393,-0.000711,-0.000435,-0.000756,...,0.000712,-0.000424,-3.099640e-04,-0.000453,-0.000017,-0.000355,-0.000304,-0.000296,-2.049222e-05,-0.000089
19534,AT2G23630,GEP_1,0.009002,5,0.009002,-7.681538e-04,-0.001475,-0.000317,-0.000298,-0.000274,...,-0.000402,-0.000171,-1.672692e-04,-0.000916,-0.000130,-0.000181,-0.000125,-0.000895,-7.551204e-06,-0.000057
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26832,AT1G07723,GEP_30,0.000045,213,-0.000011,-6.728264e-06,-0.000011,0.000033,-0.000005,-0.000003,...,0.000003,-0.000012,-5.033787e-07,0.000001,-0.000005,-0.000008,-0.000003,0.000014,2.250776e-07,0.000045
27032,AT2G24810,GEP_30,0.000039,214,-0.000018,-1.634842e-07,-0.000035,-0.000004,-0.000019,-0.000004,...,-0.000018,-0.000006,-7.665937e-06,-0.000015,0.000028,-0.000012,-0.000002,-0.000006,2.131939e-05,0.000039
662,AT1G07025,GEP_30,0.000036,215,0.000004,-1.288605e-05,-0.000033,0.000011,-0.000017,0.000017,...,0.000019,-0.000017,-4.986085e-06,-0.000007,-0.000008,-0.000016,-0.000006,-0.000011,2.507198e-05,0.000036
25432,AT4G34139,GEP_30,0.000027,216,-0.000014,-6.897318e-06,-0.000013,-0.000011,-0.000005,0.000013,...,-0.000010,-0.000005,-3.690098e-06,-0.000010,-0.000004,-0.000004,-0.000003,-0.000004,-3.789185e-07,0.000027


In [96]:
GEP_score.to_csv('../../cNMF/20250722_root_atlas_cnmf_parallel.gene_spectra_score.k_30.dt_0_1_annotated.csv', index=False)

In [97]:
gene_list.columns = ['Gene'] + ['features']

In [98]:
## Keep genes that can only be found in GeneSys trained model
GEP_score = gene_list.merge(GEP_score, on='Gene', how='left')

In [99]:
GEP_score = GEP_score.sort_values(by=['GEP', 'Rank'], ascending=[True, True])

In [100]:
GEP_score

Unnamed: 0,Gene,features,GEP,Score,Rank,GEP_1,GEP_2,GEP_3,GEP_4,GEP_5,...,GEP_21,GEP_22,GEP_23,GEP_24,GEP_25,GEP_26,GEP_27,GEP_28,GEP_29,GEP_30
47,AT1G54010,AT1G54010,GEP_1,0.010547,2,0.010547,-0.000808,-0.001929,-0.000414,-0.000364,...,-0.000526,-0.000225,-0.000194,-0.000810,-0.000218,-0.000210,-0.000095,-0.000283,-7.620894e-06,-0.000056
173,AT4G37160,AT4G37160,GEP_1,0.009257,3,0.009257,-0.000651,-0.001888,-0.000467,-0.000399,...,-0.000635,-0.000249,-0.000200,-0.000524,-0.000201,-0.000230,-0.000150,0.000043,-4.744473e-06,-0.000042
60,AT1G28290,AT1G28290,GEP_1,0.008768,7,0.008768,-0.000347,-0.002265,-0.000628,-0.000731,...,-0.000101,-0.000306,-0.000265,-0.000187,-0.000160,-0.000297,-0.000152,0.000113,3.302496e-06,-0.000028
1169,AT1G54030,AT1G54030,GEP_1,0.008284,8,0.008284,-0.000294,-0.000343,-0.000886,-0.000621,...,-0.000760,-0.000422,-0.000345,-0.000187,-0.000225,-0.000141,-0.000198,0.000247,1.680429e-05,0.000029
630,AT1G80240,AT1G80240,GEP_1,0.007489,9,0.007489,-0.000836,-0.001939,-0.000661,-0.000159,...,-0.000143,0.000062,-0.000259,-0.000690,0.001121,-0.000258,-0.000256,-0.000132,-6.468640e-06,-0.000384
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13000,AT5G37410,AT5G37410,GEP_9,0.000114,899,-0.000087,0.000027,-0.000089,0.000047,0.000009,...,0.000085,-0.000037,-0.000024,-0.000036,-0.000012,-0.000019,-0.000020,0.000025,-4.075121e-06,-0.000007
14642,AT5G37050,AT5G37050,GEP_9,0.000112,900,-0.000039,0.000066,-0.000061,-0.000042,0.000043,...,0.000061,-0.000008,-0.000030,-0.000115,-0.000027,0.000005,-0.000024,-0.000046,-1.079057e-05,-0.000012
9436,AT5G61920,AT5G61920,GEP_9,0.000104,905,0.000060,-0.000022,0.000021,0.000003,-0.000011,...,0.000073,-0.000056,0.000032,0.000038,0.000038,-0.000027,0.000018,-0.000068,-8.103302e-06,0.000037
7354,AT5G14350,AT5G14350,GEP_9,0.000096,915,-0.000057,0.000030,-0.000069,0.000014,-0.000032,...,-0.000030,-0.000028,-0.000015,-0.000039,-0.000019,-0.000006,-0.000016,-0.000002,-9.798316e-06,0.000002


In [101]:
# Create a list of top genes for each GEP_ column
GEP_list = []
n_top = 50  # Number of top genes to select
# Loop through each GEP_ column (GEP_1 to GEP_31)
for i in range(1, 31):  # 1 to 30
    col = 'Rank'
    gep = f'GEP_{i}'
    # Sort by the current GEP_ column in descending order, select top 20, and get gene_IDs
    top_genes = GEP_score[GEP_score['GEP'] == gep].sort_values(by=col, ascending=True).head(n_top)['Gene'].tolist()
    GEP_list.append(top_genes)
# Name the list elements (equivalent to names(GEP_list) in R)
GEP_dict = {f'GEP_{i}': genes for i, genes in enumerate(GEP_list, 1)}

In [102]:
{key: len(value) for key, value in GEP_dict.items()}

{'GEP_1': 50,
 'GEP_2': 50,
 'GEP_3': 50,
 'GEP_4': 50,
 'GEP_5': 50,
 'GEP_6': 50,
 'GEP_7': 50,
 'GEP_8': 50,
 'GEP_9': 50,
 'GEP_10': 50,
 'GEP_11': 50,
 'GEP_12': 50,
 'GEP_13': 50,
 'GEP_14': 50,
 'GEP_15': 50,
 'GEP_16': 50,
 'GEP_17': 50,
 'GEP_18': 50,
 'GEP_19': 50,
 'GEP_20': 50,
 'GEP_21': 50,
 'GEP_22': 50,
 'GEP_23': 50,
 'GEP_24': 50,
 'GEP_25': 50,
 'GEP_26': 50,
 'GEP_27': 50,
 'GEP_28': 50,
 'GEP_29': 50,
 'GEP_30': 49}

## Define functions

In [104]:
def recovery(matched_idx, output_dir_and_file_name):
    y_pred, y_true = [], []
    
    with torch.no_grad():
        for sample in loader:
            x = sample['x'].to(device)
            
            # remove/keep genes
            to_keep = x[:,:,matched_idx]
            # remove all gene expression
            x[:,:,:]=x[:,:,:].zero_()
            # add back the gene expression for selected genes
            x[:,:,matched_idx]= to_keep
            
            x10 = torch.stack([x[:,0,:],x[:,1,:],x[:,2,:],x[:,3,:],x[:,4,:],x[:,5,:],x[:,6,:],x[:,7,:],x[:,8,:],x[:,9,:],x[:,10,:]],dim=1)
            y = sample['y'].to(device)
            y_true.append(y.cpu().detach().numpy())
            test_h = model.init_hidden(batch_size)
            p, pred_h = model.predict_proba(x10, test_h, 10)
            y_pred.append(p.cpu().detach().numpy())
    
            
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    # Compute overall recovery
    overall_recovery = (y_true == np.argmax(y_pred, axis=1)).mean()
    
    # Store results in a dictionary
    results = {'Celltype': ['Overall'], 'Recovery': [overall_recovery]}
    
    # Define cell types
    classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium',
               'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
    
    # Compute recovery per cell type
    for ct in range(10):
        idx = np.where(y_true == ct)
        recovery = (y_true[idx] == np.argmax(y_pred, axis=1)[idx]).mean()
        results['Celltype'].append(classes[ct])
        results['Recovery'].append(recovery)
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    print(df)
    
    # Save to CSV
    df.to_csv(output_dir_and_file_name, index=False)

In [105]:
def t_recovery(matched_idx, time):
    y_pred, y_true = [], []
    
    with torch.no_grad():
        for sample in loader:
            x = sample['x'].to(device)
            
            # remove/keep genes
            to_keep = x[:,:,matched_idx]
            # remove all gene expression
            x[:,:,:]=x[:,:,:].zero_()
            # add back the gene expression for selected genes
            x[:,:,matched_idx]= to_keep
            
            x10 = torch.stack([x[:,0,:],x[:,1,:],x[:,2,:],x[:,3,:],x[:,4,:],x[:,5,:],x[:,6,:],x[:,7,:],x[:,8,:],x[:,9,:],x[:,10,:]],dim=1)
            y = sample['y'].to(device)
            y_true.append(y.cpu().detach().numpy())
            test_h = model.init_hidden(batch_size)
            p, pred_h = model.predict_proba(x10, test_h, time)
            y_pred.append(p.cpu().detach().numpy())
    
            
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    # Compute overall recovery
    overall_recovery = (y_true == np.argmax(y_pred, axis=1)).mean()
    
    # Store results in a dictionary
    results = {'Celltype': ['Overall'], 'Recovery_t' + str(int(time)): [overall_recovery]}
    
    # Define cell types
    classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium',
               'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
    
    # Compute recovery per cell type
    for ct in range(10):
        idx = np.where(y_true == ct)
        recovery = (y_true[idx] == np.argmax(y_pred, axis=1)[idx]).mean()
        results['Celltype'].append(classes[ct])
        results['Recovery_t' + str(int(time))].append(recovery)
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    #print(df)
    
    # Return df
    return df

In [106]:
def gof(matched_idx):
    # Find indices where query_list elements match
    #matched_idx = [i for i, x in enumerate(gene_list['features']) if x in set(GEP_dict[GEP])]
    matched_idx = matched_idx
    #len(gene_list['features'][matched_idx])
    
    #Prepare 
    xm = xo.clone()
    # remove/keep genes
    to_keep = xm[:,:,matched_idx]
    xm[:,:,:]=xm[:,:,:].zero_()
    xm[:,:,matched_idx]= to_keep
    ## Provide entire tracks
    x = torch.stack([xm[:,0,:],xm[:,1,:],xm[:,2,:],xm[:,3,:],xm[:,4,:],xm[:,5,:],xm[:,6,:],xm[:,7,:],xm[:,8,:],xm[:,9,:],xm[:,10,:]],dim=1)
    ## Provide the first two bins 
    #x = torch.stack([xm[:,0,:],xm[:,1,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:],xm[:,2,:]],dim=1)

    ## Initialize hidden state
    pred_h = model.init_hidden(batch_size)
    
    # t0 and t1 prediction based on data from the first two time points
    t0 = model.generate_current(x, pred_h, 0)
    
    ## predict t1 label
    y0, pred_h = model.predict_proba(x, pred_h, 0)
    y0 = [num2class[i] for i in np.argmax(y0.cpu().detach().numpy(), axis=1)]
    
    t1 = model.generate_next(x, pred_h, 0)
    
    y1, pred_h = model.predict_proba(x, pred_h, 1)
    y1 = [num2class[i] for i in np.argmax(y1.cpu().detach().numpy(), axis=1)]

    t2 = model.generate_next(x, pred_h, 1)
    
    y2, pred_h = model.predict_proba(x, pred_h, 2)
    y2 = [num2class[i] for i in np.argmax(y2.cpu().detach().numpy(), axis=1)]
    
    t3 = model.generate_next(x, pred_h, 2)
    
    y3, pred_h = model.predict_proba(x, pred_h, 3)
    y3 = [num2class[i] for i in np.argmax(y3.cpu().detach().numpy(), axis=1)]
    
    t4 = model.generate_next(x, pred_h, 3)
    
    y4, pred_h = model.predict_proba(x, pred_h, 4)
    y4 = [num2class[i] for i in np.argmax(y4.cpu().detach().numpy(), axis=1)]

    t5 = model.generate_next(x, pred_h, 4)
    
    y5, pred_h = model.predict_proba(x, pred_h, 5)
    y5 = [num2class[i] for i in np.argmax(y5.cpu().detach().numpy(), axis=1)]
    
    t6 = model.generate_next(x, pred_h, 5)
    
    y6, pred_h = model.predict_proba(x, pred_h, 6)
    y6 = [num2class[i] for i in np.argmax(y6.cpu().detach().numpy(), axis=1)]
    
    t7 = model.generate_next(x, pred_h, 6)

    y7, pred_h = model.predict_proba(x, pred_h, 7)
    y7 = [num2class[i] for i in np.argmax(y7.cpu().detach().numpy(), axis=1)]
    
    t8 = model.generate_next(x, pred_h, 7)
    
    y8, pred_h = model.predict_proba(x, pred_h, 8)
    y8 = [num2class[i] for i in np.argmax(y8.cpu().detach().numpy(), axis=1)]
    
    t9 = model.generate_next(x, pred_h, 8)

    y9, pred_h = model.predict_proba(x, pred_h, 9)
    y9 = [num2class[i] for i in np.argmax(y9.cpu().detach().numpy(), axis=1)]
    
    t10 = model.generate_next(x, pred_h, 9)
    
    y10, pred_h = model.predict_proba(x, pred_h, 10)
    y10 = [num2class[i] for i in np.argmax(y10.cpu().detach().numpy(), axis=1)]

    t0 = t0.to(device).detach().numpy()
    t1 = t1.to(device).detach().numpy()
    t2 = t2.to(device).detach().numpy()
    t3 = t3.to(device).detach().numpy()
    t4 = t4.to(device).detach().numpy()
    t5 = t5.to(device).detach().numpy()
    t6 = t6.to(device).detach().numpy()
    t7 = t7.to(device).detach().numpy()
    t8 = t8.to(device).detach().numpy()
    t9 = t9.to(device).detach().numpy()
    t10 = t10.to(device).detach().numpy()
    
    pred_X = np.vstack((t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10))
    pred_Y = np.concatenate((['Stem Cell']*batch_size, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10)).tolist()
    pred_T = ['t0']*batch_size + ['t1']*batch_size + ['t2']*batch_size + ['t3']*batch_size + ['t4']*batch_size + ['t5']*batch_size + ['t6']*batch_size + ['t7']*batch_size + ['t8']*batch_size+ ['t9']*batch_size + ['t10']*batch_size

    # Create AnnData object
    cell_names = [f"Cell_{i}" for i in range(pred_X.shape[0])]
    adata = anndata.AnnData(
        X=pred_X,
        obs=pd.DataFrame(index=cell_names),  # Cell annotations
        var=pd.DataFrame(index=gene_list['features'])   # Gene annotations
    )
    adata.obs['celltype'] = pred_Y
    adata.obs['timebin'] = pred_T

    sc.pp.scale(adata, max_value=10)
    sc.tl.pca(adata, svd_solver='arpack')
    sc.pp.neighbors(adata, n_neighbors=30, n_pcs=50)
    sc.tl.leiden(adata)
    sc.tl.paga(adata)
    sc.pl.paga(adata) 

    sc.tl.umap(adata, init_pos='paga')
    adata.uns['celltype_colors'] = np.array([ "#9400d3","#5ab953", "#bfef45", "#008080", "#21B6A8", "#82b6ff", "#0000FF","#e6194b", "#9a6324", "#ffe119", "#ff9900", "#ffd4e3", "#9a6324", "#ddaa6f"], dtype=object)
    adata.obs['celltype'] = pd.Categorical(adata.obs['celltype'], categories=["Stem Cell","Columella", "Lateral Root Cap", "Atrichoblast", "Trichoblast", "Cortex", "Endodermis", "Phloem", "Xylem", "Procambium","Pericycle"])
    sc.pl.umap(adata, color=['celltype'])

    adata.uns['timebin_colors'] = np.array([ '#5E4FA2', '#3288BD', '#66C2A5', '#ABDDA4', '#E6F598', '#FFFFBF', '#FEE08B', '#FDAE61', '#F46D43', '#D53E4F','#9E0142'])
    sc.pl.umap(adata, color=['timebin'])

    ## AT1G79840
    sc.pl.umap(adata, color='AT1G79840', title='AT1G79840 (GL2, Atrichoblast)')
    ## AT5G49270
    sc.pl.umap(adata, color='AT5G49270', title='AT5G49270 (COBL9, Trichoblast)')
    sc.pl.umap(adata, color='AT1G09750', title='AT1G09750 (CORTEX, Cortex)')
    sc.pl.umap(adata, color='AT5G57620', title='AT5G57620 (MYB36, Endodermis)')
    sc.pl.umap(adata, color='AT1G79430', title='AT1G79430 (APL, Phloem)')
    sc.pl.umap(adata, color='AT1G71930', title='AT1G71930 (VND7, Xylem)')
    #sc.pl.umap(adata, color='AT4G37650', title='AT4G37650 (SHR)')    

    return adata

## TFs

In [107]:
TF = pd.read_csv("../../CW_data/celloracle/Kay_TF_thalemine_annotations.csv")

In [108]:
TF.columns = ['Gene'] + ['Name'] + ['Description']

In [109]:
TF

Unnamed: 0,Gene,Name,Description
0,AT1G01010,NAC001,NAC domain containing protein 1
1,AT1G01030,NGA3,AP2/B3-like transcriptional factor family protein
2,AT1G01060,LHY,Homeodomain-like superfamily protein
3,AT1G01250,AT1G01250,Integrase-type DNA-binding superfamily protein
4,AT1G01260,AT1G01260,basic helix-loop-helix (bHLH) DNA-binding supe...
...,...,...,...
2479,AT5G67430,AT5G67430,Acyl-CoA N-acyltransferases (NAT) superfamily ...
2480,AT5G67450,ZF1,zinc-finger protein 1
2481,AT5G67480,BT4,BTB and TAZ domain protein 4
2482,AT5G67580,TRB2,Homeodomain-like/winged-helix DNA-binding fami...


## Top 30

In [110]:
# Create a list of top genes for each GEP_ column
GEP_list = []
n_top = 30  # Number of top genes to select
# Loop through each GEP_ column (GEP_1 to GEP_31)
for i in range(1, 31):  # 1 to 30
    col = 'Rank'
    gep = f'GEP_{i}'
    # Sort by the current GEP_ column in descending order, select top 20, and get gene_IDs
    top_genes = GEP_score[GEP_score['GEP'] == gep].sort_values(by=col, ascending=True).head(n_top)['Gene'].tolist()
    GEP_list.append(top_genes)
# Name the list elements (equivalent to names(GEP_list) in R)
GEP_dict = {f'GEP_{i}': genes for i, genes in enumerate(GEP_list, 1)}

##### GEP recovery

In [111]:
for g in GEP_dict.keys():
    matched_idx = [i for i, x in enumerate(gene_list['features']) if x in set(GEP_dict[g])]
    print(len(matched_idx))
    recovery(matched_idx, "./" + g + "_celltype_recovery.csv")

30
            Celltype  Recovery
0            Overall  0.300909
1          Columella  0.062072
2   Lateral Root Cap  0.999523
3             Phloem  0.000000
4              Xylem  0.000000
5         Procambium  0.000000
6          Pericycle  0.000000
7         Endodermis  0.000000
8             Cortex  0.371507
9       Atrichoblast  0.678522
10       Trichoblast  0.975023
30
            Celltype  Recovery
0            Overall  0.101273
1          Columella  0.000000
2   Lateral Root Cap  0.000000
3             Phloem  0.000000
4              Xylem  0.000000
5         Procambium  0.000000
6          Pericycle  0.000000
7         Endodermis  0.000000
8             Cortex  0.000000
9       Atrichoblast  0.000000
10       Trichoblast  1.000000
30
            Celltype  Recovery
0            Overall  0.198091
1          Columella  1.000000
2   Lateral Root Cap  0.000000
3             Phloem  0.000000
4              Xylem  0.000000
5         Procambium  0.000000
6          Pericycle  0.000000

In [112]:
import glob
# Locate all recovery CSV files
file_list = sorted(glob.glob("./GEP_*_celltype_recovery.csv"))

In [113]:
# Read and merge
merged_df = None

for file in file_list:
    # Extract GEP name from filename (e.g., 'GEP_1')
    gep_name = os.path.basename(file).split('_celltype')[0]
    
    # Read the CSV
    df = pd.read_csv(file)
    
    # Rename "Recovery" column to current GEP
    df = df.rename(columns={'Recovery': gep_name})
    
    if merged_df is None:
        merged_df = df
    else:
        merged_df = pd.merge(merged_df, df, on='Celltype', how='outer')

# Optional: Set Celltype as index
merged_df = merged_df.set_index('Celltype')

In [116]:
# Sort columns by GEP number if desired
merged_df = merged_df[sorted(merged_df.columns, key=lambda x: int(x.split('_')[1]))]

In [121]:
# Define your custom order
custom_order = [
    "Xylem", "Phloem", "Procambium", "Pericycle", "Endodermis",
    "Cortex", "Trichoblast", "Atrichoblast", "Lateral Root Cap",
    "Columella", "Overall"
]

# Make sure index is Celltype
merged_df.index = pd.Categorical(merged_df.index, categories=custom_order, ordered=True)

# Sort by the categorical index
merged_df = merged_df.sort_index()

In [122]:
merged_df

Unnamed: 0,GEP_1,GEP_2,GEP_3,GEP_4,GEP_5,GEP_6,GEP_7,GEP_8,GEP_9,GEP_10,...,GEP_21,GEP_22,GEP_23,GEP_24,GEP_25,GEP_26,GEP_27,GEP_28,GEP_29,GEP_30
Xylem,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.726437,1.0,0.0,0.0,0.0,0.996438,0.0,0.0,0.0
Phloem,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.167194,0.0
Procambium,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.132373,0.0,...,0.0,0.002795,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Pericycle,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.986642,0.005018,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Endodermis,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.011845,0.473397,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Cortex,0.371507,0.0,0.0,0.0,0.0,1.0,0.113225,0.0,0.004269,0.0,...,0.0,0.0,0.0,0.0,0.122348,0.0,0.0,0.0,0.0,0.0
Trichoblast,0.975023,1.0,1.0,1.0,1.0,0.999532,0.990359,1.0,0.999542,0.999089,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
Atrichoblast,0.678522,0.0,0.0,0.0,0.0,0.0,0.316369,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Lateral Root Cap,0.999523,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Columella,0.062072,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.057487,0.0,0.0,0.0,0.0,0.0


In [123]:
merged_df.to_csv('./Merged_GEP_celltype_recovery.csv', index=True)

## Test GEP 17 and 20 (top30)

In [124]:
# Without GEP17 and GEP20 
matched_idx = [i for i, x in enumerate(gene_list['features']) if x in set(GEP_dict['GEP_1'] + GEP_dict['GEP_3'] + GEP_dict['GEP_4'] 
                                                                          + GEP_dict['GEP_6'] + GEP_dict['GEP_8'] + GEP_dict['GEP_10'] 
                                                                          + GEP_dict['GEP_11'] + GEP_dict['GEP_12']+ GEP_dict['GEP_13']
                                                                          + GEP_dict['GEP_16']
                                                                          + GEP_dict['GEP_19'] + GEP_dict['GEP_22']+ GEP_dict['GEP_23']
                                                                          + GEP_dict['GEP_26'] + GEP_dict['GEP_27'])]

In [125]:
len(matched_idx)

450

In [126]:
## Without 17 and 20
GEP_com_wo_17_20 = recovery(matched_idx, "./GEP_com_wo_17_20celltype_recovery.csv")

            Celltype  Recovery
0            Overall  0.897364
1          Columella  1.000000
2   Lateral Root Cap  1.000000
3             Phloem  1.000000
4              Xylem  1.000000
5         Procambium  1.000000
6          Pericycle  0.977941
7         Endodermis  1.000000
8             Cortex  1.000000
9       Atrichoblast  1.000000
10       Trichoblast  0.003162


In [127]:
# With 17
matched_idx = [i for i, x in enumerate(gene_list['features']) if x in set(GEP_dict['GEP_1'] + GEP_dict['GEP_3'] + GEP_dict['GEP_4'] 
                                                                          + GEP_dict['GEP_6'] + GEP_dict['GEP_8'] + GEP_dict['GEP_10'] 
                                                                          + GEP_dict['GEP_11'] + GEP_dict['GEP_12']+ GEP_dict['GEP_13']
                                                                          + GEP_dict['GEP_16'] + GEP_dict['GEP_17']
                                                                          + GEP_dict['GEP_19'] + GEP_dict['GEP_22']+ GEP_dict['GEP_23']
                                                                          + GEP_dict['GEP_26'] + GEP_dict['GEP_27'])]

In [128]:
len(matched_idx)

480

In [129]:
## With 17
GEP_com_with_17 = recovery(matched_idx, "./GEP_com_with_17_celltype_recovery.csv")

            Celltype  Recovery
0            Overall  0.996227
1          Columella  1.000000
2   Lateral Root Cap  1.000000
3             Phloem  1.000000
4              Xylem  1.000000
5         Procambium  1.000000
6          Pericycle  0.973944
7         Endodermis  1.000000
8             Cortex  0.999543
9       Atrichoblast  1.000000
10       Trichoblast  0.989218


In [130]:
# With 20
matched_idx = [i for i, x in enumerate(gene_list['features']) if x in set(GEP_dict['GEP_1'] + GEP_dict['GEP_3'] + GEP_dict['GEP_4'] 
                                                                          + GEP_dict['GEP_6'] + GEP_dict['GEP_8'] + GEP_dict['GEP_10'] 
                                                                          + GEP_dict['GEP_11'] + GEP_dict['GEP_12']+ GEP_dict['GEP_13']
                                                                          + GEP_dict['GEP_16'] + GEP_dict['GEP_20']
                                                                          + GEP_dict['GEP_19'] + GEP_dict['GEP_22']+ GEP_dict['GEP_23']
                                                                          + GEP_dict['GEP_26'] + GEP_dict['GEP_27'])]

In [131]:
len(matched_idx)

480

In [132]:
## With 20
GEP_com_with_20 = recovery(matched_idx, "./GEP_com_with_20_celltype_recovery.csv")

            Celltype  Recovery
0            Overall  0.993409
1          Columella  1.000000
2   Lateral Root Cap  1.000000
3             Phloem  1.000000
4              Xylem  1.000000
5         Procambium  1.000000
6          Pericycle  0.986995
7         Endodermis  1.000000
8             Cortex  1.000000
9       Atrichoblast  1.000000
10       Trichoblast  0.947202
