## 1. Methylation

In [None]:
import pandas as pd

In [None]:
raw_methy_df = pd.read_csv('./data/raw_data/CCLE_RRBS_TSS1kb_20181022.txt', sep='\t')
display(raw_methy_df)

In [None]:
# column locus_id should be separate the content after the  '_'
raw_methy_df['locus_id'] = raw_methy_df['locus_id'].apply(lambda x: x.split('_')[0])
# drop na in the column locus_id
raw_methy_df = raw_methy_df.dropna(subset=['CpG_sites_hg19'])
# remove the column CpG_sites_hg19 and avg_coverage
raw_methy_df = raw_methy_df.drop(columns=['CpG_sites_hg19', 'avg_coverage'])
display(raw_methy_df)

In [None]:
# Identify columns to convert (all except 'locus_id')
cols_to_convert = raw_methy_df.columns.difference(['locus_id'])
# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)
raw_methy_df[cols_to_convert] = raw_methy_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')
raw_methy_df = raw_methy_df.fillna(0.0)
display(raw_methy_df)

In [None]:
# Now perform the groupby mean operation
methy_df = raw_methy_df.groupby('locus_id', as_index=False).mean()
display(methy_df)

In [None]:
bmg_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')
# keep BioMedGraphica_ID and HGNC_Symbol
bmg_promoter_df = bmg_promoter_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]
display(bmg_promoter_df)

In [None]:
# merge the biomedgraphica_id with the raw_methy_df
merged_methy_df = pd.merge(bmg_promoter_df, methy_df, left_on='HGNC_Symbol', right_on='locus_id', how='inner')
merged_methy_df.drop(columns=['HGNC_Symbol', 'locus_id'], inplace=True)
display(merged_methy_df)

In [None]:
final_merged_methy_df = pd.merge(bmg_promoter_df, methy_df, left_on='HGNC_Symbol', right_on='locus_id', how='left')
# fill the NaN with 0.0
final_merged_methy_df = final_merged_methy_df.fillna(0.0)
final_merged_methy_df.drop(columns=['HGNC_Symbol', 'locus_id'], inplace=True)
display(final_merged_methy_df)

## 2. Gene

In [None]:
raw_gene_df = pd.read_csv('./data/raw_data/OmicsCNGene.csv')
display(raw_gene_df)

In [None]:
# First, set the "Unnamed: 0" column as the index:
raw_gene_t_df = raw_gene_df.set_index("Unnamed: 0")
# Then transpose the DataFrame:
raw_gene_t_df = raw_gene_t_df.transpose()
# move the index to a column and rename it as the gene_name
raw_gene_t_df.reset_index(inplace=True)
raw_gene_t_df.rename(columns={'index': 'gene_name'}, inplace=True)
# column gene_name should be kept with the content by removing "()" and remove and " " content
raw_gene_t_df['gene_name'] = raw_gene_t_df['gene_name'].apply(lambda x: x.split('(')[0].strip())
# Optionally, if you want to view the result:
display(raw_gene_t_df)

In [None]:
# Identify columns to convert (all except 'Unnamed: 0' and 'gene_name')
cols_to_convert = raw_gene_t_df.columns.difference(['Unnamed: 0', 'gene_name'])
# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)
raw_gene_t_df[cols_to_convert] = raw_gene_t_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')
raw_gene_t_df = raw_gene_t_df.fillna(0.0)
display(raw_gene_t_df)

In [None]:
# groupby the gene_name and calculate the mean
gene_df = raw_gene_t_df.groupby('gene_name', as_index=False).mean()
display(gene_df)

In [None]:
bmg_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')
# keep BioMedGraphica_ID and HGNC_Symbol
bmg_gene_df = bmg_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]
display(bmg_gene_df)

In [None]:
# merge the biomedgraphica_id with the raw_gene_t_df
merged_gene_df = pd.merge(bmg_gene_df, gene_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
merged_gene_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(merged_gene_df)

In [None]:
final_merged_gene_df = pd.merge(bmg_gene_df, gene_df, left_on='HGNC_Symbol', right_on='gene_name', how='left')
# fill the NaN with -1.0
final_merged_gene_df = final_merged_gene_df.fillna(0.0)
final_merged_gene_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(final_merged_gene_df)

## 3. Transcript

In [None]:
raw_transcript_df = pd.read_csv('./data/raw_data/OmicsExpressionProteinCodingGenesTPMLogp1BatchCorrected.csv')
display(raw_transcript_df)

In [None]:
# First, set the "Unnamed: 0" column as the index:
raw_transcript_df = raw_transcript_df.set_index("Unnamed: 0")
# Then transpose the DataFrame:
raw_transcript_df = raw_transcript_df.transpose()
# move the index to a column and rename it as the gene_name
raw_transcript_df.reset_index(inplace=True)
raw_transcript_df.rename(columns={'index': 'gene_name'}, inplace=True)
# column gene_name should be kept with the content by removing "()" and remove and
raw_transcript_df['gene_name'] = raw_transcript_df['gene_name'].apply(lambda x: x.split('(')[0].strip())
# Optionally, if you want to view the result:
display(raw_transcript_df)

In [None]:
# Identify columns to convert (all except 'Unnamed: 0' and 'gene_name')
cols_to_convert = raw_transcript_df.columns.difference(['Unnamed: 0', 'gene_name'])
# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)
raw_transcript_df[cols_to_convert] = raw_transcript_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')
raw_transcript_df = raw_transcript_df.fillna(0.0)
display(raw_transcript_df)

In [None]:
bmg_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')
# keep BioMedGraphica_ID and HGNC_Symbol
bmg_transcript_df = bmg_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]
display(bmg_transcript_df)

In [None]:
# merge the biomedgraphica_id with the raw_transcript_df
merge_transcript_df = pd.merge(bmg_transcript_df, raw_transcript_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
merge_transcript_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(merge_transcript_df)

In [None]:
# merge the biomedgraphica_id with the raw_transcript_df
final_merged_transcript_df = pd.merge(bmg_transcript_df, raw_transcript_df, left_on='HGNC_Symbol', right_on='gene_name', how='left')
# fill the NaN with -1.0
final_merged_transcript_df = final_merged_transcript_df.fillna(0.0)
final_merged_transcript_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(final_merged_transcript_df)

## 4. Protein

In [None]:
raw_protein_df = pd.read_csv('./data/raw_data/protein_quant_current_normalized.csv')
raw_protein_df = raw_protein_df.drop(columns=['Protein_Id', 'Gene_Symbol', 'Description', 'Group_ID', 'Uniprot'])
# Also drop columns names contain Peptides
raw_protein_df = raw_protein_df[raw_protein_df.columns.drop(list(raw_protein_df.filter(regex='Peptides')))]
display(raw_protein_df)

In [None]:
# Identify columns to convert (all except 'Unnamed: 0' and 'protein_name')
cols_to_convert = raw_protein_df.columns.difference(['Uniprot_Acc'])
# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)
raw_protein_df[cols_to_convert] = raw_protein_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')
raw_protein_df = raw_protein_df.fillna(0.0)
display(raw_protein_df)

In [None]:
# groupby the protein_name and calculate the mean
protein_df = raw_protein_df.groupby('Uniprot_Acc', as_index=False).mean()
display(protein_df)


In [None]:
# merge the biomedgraphica_id with the raw_protein_df
bmg_protein_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')
bmg_protein_df = bmg_protein_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID']]
display(bmg_protein_df)

In [None]:
# merge the biomedgraphica_id with the raw_protein_df
merged_protein_df = pd.merge(bmg_protein_df, protein_df, left_on='Uniprot_ID', right_on='Uniprot_Acc', how='inner')
merged_protein_df.drop(columns=['Uniprot_ID', 'Uniprot_Acc'], inplace=True)
display(merged_protein_df)

In [None]:
# merge the biomedgraphica_id with the raw_protein_df
final_merged_protein_df = pd.merge(bmg_protein_df, protein_df, left_on='Uniprot_ID', right_on='Uniprot_Acc', how='left')
# fill the NaN with 0.0
final_merged_protein_df = final_merged_protein_df.fillna(0.0)
final_merged_protein_df.drop(columns=['Uniprot_ID', 'Uniprot_Acc'], inplace=True)
display(final_merged_protein_df)

## 5. Drug

In [None]:
raw_drug_df = pd.read_csv('./data/raw_data/sanger-dose-response.csv')
display(raw_drug_df)

In [None]:
# keep columns ['ARXSPAN_ID', 'DRUG_NAME', 'IC50_PUBLISHED']
raw_drug_df = raw_drug_df[['ARXSPAN_ID', 'DRUG_NAME', 'IC50_PUBLISHED', 'AUC_PUBLISHED']]
# check if there is nan in all of the dataframe
print(raw_drug_df.isnull().sum())
# drop the nan in this dataframe
raw_drug_df = raw_drug_df.dropna().reset_index(drop=True)
# check if there is nan in all of the dataframe again
print(raw_drug_df.isnull().sum())
display(raw_drug_df)

In [None]:
# fetch the drug_name as a independent dataframe
drug_name_df = raw_drug_df[['DRUG_NAME']]
# remove the duplicate in the drug_name_df
drug_name_df = drug_name_df.drop_duplicates().reset_index(drop=True)
display(drug_name_df)

In [None]:
bmg_drug_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug.csv')
display(bmg_drug_df)
bmg_drug_name_df = bmg_drug_df[['BioMedGraphica_Conn_ID', 'PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym']]
# keep BioMedGraphica_Conn_ID and Drug_Name
display(bmg_drug_name_df)

In [None]:
# Combine 'PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym' into one column called 'Drug_Name'
bmg_drug_name_df['Drug_Name'] = bmg_drug_name_df[['PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym']].apply(lambda x: ';'.join(x.dropna().astype(str)), axis=1)
bmg_drug_name_df = bmg_drug_name_df[['BioMedGraphica_Conn_ID', 'Drug_Name']]
display(bmg_drug_name_df)

In [None]:
# since in the column drug name, there are lots of names split by ";", if this space contain the name in the drug_name_df, then we should match this 2 rows and merge the drug_name_df with the bmg_drug_df
# Fix the drug name processing
# First, check if there are any NaN values in Drug_Name column
print(f"Number of NaN values in Drug_Name column: {bmg_drug_name_df['Drug_Name'].isna().sum()}")

# Fill NaN values with empty string to avoid errors
bmg_drug_name_df['Drug_Name'] = bmg_drug_name_df['Drug_Name'].fillna('')

# Convert to string to ensure split() works on all entries
bmg_drug_name_df['drug_name_list'] = bmg_drug_name_df['Drug_Name'].astype(str).apply(lambda x: [name.strip() for name in x.split(';')] if x else [])

# Similarly for drug_name_df
drug_name_df['drug_name_list'] = drug_name_df['DRUG_NAME'].astype(str).apply(lambda x: [name.strip() for name in x.split(';')] if x else [])

In [None]:
display(bmg_drug_name_df)
display(drug_name_df)

In [None]:
# Create a flattened mapping of drug names to their BMG IDs
drug_name_to_bmg = {}
for idx, row in bmg_drug_name_df.iterrows():
    bmg_id = row['BioMedGraphica_Conn_ID']
    for drug_name in row['drug_name_list']:
        if drug_name:  # Avoid empty strings
            drug_name_to_bmg[drug_name.upper()] = bmg_id

# Create a new mapping from ARXSPAN_ID to BMG_ID based on drug name matches
arxspan_to_bmg = {}
for idx, row in drug_name_df.iterrows():
    arxspan_id = row['DRUG_NAME']
    for drug_name in row['drug_name_list']:
        if drug_name.upper() in drug_name_to_bmg:
            arxspan_to_bmg[arxspan_id] = drug_name_to_bmg[drug_name.upper()]
            break

# Create a mapping dataframe
mapping_df = pd.DataFrame(list(arxspan_to_bmg.items()), columns=['DRUG_NAME', 'BioMedGraphica_Conn_ID'])

# Display how many drug names were successfully matched
print(f"Successfully matched {len(mapping_df)} out of {len(drug_name_df)} drugs")

# Display the first few rows of the mapping
print("\nSample of drug name mappings:")
display(mapping_df)

# Now you can use this mapping to merge with your drug_response_df
# Example:
# merged_df = pd.merge(drug_response_df, mapping_df, on='DRUG_NAME', how='left')


In [None]:
final_merged_drug_df = mapping_df.copy()
# sort values by the column BioMedGraphica_ID
final_merged_drug_df = final_merged_drug_df.sort_values(by='BioMedGraphica_Conn_ID').reset_index(drop=True)
display(final_merged_drug_df)

In [None]:
# filter the final drug score dataframe by filtering the 'DRUG_NAME' in the final_merged_drug_df DRUG_NAME
final_drug_df = raw_drug_df[raw_drug_df['DRUG_NAME'].isin(final_merged_drug_df['DRUG_NAME'])].reset_index(drop=True)
display(final_drug_df)

## 6. CRISPR

In [None]:
# load entity type ['Promoter', 'Gene', 'Transcript', 'Protein'] for bmgc_entity_df
bmgc_entity_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/BioMedGraphica_Conn_Entity.csv')
# check if there is any null values in the bmgc_entity_df
print(bmgc_entity_df.isnull().sum())
# filter out the rows with Type in the list ['promoter', 'gene', 'transcript', 'protein']
bmgc_omics_df = bmgc_entity_df[bmgc_entity_df['Type'].isin(['Promoter', 'Gene', 'Transcript', 'Protein'])].reset_index(drop=True)
# check if there is any null values in the bmgc_omics_df
print(bmgc_omics_df.isnull().sum())
display(bmgc_omics_df)

In [None]:
# load the relation
bmgc_relation_df = pd.read_csv('./data/BioMedGraphica-Conn/Relation/BioMedGraphica_Conn_Relation.csv')
# check if there is any null values in the bmgc_relation_df
print(bmgc_relation_df.isnull().sum())

# filter our the rows with Relation type in the list ['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein', 'Protein-Protein']
bmgc_omics_relation_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein', 'Protein-Protein'])].reset_index(drop=True)
# check if there is any null values in the bmgc_omics_relation_df
print(bmgc_omics_relation_df.isnull().sum())
display(bmgc_omics_relation_df)

In [None]:
# translation chain converging to the same node
# fetch the promoter, gene, transcript and protein entity alone
promoter_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Promoter'].copy()
gene_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Gene'].copy()
transcript_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Transcript'].copy()
protein_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Protein'].copy()

display(bmgc_omics_relation_df)
# recheck the null values in bmgc_omics_relation_df
print("Null values in bmgc_omics_relation_df:")
print(bmgc_omics_relation_df.isnull().sum())

# fetch the Promoter-Gene, Gene-Transcript, Transcript-Protein relation alone
promoter_gene_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].copy()
gene_transcript_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].copy()
transcript_protein_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].copy()

In [None]:
gene_transcript_entity_df = pd.merge(gene_entity_df, gene_transcript_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')
gene_transcript_protein_entity_df = pd.merge(gene_transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BMGC_To_ID', right_on='BMGC_From_ID', how='outer')
# drop NaN values in BMGC_From_ID_x	BMGC_To_ID_x BMGC_From_ID_y	BMGC_To_ID_y
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.dropna(subset=['BMGC_From_ID_x', 'BMGC_To_ID_x', 'BMGC_From_ID_y', 'BMGC_To_ID_y']).reset_index(drop=True)
# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID_x', 'BMGC_To_ID_y'] and rename the columns to ['BMGC_GN_ID', 'BMGC_TS_ID', 'BMGC_PT_ID']
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID_y']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID', 'BMGC_To_ID_y': 'BMGC_PT_ID'}).sort_values(by='BMGC_GN_ID').reset_index(drop=True)
# drop duplicates rows in gene_transcript_protein_entity_df
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)
display(gene_transcript_protein_entity_df)

In [None]:
# just copy gene_transcript_protein_entity_df as promoter_gene_transcript_protein_entity_df
bmgc_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')
bmgc_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')
promoter_gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.copy()
promoter_gene_df = pd.concat([bmgc_promoter_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PM_ID'}), bmgc_gene_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID'})], axis=1)
promoter_protein_entity_df = pd.merge(promoter_gene_transcript_protein_entity_df, promoter_gene_df, left_on='BMGC_GN_ID', right_on='BMGC_GN_ID', how='left').drop(columns=['BMGC_GN_ID'])
promoter_protein_entity_df = promoter_protein_entity_df[['BMGC_PM_ID', 'BMGC_PT_ID']]
display(promoter_protein_entity_df)

In [None]:
transcript_protein_entity_df = pd.merge(transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')
# drop NaN rows in the BMGC_From_ID and	BMGC_To_ID
transcript_protein_entity_df = transcript_protein_entity_df.dropna(subset=['BMGC_From_ID', 'BMGC_To_ID']).reset_index(drop=True)
# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID'] and rename the columns to ['BMGC_TS_ID', 'BMGC_PT_ID']
transcript_protein_entity_df = transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_TS_ID', 'BMGC_To_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_TS_ID').reset_index(drop=True)
# drop duplicates rows in transcript_protein_entity_df
transcript_protein_entity_df = transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)
display(transcript_protein_entity_df)

In [None]:
# keep the columns ['BioMedGraphica_Conn_ID'] and rename the columns to ['BMGC_PT_ID']
only_protein_entity_df = protein_entity_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_PT_ID').reset_index(drop=True)
display(only_protein_entity_df)

### 6.1 CRISPR top 100 gene entities

In [None]:
import pandas as pd
raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')
display(raw_crispr_df)

In [None]:
# get the column names of the raw_crispr_df aside from the first column and convert this to a list
raw_crispr_df_columns = raw_crispr_df.columns[1:].tolist()
print(raw_crispr_df_columns)

## 6. Coalign the samples

In [None]:
methy_samples = final_merged_methy_df.columns[1:]
gene_samples = final_merged_gene_df.columns[1:]
transcript_samples = final_merged_transcript_df.columns[1:]
protein_samples = final_merged_protein_df.columns[1:]
drug_samples = list(set(list(final_drug_df['ARXSPAN_ID'])))

# print all samples
print(methy_samples)
print(gene_samples)
print(transcript_samples)
print(protein_samples)
print(drug_samples)

In [None]:
# cell line names in the protein samples, remove the content after the second "_"
# result = ['_'.join(s.split('_')[:2]) for s in strings]
protein_split_samples = ['_'.join(s.split('_')[:2]) for s in protein_samples]
print(protein_split_samples)

In [None]:
cell_line_anno_df = pd.read_csv('./data/raw_data/Cell_lines_annotations_20181226.txt', sep='\t')
# drop NaN in the depMapID
cell_line_anno_df = cell_line_anno_df.dropna(subset=['depMapID']).reset_index(drop=True)
# drop NaN in the Pathology
cell_line_anno_df = cell_line_anno_df.dropna(subset=['Pathology']).reset_index(drop=True)
# drop rows if the column 'PATHOLOGIST_ANNOTATION' contains 'benign'
# First handle NaN values in PATHOLOGIST_ANNOTATION
cell_line_anno_df = cell_line_anno_df[~(
    cell_line_anno_df['PATHOLOGIST_ANNOTATION'].fillna('').str.contains('benign', case=False)
)]
display(cell_line_anno_df)

In [None]:
print(len(list(set(protein_split_samples))))
# give me the duplicate elements in the protein_split_samples
import collections
rep_list = [item for item, count in collections.Counter(protein_split_samples).items() if count > 1]
# for column names in final_merged_protein_df, if the column name contains the element in the rep_list, collect the column name in rep_col_list
rep_col_list = [col for col in final_merged_protein_df.columns if any(rep in col for rep in rep_list)]
print(rep_col_list)
# display the rep_col_list in the final_merged_protein_df
display(final_merged_protein_df[rep_col_list])

# removed all of the rep_col_list in the final_merged_protein_df
final_merged_protein_df = final_merged_protein_df.drop(columns=rep_col_list)
display(final_merged_protein_df)

In [None]:
# get the overlapped cell lines in methy_samples with cell_line_anno_df by merge on cell line id
# first format the methy_samples as dataframe
methy_samples_df = pd.DataFrame(methy_samples, columns=['CCLE_Name'])
# merge the methy_samples_df with the cell_line_anno_df
merged_methy_samples_df = pd.merge(methy_samples_df, cell_line_anno_df, left_on='CCLE_Name', right_on='CCLE_ID', how='inner')
display(merged_methy_samples_df)

# get the map dictionary from the ccle id to depmap id
methy_map_dict = dict(zip(merged_methy_samples_df['CCLE_Name'], merged_methy_samples_df['depMapID']))
print(methy_map_dict)

In [None]:
# get the cleaned cell line names by removed the content after the second "_" in final_merged_protein_df.columns[1:]
original_protein_samples = final_merged_protein_df.columns[1:]
cleaned_protein_samples = ['_'.join(s.split('_')[:2]) for s in original_protein_samples]
# get the cleaned_protein_samples_dict_df from dictionary
cleaned_protein_samples_dict_df = pd.DataFrame({'CCLE_Name': original_protein_samples, 'cleaned_CCLE_Name': cleaned_protein_samples})
# merge the cleaned_protein_samples_df with the cell_line_anno_df
merged_cleaned_protein_samples_df = pd.merge(cleaned_protein_samples_dict_df, cell_line_anno_df, left_on='cleaned_CCLE_Name', right_on='CCLE_ID', how='inner')
display(merged_cleaned_protein_samples_df)
# get the depmap protein id
depmap_protein_samples = merged_cleaned_protein_samples_df['depMapID'].to_list()

# Create the map dictionary for the protein samples
protein_map_dict = dict(zip(original_protein_samples, depmap_protein_samples))
print(protein_map_dict)

In [None]:
# get the methy_mapped_samples and protein_mapped_samples from merged_methy_samples_df and merged_cleaned_protein_samples_df
methy_mapped_samples = list(merged_methy_samples_df['depMapID'])
protein_mapped_samples = list(merged_cleaned_protein_samples_df['depMapID'])

### 6.1 Aligning over omics

#### 6.1.1 Aligning with intersection

In [None]:
# get the overlapped samples only over omics samples
omics_overlapped_samples = sorted(list(set(methy_mapped_samples) & set(gene_samples) & set(transcript_samples) & set(protein_mapped_samples)))
print("length of omics overlapped samples: ", len(omics_overlapped_samples))
# check the overlapped samples with annotation samples
annotation_samples = list(cell_line_anno_df['depMapID'])
overlapped_omics_annotation_samples = sorted(list(set(omics_overlapped_samples) & set(annotation_samples)))
print("length of overlapped omics and annotation samples: ", len(overlapped_omics_annotation_samples))

#### 6.1.2 Aligning with union on omics and intersection with annotation

In [None]:
# get the union samples only over omics samples
omics_union_samples = sorted(list(set(methy_mapped_samples) | set(gene_samples) | set(transcript_samples) | set(protein_mapped_samples)))
print("length of omics union samples: ", len(omics_union_samples))
# check the intersection of union samples with annotation samples
overlapped_omics_union_annotation_samples = sorted(list(set(omics_union_samples) & set(annotation_samples)))
print("length of overlapped omics and annotation samples: ", len(overlapped_omics_union_annotation_samples))
print(overlapped_omics_union_annotation_samples)

#### 6.1.3 overlapped_omics_union_annotation_samples (cancerous / non-cancerous) Ratio

In [None]:
non_cancerous_samples_df = pd.read_csv('./data/raw_data/cell-lines-in-Non-Cancerous.csv')
display(non_cancerous_samples_df)
# get the depmap id from the non_cancerous_samples_df
non_cancerous_samples = list(non_cancerous_samples_df['Depmap Id'])
# get the overlapped non_cancerous samples with the overlapped_omics_union_annotation_samples
overlapped_non_cancerous_samples = sorted(list(set(overlapped_omics_union_annotation_samples) & set(non_cancerous_samples)))
print("length of overlapped omics and annotation samples: ", len(overlapped_omics_union_annotation_samples))
print("length of overlapped non cancerous samples: ", len(overlapped_non_cancerous_samples))


### 6.2 Overlapping over DTI

In [None]:
# get the overlapped samples among all samples (methy_mapped_samples, gene_samples, transcript_samples, protein_mapped_samples, drug_samples)
dti_overlapped_samples = sorted(list(set(overlapped_omics_union_annotation_samples) & set(drug_samples)))
print("len(overlapped_samples):", len(dti_overlapped_samples))
# need to confirm zero intersection between the overlapped samples and non_cancerous_samples
dti_overlapped_non_cancerous_samples = sorted(list(set(dti_overlapped_samples) & set(non_cancerous_samples)))
print("len(overlapped_non_cancerous_samples):", len(dti_overlapped_non_cancerous_samples))
# if there are overlapped samples between the dti_overlapped_samples and non_cancerous_samples, then remove these samples from the dti_overlapped_samples
if len(dti_overlapped_non_cancerous_samples) > 0:
    dti_overlapped_samples = sorted(list(set(dti_overlapped_samples) - set(dti_overlapped_non_cancerous_samples)))
# convert the overlapped_samples to the dataframe with annotation
dti_overlapped_samples_df = pd.merge(pd.DataFrame(dti_overlapped_samples, columns=['depMapID']), cell_line_anno_df, on='depMapID', how='inner')
display(dti_overlapped_samples_df)

In [None]:
# mkdir for processsing folder
import os
if not os.path.exists('./data/process_data'):
    os.makedirs('./data/process_data')

In [None]:
dti_overlapped_samples_df.to_csv('./data/process_data/dti_overlapped_samples.csv', index=False)

In [None]:
print(dti_overlapped_samples_df.columns)

## 7. Annotate the samples with disease

### 7.1 Add annotation from Cellosaurus

In [None]:
# import pandas as pd
# import requests

# url = 'https://ftp.expasy.org/databases/cellosaurus/cellosaurus.obo'
# response = requests.get(url)

# with open('./data/raw_data/cellosaurus.obo', 'wb') as f:
#     f.write(response.content)

# # Path to the downloaded OBO file
# obo_path = './data/raw_data/cellosaurus.obo'

# entries = []
# current_entry = {}

# # Read and parse
# with open(obo_path, 'r', encoding='utf-8') as file:
#     for line in file:
#         line = line.strip()

#         if line == "[Term]":
#             if current_entry.get("xref") and any("NCBI_TaxID:9606" in x for x in current_entry["xref"]):
#                 # Separate NCIt and ORDO xrefs
#                 ncit_refs = [x for x in current_entry["xref"] if x.startswith("NCIt:")]
#                 ordo_refs = [x for x in current_entry["xref"] if x.startswith("ORDO:")]
#                 current_entry["xref_NCIt"] = "; ".join(ncit_refs)
#                 current_entry["xref_ORDO"] = "; ".join(ordo_refs)
#                 entries.append({
#                     "id": current_entry.get("id", ""),
#                     "name": current_entry.get("name", ""),
#                     "synonym": "; ".join(current_entry.get("synonym", [])),
#                     "xref_NCIt": current_entry.get("xref_NCIt", ""),
#                     "xref_ORDO": current_entry.get("xref_ORDO", "")
#                 })
#             current_entry = {"synonym": [], "xref": []}

#         elif line.startswith("id:"):
#             current_entry["id"] = line.split("id:")[1].strip()

#         elif line.startswith("name:"):
#             current_entry["name"] = line.split("name:")[1].strip()

#         elif line.startswith("synonym:"):
#             synonym = line.split("synonym:")[1].split("RELATED")[0].strip().strip('"')
#             current_entry["synonym"].append(synonym)

#         elif line.startswith("xref:"):
#             current_entry["xref"].append(line.split("xref:")[1].strip())

# # Handle last entry
# if current_entry.get("xref") and any("NCBI_TaxID:9606" in x for x in current_entry["xref"]):
#     ncit_refs = [x for x in current_entry["xref"] if x.startswith("NCIt:")]
#     ordo_refs = [x for x in current_entry["xref"] if x.startswith("ORDO:")]
#     entries.append({
#         "id": current_entry.get("id", ""),
#         "name": current_entry.get("name", ""),
#         "synonym": "; ".join(current_entry.get("synonym", [])),
#         "xref_NCIt": "; ".join(ncit_refs),
#         "xref_ORDO": "; ".join(ordo_refs)
#     })

# # Create DataFrame
# cellosaurus_parsed_df = pd.DataFrame(entries)
# cellosaurus_parsed_df = cellosaurus_parsed_df[["id", "name", "synonym", "xref_NCIt", "xref_ORDO"]]
# cellosaurus_parsed_df.to_csv('./data/raw_data/cellosaurus_parsed.csv', index=False)
# # Show preview
# display(cellosaurus_parsed_df)

In [None]:
# # Step 1: Keep the desired columns
# dti_overlapped_samples_desc_df = dti_overlapped_samples_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code']]
# display(dti_overlapped_samples_desc_df)

# # Step 2: Try direct merge on Name == name using left join
# matched_df = pd.merge(
#     dti_overlapped_samples_desc_df,
#     cellosaurus_parsed_df,
#     left_on='Name',
#     right_on='name',
#     how='left'
# )

# # Step 3: Find rows where no match occurred (name is NaN)
# unmatched_df = matched_df[matched_df['name'].isna()].copy()

# # Step 4: Drop the Cellosaurus columns to prep for synonym match
# cols_to_drop = [col for col in matched_df.columns if col in cellosaurus_parsed_df.columns]
# unmatched_df = unmatched_df.drop(columns=cols_to_drop)

# # Step 5: Expand cellosaurus synonyms
# synonym_expanded_df = cellosaurus_parsed_df.copy()
# synonym_expanded_df = synonym_expanded_df.dropna(subset=['synonym'])
# synonym_expanded_df = synonym_expanded_df.assign(
#     synonym=synonym_expanded_df['synonym'].str.split(';')
# ).explode('synonym')
# synonym_expanded_df['synonym'] = synonym_expanded_df['synonym'].str.strip()

# # Step 6: Left join unmatched rows with synonym-expanded Cellosaurus
# synonym_matched_df = pd.merge(
#     unmatched_df,
#     synonym_expanded_df,
#     left_on='Name',
#     right_on='synonym',
#     how='left'
# )

# # Step 7: Combine direct-matched (non-NaN name) and synonym-matched rows
# all_match_df = pd.concat(
#     [matched_df[matched_df['name'].notna()], synonym_matched_df],
#     ignore_index=True
# )

# # Optional: Drop duplicates based on depMapID if needed
# all_match_df = all_match_df.drop_duplicates(subset=['depMapID'])
# display(all_match_df)

In [None]:
# # Function to store primary and extra xref matches with better column handling
# def process_xref_columns_with_extras(df):
#     # Create a copy to avoid modifying the original
#     result_df = df.copy()
    
#     # Always create all the columns, even if they might end up empty
#     result_df['xref_NCIt_ID'] = ''
#     result_df['xref_NCIt_name'] = ''
#     result_df['xref_NCIt_extra_ID'] = ''
#     result_df['xref_NCIt_extra_name'] = ''
#     result_df['xref_ORDO_ID'] = ''
#     result_df['xref_ORDO_name'] = ''
#     result_df['xref_ORDO_extra_ID'] = ''
#     result_df['xref_ORDO_extra_name'] = ''
    
#     # Process xref_NCIt if it exists
#     if 'xref_NCIt' in result_df.columns:
#         # Only process rows that have xref_NCIt values
#         mask = result_df['xref_NCIt'].notna() & (result_df['xref_NCIt'] != '')
        
#         for idx in result_df[mask].index:
#             xref_value = result_df.loc[idx, 'xref_NCIt']
            
#             if '; ' in xref_value:  # Multiple entries
#                 entries = xref_value.split('; ')
                
#                 # Process first entry for primary columns
#                 if ' ! ' in entries[0]:
#                     id_val, name_val = entries[0].split(' ! ', 1)
#                     result_df.loc[idx, 'xref_NCIt_ID'] = id_val
#                     result_df.loc[idx, 'xref_NCIt_name'] = name_val
                
#                 # Process additional entries for extra columns
#                 extra_ids = []
#                 extra_names = []
#                 for entry in entries[1:]:
#                     if ' ! ' in entry:
#                         id_val, name_val = entry.split(' ! ', 1)
#                         extra_ids.append(id_val)
#                         extra_names.append(name_val)
                
#                 if extra_ids:
#                     result_df.loc[idx, 'xref_NCIt_extra_ID'] = '; '.join(extra_ids)
#                     result_df.loc[idx, 'xref_NCIt_extra_name'] = '; '.join(extra_names)
                    
#             elif ' ! ' in xref_value:  # Single entry
#                 id_val, name_val = xref_value.split(' ! ', 1)
#                 result_df.loc[idx, 'xref_NCIt_ID'] = id_val
#                 result_df.loc[idx, 'xref_NCIt_name'] = name_val
        
#         # Drop the original column
#         result_df = result_df.drop(columns=['xref_NCIt'])
    
#     # Process xref_ORDO if it exists
#     if 'xref_ORDO' in result_df.columns:
#         # Only process rows that have xref_ORDO values
#         mask = result_df['xref_ORDO'].notna() & (result_df['xref_ORDO'] != '')
        
#         for idx in result_df[mask].index:
#             xref_value = result_df.loc[idx, 'xref_ORDO']
            
#             if '; ' in xref_value:  # Multiple entries
#                 entries = xref_value.split('; ')
                
#                 # Process first entry for primary columns
#                 if ' ! ' in entries[0]:
#                     id_val, name_val = entries[0].split(' ! ', 1)
#                     result_df.loc[idx, 'xref_ORDO_ID'] = id_val
#                     result_df.loc[idx, 'xref_ORDO_name'] = name_val
                
#                 # Process additional entries for extra columns
#                 extra_ids = []
#                 extra_names = []
#                 for entry in entries[1:]:
#                     if ' ! ' in entry:
#                         id_val, name_val = entry.split(' ! ', 1)
#                         extra_ids.append(id_val)
#                         extra_names.append(name_val)
                
#                 if extra_ids:
#                     result_df.loc[idx, 'xref_ORDO_extra_ID'] = '; '.join(extra_ids)
#                     result_df.loc[idx, 'xref_ORDO_extra_name'] = '; '.join(extra_names)
                    
#             elif ' ! ' in xref_value:  # Single entry
#                 id_val, name_val = xref_value.split(' ! ', 1)
#                 result_df.loc[idx, 'xref_ORDO_ID'] = id_val
#                 result_df.loc[idx, 'xref_ORDO_name'] = name_val
        
#         # Drop the original column
#         result_df = result_df.drop(columns=['xref_ORDO'])
    
#     return result_df

# # Apply the function
# all_match_df = process_xref_columns_with_extras(all_match_df)

# # Update the column list to include the extra columns
# dti_overlapped_samples_desc_co_df = all_match_df[['depMapID', 'Name', 'Pathology', 'Histology', 
#                                              'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 
#                                              'xref_NCIt_ID', 'xref_NCIt_name', 'xref_NCIt_extra_ID', 'xref_NCIt_extra_name',
#                                              'xref_ORDO_ID', 'xref_ORDO_name', 'xref_ORDO_extra_ID', 'xref_ORDO_extra_name']]
# display(dti_overlapped_samples_desc_co_df)

### 7.2 Cell line disease match

#### 7.2.1 Cell line disease hard match

In [None]:
# bmg_disease_name_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_GUI_Name.csv')
# display(bmg_disease_name_df)

In [None]:
# # Step 1: Expand BMG_Disease_Name (still lowercased and exploded)
# bmg_expanded = bmg_disease_name_df.copy()

# # First, strip any surrounding quotes and fill NaN values with empty strings
# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].fillna('')
# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].str.replace('^"|"$', '', regex=True)

# # Split on pipe with spaces
# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].str.split(' \| ')

# # Apply strip only to non-empty lists
# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].apply(
#     lambda x: [name.strip() for name in x] if isinstance(x, list) else []
# )

# # Explode the list into separate rows
# bmg_expanded = bmg_expanded.explode('Disease_Name_List')

# # Remove rows with empty disease names after exploding
# bmg_expanded = bmg_expanded[bmg_expanded['Disease_Name_List'].str.len() > 0]

# # Create lowercase version for easier matching
# bmg_expanded['Disease_Name_List_lower'] = bmg_expanded['Disease_Name_List'].str.lower()

# display(bmg_expanded)

In [None]:
# # Create a lookup list of tuples: (Conn_ID, Original_Name, Lower_Name)
# bmg_records = bmg_expanded[['BioMedGraphica_Conn_ID', 'Disease_Name_List', 'Disease_Name_List_lower']].to_records(index=False)

# # Updated matching function: look for exact full-name match
# def match_bmg_disease_exact(xref_name, bmg_records):
#     if pd.isna(xref_name):
#         return (None, None)
#     xref_name_lower = str(xref_name).strip().lower()
    
#     for conn_id, original_name, bmg_name_lower in bmg_records:
#         if xref_name_lower == bmg_name_lower:
#             return (conn_id, original_name)
    
#     return (None, None)

# # NCIt
# ncit_matches = dti_overlapped_samples_desc_co_df['xref_NCIt_name'].apply(lambda x: match_bmg_disease_exact(x, bmg_records))
# dti_overlapped_samples_desc_co_df['NCIt_BMGC_ID'] = ncit_matches.apply(lambda x: x[0])
# dti_overlapped_samples_desc_co_df['NCIt_BMGC_name'] = ncit_matches.apply(lambda x: x[1])

# # ORDO
# ordo_matches = dti_overlapped_samples_desc_co_df['xref_ORDO_name'].apply(lambda x: match_bmg_disease_exact(x, bmg_records))
# dti_overlapped_samples_desc_co_df['ORDO_BMGC_ID'] = ordo_matches.apply(lambda x: x[0])
# dti_overlapped_samples_desc_co_df['ORDO_BMGC_name'] = ordo_matches.apply(lambda x: x[1])

# display(dti_overlapped_samples_desc_co_df)


In [None]:
# # keep the columns ['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_ID', 'xref_NCIt_name', 'NCIt_BMGC_ID', 'NCIt_BMGC_name', 'xref_ORDO_ID', 'xref_ORDO_name', 'ORDO_BMGC_ID', 'ORDO_BMGC_name']
# dti_overlapped_samples_desc_co_df = dti_overlapped_samples_desc_co_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_ID', 'xref_NCIt_name', 'NCIt_BMGC_ID', 'NCIt_BMGC_name', 'xref_ORDO_ID', 'xref_ORDO_name', 'ORDO_BMGC_ID', 'ORDO_BMGC_name']]
# display(dti_overlapped_samples_desc_co_df)
# dti_overlapped_samples_desc_co_df.to_csv('./data/process_data/dti_overlapped_samples_desc_co.csv', index=False)

In [None]:
# Merge NCIt_BMGC_ID/ORDO_BMGC_ID and NCIt_BMGC_name/ORDO_BMGC_name into BMGC_Matched columns
# Strategy:
# 1. Prioritize NCIt data when both are available
# 2. Use NCIt data if only NCIt is available
# 3. Use ORDO data if only ORDO is available
# 4. Leave as null if both are null

# dti_matched_df = dti_overlapped_samples_desc_co_df.copy()

# # Create new merged ID column
# dti_matched_df['BMGC_Matched_ID'] = (
#     dti_matched_df['NCIt_BMGC_ID'].fillna(
#         dti_matched_df['ORDO_BMGC_ID']
#     )
# )

# # Create new merged name column
# dti_matched_df['BMGC_Matched_name'] = (
#     dti_matched_df['NCIt_BMGC_name'].fillna(
#         dti_matched_df['ORDO_BMGC_name']
#     )
# )

# # Display the dataframe with new columns
# display(dti_matched_df)

# # Check number of null values in new columns
# print(f"Null values in BMGC_Matched_ID: {dti_matched_df['BMGC_Matched_ID'].isna().sum()}")
# print(f"Null values in BMGC_Matched_name: {dti_matched_df['BMGC_Matched_name'].isna().sum()}")

# # only keep the columns where the BMGC_Matched_ID is not null
# dti_matched_df = dti_matched_df[dti_matched_df['BMGC_Matched_ID'].notna()].reset_index(drop=True)
# display(dti_matched_df)

# # Save to CSV
# dti_matched_df.to_csv('./data/process_data/dti_matched.csv', index=False)

#### 7.2.2 Cell line disease soft match

In [None]:
# # check the unmatched samples in the dti_overlapped_samples_desc_co_df by checking the NaN in the NCIt_BMGC_ID and ORDO_BMGC_ID
# unmatched_samples_df = dti_overlapped_samples_desc_co_df[dti_overlapped_samples_desc_co_df['NCIt_BMGC_ID'].isna() & dti_overlapped_samples_desc_co_df['ORDO_BMGC_ID'].isna()]
# # keep the columns ['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_name', , 'xref_ORDO_name']
# unmatched_samples_df = unmatched_samples_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_name', 'xref_ORDO_name']]
# display(unmatched_samples_df)

In [None]:
# embed the disease name for bmg_disease_df by bert-based model
import torch
import numpy as np
from tqdm import tqdm
from typing import List, Tuple, Dict
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset

class SentenceDataset(Dataset):
    def __init__(self, sentences: List[str]):
        self.sentences = sentences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx]


class TextEncoder():
    def __init__(self, model_path: str = "dmis-lab/biobert-v1.1", device: str = "cuda"):
        """
        Args:
            model_path (str, optional): Path to the deberta model. Defaults to 'dmis-lab/biobert-v1.1'.
            device (str, optional): Device to run the model on ('cpu' or 'cuda'). Defaults to 'cpu'.
        """
        self.model_path = model_path
        self.device = device
        self.model = None
        self.tokenizer = None

    def load_model(self):
        """
        Load the deberta model and tokenizer from the specified model path.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModel.from_pretrained(self.model_path).to(self.device)

    def generate_embeddings(self, sentences: List[str], batch_size: int = 32, seq_emb_dim: int = 64) -> torch.Tensor:
        """
        Generate a single-dimensional embedding for each sentence.

        Args:
            sentences (List[str]): List of sentences to embed.
            batch_size (int, optional): Batch size for DataLoader. Defaults to 32.

        Returns:
            List[float]: List of single-dimensional embeddings.
        """
        dataset = SentenceDataset(sentences)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        embedding_batches = []
        for batch in tqdm(dataloader, desc="Embedding sentences", unit="batch"):
            inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
            with torch.no_grad():
                outputs = self.model(**inputs)
            # Handle single batch case properly
            mean_embeddings = torch.mean(outputs.last_hidden_state, dim=1)  # [batch_size, hidden_dim]
            
            # For adaptive pooling, we need to reshape for 1D adaptive pooling
            # [batch_size, 1, hidden_dim] -> [batch_size, 1, seq_emb_dim] -> [batch_size, seq_emb_dim]
            batch_size = mean_embeddings.size(0)
            reshaped = mean_embeddings.view(batch_size, 1, -1)
            projected = torch.nn.functional.adaptive_avg_pool1d(reshaped, output_size=seq_emb_dim)
            projected = projected.squeeze(1)  # Only squeeze dimension 1, keep batch dimension
            embedding_batches.append(projected)
        return torch.cat(embedding_batches, dim=0)

    def save_embeddings(self, embeddings, output_npy_path):
        """
        Save embeddings to a .npy file.
        
        Args:
            embeddings (torch.Tensor): The embeddings to save.
            output_npy_path (str): Path to save the embeddings file.
        """
        # Move embeddings to CPU before converting to numpy
        embeddings_cpu = embeddings.cpu().numpy()
        np.save(output_npy_path, embeddings_cpu)
        print(f"Embeddings saved at {output_npy_path} with shape {embeddings_cpu.shape}")

In [None]:
# # convert the bmg_disease_name_df to the list of disease names of each rows
# disease_names = list(bmg_disease_name_df['Disease_Name_List'])
# # to make sure each disease name is a string
# disease_names = [str(name) for name in disease_names]
# print(len(disease_names))
# print(disease_names[:5])
# print(disease_names[-1])

In [None]:
#  # Use language model to embed the name and description
# name_sentence_list = disease_names
# text_encoder = TextEncoder()
# text_encoder.load_model()
# name_embeddings = text_encoder.generate_embeddings(name_sentence_list, batch_size=32, seq_emb_dim=768)
# print(name_embeddings.shape)
# # mkdir folder BMG_emb
# if not os.path.exists('./data/BMG_emb'):
#     os.makedirs('./data/BMG_emb')
# text_encoder.save_embeddings(name_embeddings, './data/BMG_emb/disease_name_embeddings.npy')

In [None]:
# # convert each row in the unmatched_samples_df to a sentence like "Name: NIH:OVCAR-3, Pathology: OVARY, Histology: CARCINOMA, type: CANCER, PATHOLOGIST_ANNOTATION: OVARY, tcga_code: OV"
# depmap_desc_sentence_list = []
# for idx, row in unmatched_samples_df.iterrows():
#     depmap_desc_sentence_list.append(f" {row['type']}, {row['PATHOLOGIST_ANNOTATION']}")

# # and create the ncit_desc_sentence_list and ordo_desc_sentence_list
# ncit_desc_sentence_list = list(unmatched_samples_df['xref_NCIt_name'])
# ordo_desc_sentence_list = list(unmatched_samples_df['xref_ORDO_name'])

# # make sure each element in the depmap_desc_sentence_list, ncit_desc_sentence_list and ordo_desc_sentence_list is a string
# depmap_desc_sentence_list = [str(name) for name in depmap_desc_sentence_list]
# ncit_desc_sentence_list = [str(name) for name in ncit_desc_sentence_list]
# ordo_desc_sentence_list = [str(name) for name in ordo_desc_sentence_list]

# print(depmap_desc_sentence_list)
# print(ncit_desc_sentence_list)
# print(ordo_desc_sentence_list)

In [None]:
# # Define the sentence lists to process and their corresponding prefixes
# sentence_list_mapping = {
#     'depmap': depmap_desc_sentence_list,
#     'ncit': ncit_desc_sentence_list,
#     'ordo': ordo_desc_sentence_list
# }

# # Define top_k for this matching task
# top_k = 2

# # Process each sentence list and add the matches to the DataFrame
# for prefix, sentence_list in sentence_list_mapping.items():
#     # Lists to store the matches for each description
#     all_matched_lists = [[] for _ in range(top_k)]
#     all_matched_name_lists = [[] for _ in range(top_k)]
#     all_similarity_lists = [[] for _ in range(top_k)]
    
#     for desc in sentence_list:
#         desc_embeddings = text_encoder.generate_embeddings([desc], batch_size=1, seq_emb_dim=768)
#         # Calculate cosine similarity
#         cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
#         similarity = cos(name_embeddings, desc_embeddings)
#         # Find the index of top k most similar diseases
#         top_k_idx = torch.argsort(similarity, descending=True)[:top_k]
        
#         # Process each of the top k matches
#         for rank, idx in enumerate(top_k_idx):
#             idx_int = idx.item()
#             disease_id = bmg_disease_name_df.iloc[idx_int]['BioMedGraphica_Conn_ID']
#             disease_name = disease_names[idx_int]
#             sim_score = similarity[idx_int].item()
            
#             # Store in the corresponding lists
#             all_matched_lists[rank].append(disease_id)
#             all_matched_name_lists[rank].append(disease_name)
#             all_similarity_lists[rank].append(sim_score)
    
#     # Add the matched data to the DataFrame with the appropriate prefix
#     for i in range(top_k):
#         rank = i + 1
#         unmatched_samples_df[f'{prefix}_match_{rank}_disease'] = all_matched_lists[i]
#         unmatched_samples_df[f'{prefix}_match_{rank}_disease_name'] = all_matched_name_lists[i]
#         unmatched_samples_df[f'{prefix}_match_{rank}_similarity'] = all_similarity_lists[i]
    
#     print(f"Finished processing {prefix} descriptions")

# # Display the updated DataFrame
# display(unmatched_samples_df)

# # Save the updated DataFrame to CSV
# unmatched_samples_df.to_csv('./data/process_data/unmatched_samples_softmatch.csv', index=False)

### 7.3 Manual filter and match

In [None]:
# unmatched_samples_df = pd.read_csv('./data/process_data/unmatched_samples_softmatch.csv')
# # add the 2 empty columnns 'BMGC_manual_ID' and 'BMGC_manual_name' to the unmatched_samples_manual_df
# unmatched_samples_manual_df = unmatched_samples_df.copy()
# unmatched_samples_manual_df['BMGC_manual_ID'] = None
# unmatched_samples_manual_df['BMGC_manual_name'] = None
# # change the order of the columns to make the 'BMGC_manual_ID' and 'BMGC_manual_name' after id
# unmatched_samples_manual_df = unmatched_samples_manual_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'BMGC_manual_ID', 'BMGC_manual_name', 'xref_NCIt_name', 'xref_ORDO_name', 'depmap_match_1_disease', 'depmap_match_1_disease_name', 'depmap_match_2_disease', 'depmap_match_2_disease_name', 'ncit_match_1_disease', 'ncit_match_1_disease_name', 'ncit_match_2_disease', 'ncit_match_2_disease_name', 'ordo_match_1_disease', 'ordo_match_1_disease_name', 'ordo_match_2_disease', 'ordo_match_2_disease_name']]
# display(unmatched_samples_manual_df)
# unmatched_samples_manual_df.to_csv('./data/process_data/unmatched_samples_manual.csv', index=False)

### 7.4 Combine matched and manual_unmatached

In [None]:
# Matched samples, rename BMGC_Matched_ID and BMGC_Matched_name to BMGC_ID and BMGC_name
dti_matched_df = pd.read_csv('./data/process_data/dti_matched.csv').rename(columns={'BMGC_Matched_ID': 'BMGC_Disease_ID', 'BMGC_Matched_name': 'BMGC_Disease_name'})
# keep the columns ['depMapID', 'Name', 'BMGC_Disease_ID', 'BMGC_Disease_name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'xref_NCIt_ID', 'xref_NCIt_name', 'xref_ORDO_ID', 'xref_ORDO_name']
dti_matched_df = dti_matched_df[['depMapID', 'Name', 'BMGC_Disease_ID', 'BMGC_Disease_name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code']]
display(dti_matched_df)
# Unmatched samples, rename BMGC_manual_ID and BMGC_manual_name to BMGC_Disease_ID and BMGC_Disease_name
dti_unmatched_manual_df = pd.read_csv('./data/process_data/unmatched_samples_manual.csv').rename(columns={'BMGC_manual_ID': 'BMGC_Disease_ID', 'BMGC_manual_name': 'BMGC_Disease_name'})
display(dti_unmatched_manual_df)

In [None]:
dti_unmatched_manual_valid_df = pd.read_csv('./data/manual_match_data/unmatched_samples_manual_valid.csv').rename(columns={'BMGC_manual_ID': 'BMGC_Disease_ID', 'BMGC_manual_name': 'BMGC_Disease_name'})
display(dti_unmatched_manual_valid_df)
# concatenate the dti_matched_df and dti_unmatched_manual_df and sort by depMapID
dti_combined_df = pd.concat([dti_matched_df, dti_unmatched_manual_valid_df], ignore_index=True)
# sort the dti_combined_df by depMapID
dti_combined_df = dti_combined_df.sort_values(by='depMapID').reset_index(drop=True)
# check if there is any empty value in the BMGC_ID and BMGC_name columns
print(dti_combined_df[['BMGC_Disease_ID', 'BMGC_Disease_name']].isnull().sum())
# convert all content in 'BMGC_Disease_name' to lower case
dti_combined_df['BMGC_Disease_name'] = dti_combined_df['BMGC_Disease_name'].str.lower()
display(dti_combined_df)
dti_combined_df.to_csv('./data/process_data/dti_combined_samples.csv', index=False)

## 8. CRISPR/RNAi Biomarkers

In [None]:
os.makedirs('./data/pretrain_plain_data', exist_ok=True)
os.makedirs('./data/pretrain_status_data', exist_ok=True)

In [None]:
# Create the name and desc files
bmgc_promoter_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_gene_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_transcript_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_protein_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])

bmgc_omics_name_id_tmp_df = pd.concat([bmgc_promoter_name_id_df, bmgc_gene_name_id_df, bmgc_transcript_name_id_df, bmgc_protein_name_id_df], axis=0).reset_index(drop=True)
bmgc_omics_name_id_df = pd.merge(bmgc_omics_df[['BioMedGraphica_Conn_ID']], bmgc_omics_name_id_tmp_df, on='BioMedGraphica_Conn_ID', how='left')
# check the null values in the bmgc_omics_name_id_df
print("Null values in bmgc_omics_name_id_df:")
print(bmgc_omics_name_id_df.isnull().sum())
# fill the NaN values in the Names_and_IDs column with empty string
bmgc_omics_name_id_df['Names_and_IDs'] = bmgc_omics_name_id_df['Names_and_IDs'].fillna(' ')
# recheck the null values in the bmgc_omics_name_id_dff
print("Null values in bmgc_omics_name_id_df:")
print(bmgc_omics_name_id_df.isnull().sum())
display(bmgc_omics_name_id_df)
bmgc_omics_name_id_df.to_csv('./data/pretrain_plain_data/bmgc_omics_name.csv', index=False)
bmgc_omics_name_id_df.to_csv('./data/pretrain_status_data/bmgc_omics_name.csv', index=False)

In [None]:
bmgc_promoter_desc_df = bmgc_promoter_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()
bmgc_promoter_desc_df['Description'] = np.nan # add the Description column to bmgc_promoter_name_df with NaN values
bmgc_gene_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)
bmgc_transcript_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)
bmgc_protein_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)

# concat and drop BioMedGraphica_ID column
bmgc_omics_desc_tmp_df = pd.concat([bmgc_promoter_desc_df, bmgc_gene_desc_df, bmgc_transcript_desc_df, bmgc_protein_desc_df], axis=0).reset_index(drop=True)
# check the null values in the bmgc_omics_desc_tmp_df
print(bmgc_omics_desc_tmp_df.isnull().sum())
bmgc_omics_desc_df = pd.merge(bmgc_omics_df[['BioMedGraphica_Conn_ID']], bmgc_omics_desc_tmp_df, on='BioMedGraphica_Conn_ID', how='left')
# check the null values in the bmgc_omics_desc_df
print(bmgc_omics_desc_df.isnull().sum())
# fill NaN values with empty string in Description column
bmgc_omics_desc_df['Description'] = bmgc_omics_desc_df['Description'].fillna(' ')
# recheck the null values in the bmgc_omics_desc_df
print(bmgc_omics_desc_df.isnull().sum())
display(bmgc_omics_desc_df)

bmgc_omics_desc_df.to_csv('./data/pretrain_plain_data/bmgc_omics_desc.csv', index=False)
bmgc_omics_desc_df.to_csv('./data/pretrain_status_data/bmgc_omics_desc.csv', index=False)

### 8.0 Entity markers

In [None]:
# translation chain converging to the same node
# fetch the promoter, gene, transcript and protein entity alone
promoter_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Promoter'].copy()
gene_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Gene'].copy()
transcript_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Transcript'].copy()
protein_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Protein'].copy()

display(bmgc_omics_relation_df)
# recheck the null values in bmgc_omics_relation_df
print("Null values in bmgc_omics_relation_df:")
print(bmgc_omics_relation_df.isnull().sum())

# fetch the Promoter-Gene, Gene-Transcript, Transcript-Protein relation alone
promoter_gene_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].copy()
gene_transcript_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].copy()
transcript_protein_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].copy()

In [None]:
gene_transcript_entity_df = pd.merge(gene_entity_df, gene_transcript_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')
gene_transcript_protein_entity_df = pd.merge(gene_transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BMGC_To_ID', right_on='BMGC_From_ID', how='outer')
# drop NaN values in BMGC_From_ID_x	BMGC_To_ID_x BMGC_From_ID_y	BMGC_To_ID_y
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.dropna(subset=['BMGC_From_ID_x', 'BMGC_To_ID_x', 'BMGC_From_ID_y', 'BMGC_To_ID_y']).reset_index(drop=True)
# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID_x', 'BMGC_To_ID_y'] and rename the columns to ['BMGC_GN_ID', 'BMGC_TS_ID', 'BMGC_PT_ID']
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID_y']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID', 'BMGC_To_ID_y': 'BMGC_PT_ID'}).sort_values(by='BMGC_GN_ID').reset_index(drop=True)
# drop duplicates rows in gene_transcript_protein_entity_df
gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)
display(gene_transcript_protein_entity_df)
# just copy gene_transcript_protein_entity_df as promoter_gene_transcript_protein_entity_df
bmgc_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')
bmgc_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')
promoter_gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.copy()
promoter_gene_df = pd.concat([bmgc_promoter_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PM_ID'}), bmgc_gene_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID'})], axis=1)
promoter_protein_entity_df = pd.merge(promoter_gene_transcript_protein_entity_df, promoter_gene_df, left_on='BMGC_GN_ID', right_on='BMGC_GN_ID', how='left').drop(columns=['BMGC_GN_ID'])
promoter_protein_entity_df = promoter_protein_entity_df[['BMGC_PM_ID', 'BMGC_PT_ID']]
display(promoter_protein_entity_df)
transcript_protein_entity_df = pd.merge(transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')
# drop NaN rows in the BMGC_From_ID and	BMGC_To_ID
transcript_protein_entity_df = transcript_protein_entity_df.dropna(subset=['BMGC_From_ID', 'BMGC_To_ID']).reset_index(drop=True)
# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID'] and rename the columns to ['BMGC_TS_ID', 'BMGC_PT_ID']
transcript_protein_entity_df = transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_TS_ID', 'BMGC_To_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_TS_ID').reset_index(drop=True)
# drop duplicates rows in transcript_protein_entity_df
transcript_protein_entity_df = transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)
display(transcript_protein_entity_df)

In [None]:
# keep the columns ['BioMedGraphica_Conn_ID'] and rename the columns to ['BMGC_PT_ID']
only_protein_entity_df = protein_entity_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_PT_ID').reset_index(drop=True)
display(only_protein_entity_df)

### 8.1 CRISPR top 100 gene entities

In [None]:
import pandas as pd
raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')
display(raw_crispr_df)

In [None]:
# get the column names of the raw_crispr_df aside from the first column and convert this to a list
raw_crispr_df_columns = raw_crispr_df.columns[1:].tolist()
print(raw_crispr_df_columns)

In [None]:
# First, set the "Unnamed: 0" column as the index:
raw_crispr_t_df = raw_crispr_df.set_index("Unnamed: 0")
# Then transpose the DataFrame:
raw_crispr_t_df = raw_crispr_t_df.transpose()
# move the index to a column and rename it as the gene_name
raw_crispr_t_df.reset_index(inplace=True)
raw_crispr_t_df.rename(columns={'index': 'gene_name'}, inplace=True)
# column gene_name should be kept with the content by removing "()" and remove and " " content
raw_crispr_t_df['gene_name'] = raw_crispr_t_df['gene_name'].apply(lambda x: x.split('(')[0].strip())
# Optionally, if you want to view the result:
display(raw_crispr_t_df)

In [None]:
# Identify columns to convert (all except 'gene_name')
cols_to_convert = raw_crispr_t_df.columns.difference(['gene_name'])
# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)
raw_crispr_t_df[cols_to_convert] = raw_crispr_t_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')
raw_crispr_t_df = raw_crispr_t_df.fillna(0.0)
display(raw_crispr_t_df)
# groupby the gene_name and calculate the mean
crispr_df = raw_crispr_t_df.groupby('gene_name', as_index=False).mean()
display(crispr_df)

In [None]:
bmg_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')
# keep BioMedGraphica_ID and HGNC_Symbol
bmg_gene_df = bmg_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]
display(bmg_gene_df)
# merge the biomedgraphica_id with the crispr_df
merged_crispr_df = pd.merge(bmg_gene_df, crispr_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
merged_crispr_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(merged_crispr_df)

In [None]:
# map the gene entity with gene_transcript_protein_entity_df
protein_crispr_df = pd.merge(gene_transcript_protein_entity_df, merged_crispr_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])
# and drop nan in any rows from protein_crispr_df
protein_crispr_df = protein_crispr_df.dropna().reset_index(drop=True)
display(protein_crispr_df)

In [None]:
overlapped_omics_union_annotation_samples

In [None]:
# Keep BioMedGraphica_Conn_ID and columns that are in overlapped_omics_union_annotation_samples
# For CRISPR data
columns_to_keep = ['BMGC_GN_ID', 'BMGC_PT_ID'] + [col for col in protein_crispr_df.columns if col in overlapped_omics_union_annotation_samples]
processed_crispr_df = protein_crispr_df[columns_to_keep]
display(processed_crispr_df)
processed_crispr_df.to_csv('./data/pretrain_plain_data/processed_crispr.csv', index=False)
processed_crispr_df.to_csv('./data/pretrain_status_data/processed_crispr.csv', index=False)

### 8.2 RNAi top 100 transcript entities

In [None]:
raw_rna_combined_df = pd.read_csv('./data/raw_data/D2_combined_gene_dep_scores.csv')
# fill NaN values with 0.0
raw_rna_combined_df = raw_rna_combined_df.fillna(0.0)
display(raw_rna_combined_df)
# Move the Unnamed: 0 to a column and rename it as the gene_name
raw_rna_combined_df.rename(columns={'Unnamed: 0': 'gene_name'}, inplace=True)
# Clean the gene_name column by removing parentheses and extra spaces
raw_rna_combined_df['gene_name'] = raw_rna_combined_df['gene_name'].apply(lambda x: x.split('(')[0].strip())
# Display the result
display(raw_rna_combined_df)

In [None]:
# get the overlapped cell lines in methy_samples with cell_line_anno_df by merge on cell line id
rna_combined_samples = raw_rna_combined_df.columns[1:].tolist()
# first format the methy_samples as dataframe
rna_combined_samples_df = pd.DataFrame(rna_combined_samples, columns=['CCLE_Name'])
# merge the methy_samples_df with the cell_line_anno_df
rna_combined_samples_anno_df = pd.merge(rna_combined_samples_df, cell_line_anno_df, left_on='CCLE_Name', right_on='CCLE_ID', how='inner')
display(rna_combined_samples_anno_df)

# get the map dictionary from the ccle id to depmap id
rna_combined_samples_map_dict = dict(zip(rna_combined_samples_anno_df['CCLE_Name'], rna_combined_samples_anno_df['depMapID']))
print(rna_combined_samples_map_dict)
maped_rna_combined_df = raw_rna_combined_df.rename(columns=rna_combined_samples_map_dict)
columns_to_keep = ['gene_name'] + sorted([col for col in maped_rna_combined_df.columns if col in overlapped_omics_union_annotation_samples])
filtered_rna_df = maped_rna_combined_df[columns_to_keep]
display(filtered_rna_df)

In [None]:
bmg_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')
# keep BioMedGraphica_ID and HGNC_Symbol
bmg_transcript_df = bmg_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]
display(bmg_transcript_df)
# merge the biomedgraphica_id with the raw_transcript_df
merge_rna_df = pd.merge(bmg_transcript_df, filtered_rna_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
merge_rna_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)
display(merge_rna_df)

In [None]:
# map the transcript entity with transcript_protein_entity_df
protein_rna_df = pd.merge(transcript_protein_entity_df, merge_rna_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])
# and drop nan in any rows from protein_rna_df
processed_rna_df = protein_rna_df.dropna().reset_index(drop=True)
display(processed_rna_df)
processed_rna_df.to_csv('./data/pretrain_plain_data/processed_rna.csv', index=False)
processed_rna_df.to_csv('./data/pretrain_status_data/processed_rna.csv', index=False)

In [None]:
processed_crispr_samples = processed_crispr_df.columns[2:].tolist()
processed_rna_samples = processed_rna_df.columns[2:].tolist()

# get the union and intersection samples of the two lists
overlapped_crispr_rna_samples = sorted(list(set(processed_crispr_samples) & set(processed_rna_samples)))
print("The overlapped samples between protein_crispr_samples and protein_rna_samples are:")
print(len(overlapped_crispr_rna_samples))
print(overlapped_crispr_rna_samples)

# get the union set of samples of the two lists
union_crispr_rna_samples = sorted(list(set(processed_crispr_samples) | set(processed_rna_samples)))
print("The union samples between protein_crispr_samples and protein_rna_samples are:")
print(len(union_crispr_rna_samples))
print(union_crispr_rna_samples)

In [None]:
# get the intersection based on the overlapped over omics not union over omics by overlapped_omics_annotation_samples
overlapped_omics_crispr_rna_samples = sorted(list(set(overlapped_crispr_rna_samples) & set(overlapped_omics_annotation_samples)))
print("The overlapped samples between protein_crispr_samples and protein_rna_samples are:")
print(len(overlapped_omics_crispr_rna_samples))
print(overlapped_omics_crispr_rna_samples)

# get the intersection based on the overlapped over omics not union over omics by overlapped_omics_annotation_samples for union_crispr_rna_samples
overlapped_omics_union_crispr_rna_samples = sorted(list(set(union_crispr_rna_samples) & set(overlapped_omics_annotation_samples)))
print("The overlapped samples between protein_crispr_samples and protein_rna_samples are:")
print(len(overlapped_omics_union_crispr_rna_samples))
print(overlapped_omics_union_crispr_rna_samples)

## 9. Sample Splits

### 9.1 Pretrain, Drug and Target Sample Splits

In [None]:
dti_overlapped_samples = list(dti_overlapped_samples_df['depMapID'])
print(dti_overlapped_samples)
print("length of dti_overlapped_samples: ", len(set(dti_overlapped_samples)))

#### 9.1.1 Pretrain Samples

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# Your existing code for rest_samples is good
rest_samples = sorted(list(set(overlapped_omics_union_annotation_samples) - set(dti_overlapped_samples)))
print("len(rest_samples):", len(rest_samples))
print(rest_samples[:5])

# check the non-cancerous samples in the rest_samples
non_cancerous_rest_samples = sorted(list(set(non_cancerous_samples_df['Depmap Id']) & set(rest_samples)))
print("len(non_cancerous_rest_samples):", len(non_cancerous_rest_samples))
print(non_cancerous_rest_samples[:5])

cancerous_rest_samples = sorted(list(set(rest_samples) - set(non_cancerous_rest_samples)))
print("len(cancerous_rest_samples):", len(cancerous_rest_samples))
print(cancerous_rest_samples[:5])

# convert the rest_samples to a dataframe with marking the sample cancerous or non-cancerous status
rest_samples_df = pd.DataFrame({
    'depMapID': rest_samples,
    'cancerous_status': ['cancerous' if sample in cancerous_rest_samples else 'non-cancerous' for sample in rest_samples]
})
display(rest_samples_df)

# Split the rest_samples into pretrain_plain_samples and pretrain_status_samples
# Sample 20% of the rest_samples for pretrain_plain_samples while maintaining the cancerous/non-cancerous ratio

def stratified_split_pretrain_samples(rest_samples_df, pretrain_plain_ratio=0.2, random_state=42):
    """
    Split rest samples into pretrain_plain and pretrain_status while maintaining 
    the cancerous/non-cancerous ratio in both splits.
    
    Args:
        rest_samples_df: DataFrame with depMapID and cancerous_status columns
        pretrain_plain_ratio: Ratio of samples to allocate to pretrain_plain (default 0.2)
        random_state: Random seed for reproducibility
    
    Returns:
        tuple: (pretrain_plain_samples, pretrain_status_samples)
    """
    # Use stratified split to maintain the ratio
    pretrain_plain_df, pretrain_status_df = train_test_split(
        rest_samples_df,
        test_size=1-pretrain_plain_ratio,  # 0.8 for pretrain_status
        stratify=rest_samples_df['cancerous_status'],
        random_state=random_state
    )
    
    # Extract sample lists
    pretrain_plain_samples = sorted(pretrain_plain_df['depMapID'].tolist())
    pretrain_status_samples = sorted(pretrain_status_df['depMapID'].tolist())
    
    return pretrain_plain_samples, pretrain_status_samples, pretrain_plain_df, pretrain_status_df

# Perform the split
pretrain_plain_samples, pretrain_status_samples, pretrain_plain_df, pretrain_status_df = stratified_split_pretrain_samples(
    rest_samples_df, 
    pretrain_plain_ratio=0.2, 
    random_state=42
)

# Display results
print(f"\n📊 Pretrain Sample Split Summary:")
print(f"Total rest samples: {len(rest_samples_df)}")
print(f"Pretrain plain samples: {len(pretrain_plain_samples)} ({len(pretrain_plain_samples)/len(rest_samples_df)*100:.1f}%)")
print(f"Pretrain status samples: {len(pretrain_status_samples)} ({len(pretrain_status_samples)/len(rest_samples_df)*100:.1f}%)")

print(f"\n📋 Pretrain Plain Sample Distribution:")
print(pretrain_plain_df['cancerous_status'].value_counts())
print(pretrain_plain_df['cancerous_status'].value_counts(normalize=True))

print(f"\n📋 Pretrain Status Sample Distribution:")
print(pretrain_status_df['cancerous_status'].value_counts())
print(pretrain_status_df['cancerous_status'].value_counts(normalize=True))

# Save the sample lists and detailed DataFrames
pretrain_plain_df.to_csv('./data/pretrain_plain_data/pretrain_plain_samples.csv', index=False)
pretrain_status_df.to_csv('./data/pretrain_status_data/pretrain_status_samples.csv', index=False)

# Also save just the sample IDs for convenience
pd.DataFrame(pretrain_plain_samples, columns=['depMapID']).to_csv('./data/pretrain_plain_data/pretrain_plain_sample_ids.csv', index=False)
pd.DataFrame(pretrain_status_samples, columns=['depMapID']).to_csv('./data/pretrain_status_data/pretrain_status_sample_ids.csv', index=False)

print(f"\n✅ Saved pretrain sample splits to:")
print(f"- ./data/pretrain_plain_data/pretrain_plain_samples.csv ({len(pretrain_plain_samples)} samples)")
print(f"- ./data/pretrain_status_data/pretrain_status_samples.csv ({len(pretrain_status_samples)} samples)")

# Verify the splits don't overlap
overlap_check = set(pretrain_plain_samples) & set(pretrain_status_samples)
print(f"\n🔍 Overlap check: {len(overlap_check)} overlapping samples (should be 0)")

# Show sample cancerous/non-cancerous ratios
pretrain_plain_cancerous = pretrain_plain_df['cancerous_status'].value_counts()
pretrain_status_cancerous = pretrain_status_df['cancerous_status'].value_counts()

print(f"\n📈 Final Cancerous/Non-cancerous Distribution:")
print(f"Pretrain Plain - Cancerous: {pretrain_plain_cancerous.get('cancerous', 0)}, Non-cancerous: {pretrain_plain_cancerous.get('non-cancerous', 0)}")
print(f"Pretrain Status - Cancerous: {pretrain_status_cancerous.get('cancerous', 0)}, Non-cancerous: {pretrain_status_cancerous.get('non-cancerous', 0)}")

# Split the pretrain_status_samples into training and test sets according to the cancerous status
def stratified_split_status_samples(pretrain_status_df, test_size=0.2, random_state=42):
    """
    Split pretrain status samples into training and test sets while maintaining 
    the cancerous/non-cancerous ratio.
    
    Args:
        pretrain_status_df: DataFrame with depMapID and cancerous_status columns
        test_size: Proportion of the dataset to include in the test split (default 0.2)
        random_state: Random seed for reproducibility
    
    Returns:
        tuple: (train_samples, test_samples, train_df, test_df)
    """
    # Use stratified split to maintain the ratio
    train_df, test_df = train_test_split(
        pretrain_status_df,
        test_size=test_size,
        stratify=pretrain_status_df['cancerous_status'],
        random_state=random_state
    )
    
    # Extract sample lists
    train_samples = sorted(train_df['depMapID'].tolist())
    test_samples = sorted(test_df['depMapID'].tolist())
    
    return train_samples, test_samples, train_df, test_df

# Perform the split
pretrain_status_train_samples, pretrain_status_test_samples, pretrain_status_train_df, pretrain_status_test_df = stratified_split_status_samples(
    pretrain_status_df, 
    test_size=0.2, 
    random_state=42
)

# Display results for training and test sets for cancerous and non-cancerous counts
print(f"\n📊 Pretrain Status Sample Split Summary:"
      f"\nTotal pretrain status samples: {len(pretrain_status_df)}")
print(f"Training samples: {len(pretrain_status_train_samples)} ({len(pretrain_status_train_samples)/len(pretrain_status_df)*100:.1f}%)")
print(f"Test samples: {len(pretrain_status_test_samples)} ({len(pretrain_status_test_samples)/len(pretrain_status_df)*100:.1f}%)")
print(f"\n📋 Training Sample Distribution:"
      f"\n{pretrain_status_train_df['cancerous_status'].value_counts()}")
print(pretrain_status_train_df['cancerous_status'].value_counts(normalize=True))
print(f"\n📋 Test Sample Distribution:"
      f"\n{pretrain_status_test_df['cancerous_status'].value_counts()}")

# Save the training and test sample lists and detailed DataFrames
pretrain_status_train_df.to_csv('./data/pretrain_status_data/pretrain_status_train_samples.csv', index=False)
pretrain_status_test_df.to_csv('./data/pretrain_status_data/pretrain_status_test_samples.csv', index=False)


#### 9.1.2 Target-CRISPR Samples

In [None]:
# Check the overlapped samples between dti_overlapped_samples and overlapped_crispr_rna_samples
overlapped_dti_crispr_rna_samples = sorted(list(set(dti_overlapped_samples) & set(overlapped_crispr_rna_samples)))
print("The overlapped samples between dti_overlapped_samples and overlapped_crispr_rna_samples are:")
print(len(overlapped_dti_crispr_rna_samples))
print(overlapped_dti_crispr_rna_samples)

# Add the sample information to overlapped_dti_crispr_rna_samples with tcga_code, type, and PATHOLOGIST_ANNOTATION
overlapped_dti_crispr_rna_samples_df = dti_overlapped_samples_df[dti_overlapped_samples_df['depMapID'].isin(overlapped_dti_crispr_rna_samples)].reset_index(drop=True)
# fill NaN values in the columns with 'Unknown'
overlapped_dti_crispr_rna_samples_df['tcga_code'] = overlapped_dti_crispr_rna_samples_df['tcga_code'].fillna('Unknown')
display(overlapped_dti_crispr_rna_samples_df)

In [None]:
import pandas as pd
import numpy as np
import os

def stratified_tcga_split_target(df, test_ratio=0.2, random_state=42):
    """
    Split samples by TCGA code with special handling for small groups:
    - For TCGA codes with >=5 samples: split according to test_ratio (0.8/0.2)
    - For TCGA codes with 2-4 samples: keep 1 sample for test
    - For TCGA codes with 1 sample: keep that sample for train
    """
    np.random.seed(random_state)
    
    train_samples = []
    test_samples = []
    
    # Group by TCGA code
    tcga_groups = df.groupby('tcga_code')
    
    print("TCGA Code Distribution and Split Strategy for Target Samples:")
    print("-" * 70)
    
    for tcga_code, group in tcga_groups:
        n_samples = len(group)
        samples = group['depMapID'].tolist()
        
        # Shuffle samples within each TCGA group
        np.random.shuffle(samples)
        
        if n_samples == 1:
            # Single sample: keep for train
            train_samples.extend(samples)
            test_samples_for_group = []
            strategy = "1 sample -> train only"
            
        elif n_samples < 5:
            # 2-4 samples: keep 1 for test, rest for train
            test_samples_for_group = samples[:1]
            train_samples_for_group = samples[1:]
            train_samples.extend(train_samples_for_group)
            test_samples.extend(test_samples_for_group)
            strategy = f"{n_samples} samples -> 1 test, {n_samples-1} train"
            
        else:
            # 5+ samples: use the specified ratio
            n_test = max(1, int(n_samples * test_ratio))
            test_samples_for_group = samples[:n_test]
            train_samples_for_group = samples[n_test:]
            train_samples.extend(train_samples_for_group)
            test_samples.extend(test_samples_for_group)
            strategy = f"{n_samples} samples -> {n_test} test, {len(train_samples_for_group)} train"
        
        print(f"TCGA {tcga_code}: {strategy}")
    
    print("-" * 70)
    print(f"Total train samples: {len(train_samples)}")
    print(f"Total test samples: {len(test_samples)}")
    print(f"Train ratio: {len(train_samples)/(len(train_samples)+len(test_samples)):.3f}")
    print(f"Test ratio: {len(test_samples)/(len(train_samples)+len(test_samples)):.3f}")
    
    return train_samples, test_samples

# Prepare the overlapped samples with TCGA annotation
overlapped_dti_crispr_rna_samples_df_annotated = overlapped_dti_crispr_rna_samples_df.copy()
# Fill NaN TCGA codes with 'Unknown' if any exist
overlapped_dti_crispr_rna_samples_df_annotated['tcga_code'] = overlapped_dti_crispr_rna_samples_df_annotated['tcga_code'].fillna('Unknown')

# Perform stratified split based on TCGA codes
target_crispr_train_samples, target_crispr_test_samples = stratified_tcga_split_target(
    overlapped_dti_crispr_rna_samples_df_annotated, 
    test_ratio=0.2, 
    random_state=42
)

# Count the samples
print(f"\n📊 Target Sample Count Summary:")
print(f"Total samples in dataset: {len(overlapped_dti_crispr_rna_samples_df_annotated)}")
print(f"Training samples: {len(target_crispr_train_samples)}")
print(f"Test samples: {len(target_crispr_test_samples)}")
print(f"Training percentage: {len(target_crispr_train_samples)/len(overlapped_dti_crispr_rna_samples_df_annotated)*100:.1f}%")
print(f"Test percentage: {len(target_crispr_test_samples)/len(overlapped_dti_crispr_rna_samples_df_annotated)*100:.1f}%")

# Create directories
if not os.path.exists('./data/TargetQA'):
    os.makedirs('./data/TargetQA')
if not os.path.exists('./data/TargetScreen'):
    os.makedirs('./data/TargetScreen')

# Sort the samples before saving
train_samples_sorted = sorted(target_crispr_train_samples)
test_samples_sorted = sorted(target_crispr_test_samples)

# Save the sorted train and test samples to CSV files
pd.DataFrame(train_samples_sorted, columns=['depMapID']).to_csv('./data/TargetQA/train_samples.csv', index=False)
pd.DataFrame(test_samples_sorted, columns=['depMapID']).to_csv('./data/TargetQA/test_samples.csv', index=False)
pd.DataFrame(train_samples_sorted, columns=['depMapID']).to_csv('./data/TargetScreen/train_samples.csv', index=False)
pd.DataFrame(test_samples_sorted, columns=['depMapID']).to_csv('./data/TargetScreen/test_samples.csv', index=False)

# Create detailed split information DataFrames
train_target_samples_df = overlapped_dti_crispr_rna_samples_df_annotated[
    overlapped_dti_crispr_rna_samples_df_annotated['depMapID'].isin(target_crispr_train_samples)
].reset_index(drop=True)

test_target_samples_df = overlapped_dti_crispr_rna_samples_df_annotated[
    overlapped_dti_crispr_rna_samples_df_annotated['depMapID'].isin(target_crispr_test_samples)
].reset_index(drop=True)

# Display the TCGA distributions in train and test sets
print("\n📋 Train set TCGA distribution:")
print(train_target_samples_df['tcga_code'].value_counts().sort_index())
print("\n📋 Test set TCGA distribution:")
print(test_target_samples_df['tcga_code'].value_counts().sort_index())

# Save detailed sample information
train_target_samples_df.to_csv('./data/TargetQA/train_samples_detailed.csv', index=False)
test_target_samples_df.to_csv('./data/TargetQA/test_samples_detailed.csv', index=False)
train_target_samples_df.to_csv('./data/TargetScreen/train_samples_detailed.csv', index=False)
test_target_samples_df.to_csv('./data/TargetScreen/test_samples_detailed.csv', index=False)

print(f"\n✅ Saved stratified train and test splits to:")
print(f"- ./data/TargetQA/train_samples.csv ({len(train_samples_sorted)} samples)")
print(f"- ./data/TargetQA/test_samples.csv ({len(test_samples_sorted)} samples)")
print(f"- ./data/TargetScreen/train_samples.csv ({len(train_samples_sorted)} samples)")
print(f"- ./data/TargetScreen/test_samples.csv ({len(test_samples_sorted)} samples)")
print(f"- Detailed sample information with TCGA annotations also saved")

#### 9.1.3 Drug Samples

In [None]:
# Keep the left samples dti_overlapped_samples - overlapped_dti_crispr_rna_samples
remaining_dti_samples = sorted(list(set(dti_overlapped_samples) - set(overlapped_dti_crispr_rna_samples)))
print("The remaining samples in dti_overlapped_samples after removing overlapped_dti_crispr_rna_samples are:")
print(len(remaining_dti_samples))

In [None]:
import pandas as pd
import numpy as np
import os

# Get the remaining DTI samples
remaining_dti_samples = sorted(list(set(dti_overlapped_samples) - set(overlapped_dti_crispr_rna_samples)))
print("The remaining samples in dti_overlapped_samples after removing overlapped_dti_crispr_rna_samples are:")
print(len(remaining_dti_samples))

# Create DataFrame with remaining samples and their TCGA annotations
remaining_dti_samples_df = dti_overlapped_samples_df[
    dti_overlapped_samples_df['depMapID'].isin(remaining_dti_samples)
].reset_index(drop=True)

# Fill NaN TCGA codes with 'Unknown' if any exist
remaining_dti_samples_df['tcga_code'] = remaining_dti_samples_df['tcga_code'].fillna('Unknown')

def stratified_tcga_split_dti(df, test_ratio=0.2, random_state=42):
    """
    Split samples by TCGA code with special handling for small groups:
    - For TCGA codes with >=5 samples: split according to test_ratio (0.8/0.2)
    - For TCGA codes with 2-4 samples: keep 1 sample for test
    - For TCGA codes with 1 sample: keep that sample for train
    """
    np.random.seed(random_state)
    
    train_samples = []
    test_samples = []
    
    # Group by TCGA code
    tcga_groups = df.groupby('tcga_code')
    
    print("TCGA Code Distribution and Split Strategy for DTI Drug Samples:")
    print("-" * 70)
    
    for tcga_code, group in tcga_groups:
        n_samples = len(group)
        samples = group['depMapID'].tolist()
        
        # Shuffle samples within each TCGA group
        np.random.shuffle(samples)
        
        if n_samples == 1:
            # Single sample: keep for train
            train_samples.extend(samples)
            test_samples_for_group = []
            strategy = "1 sample -> train only"
            
        elif n_samples < 5:
            # 2-4 samples: keep 1 for test, rest for train
            test_samples_for_group = samples[:1]
            train_samples_for_group = samples[1:]
            train_samples.extend(train_samples_for_group)
            test_samples.extend(test_samples_for_group)
            strategy = f"{n_samples} samples -> 1 test, {n_samples-1} train"
            
        else:
            # 5+ samples: use the specified ratio
            n_test = max(1, int(n_samples * test_ratio))
            test_samples_for_group = samples[:n_test]
            train_samples_for_group = samples[n_test:]
            train_samples.extend(train_samples_for_group)
            test_samples.extend(test_samples_for_group)
            strategy = f"{n_samples} samples -> {n_test} test, {len(train_samples_for_group)} train"
        
        print(f"TCGA {tcga_code}: {strategy}")
    
    print("-" * 70)
    print(f"Total train samples: {len(train_samples)}")
    print(f"Total test samples: {len(test_samples)}")
    print(f"Train ratio: {len(train_samples)/(len(train_samples)+len(test_samples)):.3f}")
    print(f"Test ratio: {len(test_samples)/(len(train_samples)+len(test_samples)):.3f}")
    
    return train_samples, test_samples

# Perform stratified split based on TCGA codes for remaining DTI samples
dti_train_samples, dti_test_samples = stratified_tcga_split_dti(
    remaining_dti_samples_df, 
    test_ratio=0.2, 
    random_state=42
)

# Count the samples
print(f"\n📊 DTI Drug Sample Count Summary:")
print(f"Total remaining DTI samples: {len(remaining_dti_samples_df)}")
print(f"Training samples: {len(dti_train_samples)}")
print(f"Test samples: {len(dti_test_samples)}")
print(f"Training percentage: {len(dti_train_samples)/len(remaining_dti_samples_df)*100:.1f}%")
print(f"Test percentage: {len(dti_test_samples)/len(remaining_dti_samples_df)*100:.1f}%")

# Create directories
if not os.path.exists('./data/DrugQA'):
    os.makedirs('./data/DrugQA')
if not os.path.exists('./data/DrugScreen'):
    os.makedirs('./data/DrugScreen')

# Sort the samples before saving
dti_train_samples_sorted = sorted(dti_train_samples)
dti_test_samples_sorted = sorted(dti_test_samples)

remaining_dti_samples_df.to_csv('./data/DrugQA/remaining_dti_samples.csv', index=False)
remaining_dti_samples_df.to_csv('./data/DrugScreen/remaining_dti_samples.csv', index=False)
# Save the sorted train and test samples to CSV files
pd.DataFrame(dti_train_samples_sorted, columns=['depMapID']).to_csv('./data/DrugQA/train_samples.csv', index=False)
pd.DataFrame(dti_test_samples_sorted, columns=['depMapID']).to_csv('./data/DrugQA/test_samples.csv', index=False)
pd.DataFrame(dti_train_samples_sorted, columns=['depMapID']).to_csv('./data/DrugScreen/train_samples.csv', index=False)
pd.DataFrame(dti_test_samples_sorted, columns=['depMapID']).to_csv('./data/DrugScreen/test_samples.csv', index=False)

# Create detailed split information DataFrames
train_dti_samples_df = remaining_dti_samples_df[
    remaining_dti_samples_df['depMapID'].isin(dti_train_samples)
].reset_index(drop=True)

test_dti_samples_df = remaining_dti_samples_df[
    remaining_dti_samples_df['depMapID'].isin(dti_test_samples)
].reset_index(drop=True)

# Display the TCGA distributions in train and test sets
print("\n📋 DTI Drug Train set TCGA distribution:")
print(train_dti_samples_df['tcga_code'].value_counts().sort_index())
print("\n📋 DTI Drug Test set TCGA distribution:")
print(test_dti_samples_df['tcga_code'].value_counts().sort_index())

# Save detailed sample information
train_dti_samples_df.to_csv('./data/DrugQA/train_samples_detailed.csv', index=False)
test_dti_samples_df.to_csv('./data/DrugQA/test_samples_detailed.csv', index=False)
train_dti_samples_df.to_csv('./data/DrugScreen/train_samples_detailed.csv', index=False)
test_dti_samples_df.to_csv('./data/DrugScreen/test_samples_detailed.csv', index=False)

print(f"\n✅ Saved stratified DTI drug train and test splits to:")
print(f"- ./data/DrugQA/train_samples.csv ({len(dti_train_samples_sorted)} samples)")
print(f"- ./data/DrugQA/test_samples.csv ({len(dti_test_samples_sorted)} samples)")
print(f"- ./data/DrugScreen/train_samples.csv ({len(dti_train_samples_sorted)} samples)")
print(f"- ./data/DrugScreen/test_samples.csv ({len(dti_test_samples_sorted)} samples)")
print(f"- Detailed sample information with TCGA annotations also saved")

# Summary of all splits
print(f"\n📈 Complete Sample Split Summary:")
print(f"Original DTI overlapped samples: {len(dti_overlapped_samples)}")
print(f"Target identification samples (CRISPR/RNA): {len(overlapped_dti_crispr_rna_samples)}")
print(f"Drug screening samples (remaining): {len(remaining_dti_samples)}")
print(f"  - DTI Drug Train: {len(dti_train_samples)}")
print(f"  - DTI Drug Test: {len(dti_test_samples)}")

### 9.2 Data Integration

#### 9.2.0 Data Preparation

In [None]:
maped_methy_df = final_merged_methy_df.rename(columns=methy_map_dict)
display(maped_methy_df)
mapped_protein_df = final_merged_protein_df.rename(columns=protein_map_dict)
display(mapped_protein_df)
print("Sum of all values in mapped_protein_df:", mapped_protein_df.iloc[:, 1:].sum().sum())
# check if all values in mapped_protein_df are zero
print("Are all values in mapped_protein_df zero?", (mapped_protein_df.iloc[:, 1:] == 0).all().all())

#### 9.2.0 All tasks shared the same edge_index and nodes_index system

In [None]:
# create an index data file for nodes and save it to the DTI_data folder
nodes = bmgc_entity_df['BioMedGraphica_Conn_ID'].tolist()
# create nodes index ranging from 0 to len(nodes)-1
nodes_index = np.arange(len(nodes))
nodes_index_data = pd.DataFrame({'Node': nodes, 'Index': nodes_index})
nodes_index_data = pd.merge(nodes_index_data, bmgc_entity_df[['BioMedGraphica_Conn_ID', 'Type']], left_on='Node', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])
display(nodes_index_data)
nodes_index_data.to_csv('./data/pretrain_plain_data/nodes_index.csv', index=False)
nodes_index_data.to_csv('./data/pretrain_status_data/nodes_index.csv', index=False)
nodes_index_data.to_csv('./data/TargetQA/nodes_index.csv', index=False)
nodes_index_data.to_csv('./data/TargetScreen/nodes_index.csv', index=False)
nodes_index_data.to_csv('./data/DrugQA/nodes_index.csv', index=False)
nodes_index_data.to_csv('./data/DrugScreen/nodes_index.csv', index=False)
node_index_dict = dict(zip(nodes_index_data['Node'], nodes_index_data['Index']))
# Print the first 10 items of node_index_dict
print("\nFirst 10 items in node_index_dict:")
print(list(node_index_dict.items())[:10])

# Convert the bmgc_relation_df to numpy array and save it to the DTI_data folder as the edge_index.npy
# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']
edge_index_df = bmgc_relation_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()
# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict
edge_index_df['BMGC_From_ID'] = edge_index_df['BMGC_From_ID'].map(node_index_dict)
edge_index_df['BMGC_To_ID'] = edge_index_df['BMGC_To_ID'].map(node_index_dict)
# check the null values in the edge_index_df
print("Null values in edge_index_df:")
print(edge_index_df.isnull().sum())
display(edge_index_df)
# convert the edge_index_df to numpy array and save it to the DTI_data folder
edge_index_array = edge_index_df.to_numpy().T
print('The shape of edge_index_array is:', edge_index_array.shape)
print(edge_index_array)
# save the numpy array to the DTI_data folder
np.save('./data/pretrain_plain_data/edge_index.npy', edge_index_array)
np.save('./data/pretrain_status_data/edge_index.npy', edge_index_array)
np.save('./data/TargetQA/edge_index.npy', edge_index_array)
np.save('./data/TargetScreen/edge_index.npy', edge_index_array)
np.save('./data/DrugQA/edge_index.npy', edge_index_array)
np.save('./data/DrugScreen/edge_index.npy', edge_index_array)
# generate the internal_edge_index by selecting bmgc_relation_df in Type ['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein']
internal_edge_index_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein'])].copy()
# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']
internal_edge_index_df = internal_edge_index_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()
# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict
internal_edge_index_df['BMGC_From_ID'] = internal_edge_index_df['BMGC_From_ID'].map(node_index_dict)
internal_edge_index_df['BMGC_To_ID'] = internal_edge_index_df['BMGC_To_ID'].map(node_index_dict)
# check the null values in the internal_edge_index_df
print("Null values in internal_edge_index_df:")
print(internal_edge_index_df.isnull().sum())
# convert the internal_edge_index_df to numpy array and save it to the DTI_data folder
internal_edge_index_array = internal_edge_index_df.to_numpy().T
print('The shape of internal_edge_index_array is:', internal_edge_index_array.shape)
print(internal_edge_index_array)
# save the numpy array to the DTI_data folder
np.save('./data/pretrain_plain_data/internal_edge_index.npy', internal_edge_index_array)
np.save('./data/pretrain_status_data/internal_edge_index.npy', internal_edge_index_array)
np.save('./data/TargetQA/internal_edge_index.npy', internal_edge_index_array)
np.save('./data/TargetScreen/internal_edge_index.npy', internal_edge_index_array)
np.save('./data/DrugQA/internal_edge_index.npy', internal_edge_index_array)
np.save('./data/DrugScreen/internal_edge_index.npy', internal_edge_index_array)
# generate the ppi_edge_index by selecting bmgc_relation_df in Type ['Protein-Protein']
ppi_edge_index_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Protein-Protein'])].copy()
# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']
ppi_edge_index_df = ppi_edge_index_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()
# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict
ppi_edge_index_df['BMGC_From_ID'] = ppi_edge_index_df['BMGC_From_ID'].map(node_index_dict)
ppi_edge_index_df['BMGC_To_ID'] = ppi_edge_index_df['BMGC_To_ID'].map(node_index_dict)
# check the null values in the ppi_edge_index_df
print("Null values in ppi_edge_index_df:")
print(ppi_edge_index_df.isnull().sum())
# convert the ppi_edge_index_df to numpy array and save it to the DTI_data folder
ppi_edge_index_array = ppi_edge_index_df.to_numpy().T
print('The shape of ppi_edge_index_array is:', ppi_edge_index_array.shape)
print(ppi_edge_index_array)
# save the numpy array to the DTI_data folder
np.save('./data/pretrain_plain_data/ppi_edge_index.npy', ppi_edge_index_array)
np.save('./data/pretrain_status_data/ppi_edge_index.npy', ppi_edge_index_array)
np.save('./data/TargetQA/ppi_edge_index.npy', ppi_edge_index_array)
np.save('./data/TargetScreen/ppi_edge_index.npy', ppi_edge_index_array)
np.save('./data/DrugQA/ppi_edge_index.npy', ppi_edge_index_array)
np.save('./data/DrugScreen/ppi_edge_index.npy', ppi_edge_index_array)

#### 9.2.0 All tasks shared the same textual description

In [None]:
# Create the name and desc files
bmgc_promoter_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_gene_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_transcript_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_protein_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_pathway_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Pathway/BioMedGraphica_Conn_Pathway_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_metabolite_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Metabolite/BioMedGraphica_Conn_Metabolite_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_microbiota_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Microbiota/BioMedGraphica_Conn_Microbiota_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_exposure_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Exposure/BioMedGraphica_Conn_Exposure_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_phenotype_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Phenotype/BioMedGraphica_Conn_Phenotype_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_disease_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_drug_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])
# get the number of rows in each dataframe and sum the number of rows for those dataframes
print("bmgc_promoter_name_id_df:", len(bmgc_promoter_name_id_df))
print("bmgc_gene_name_id_df:", len(bmgc_gene_name_id_df))
print("bmgc_transcript_name_id_df:", len(bmgc_transcript_name_id_df))
print("bmgc_protein_name_id_df:", len(bmgc_protein_name_id_df))
print("bmgc_pathway_name_id_df:", len(bmgc_pathway_name_id_df))
print("bmgc_metabolite_name_id_df:", len(bmgc_metabolite_name_id_df))
print("bmgc_microbiota_name_id_df:", len(bmgc_microbiota_name_id_df))
print("bmgc_exposure_name_id_df:", len(bmgc_exposure_name_id_df))
print("bmgc_phenotype_name_id_df:", len(bmgc_phenotype_name_id_df))
print("bmgc_disease_name_id_df:", len(bmgc_disease_name_id_df))
print("bmgc_drug_name_id_df:", len(bmgc_drug_name_id_df))
# sum the number of rows for those dataframes
print("Total number of rows in all dataframes:", len(bmgc_promoter_name_id_df) + len(bmgc_gene_name_id_df) + len(bmgc_transcript_name_id_df) + len(bmgc_protein_name_id_df) + len(bmgc_pathway_name_id_df) + len(bmgc_metabolite_name_id_df) + len(bmgc_microbiota_name_id_df) + len(bmgc_exposure_name_id_df) + len(bmgc_phenotype_name_id_df) + len(bmgc_disease_name_id_df) + len(bmgc_drug_name_id_df))

In [None]:
# left join those name_df into bmgc_entity_df
bmgc_entity_cp_df = bmgc_entity_df[['BioMedGraphica_Conn_ID']].copy()
display(bmgc_entity_cp_df)
# concatenate all the name_id_df into one dataframe
bmgc_name_id_tmp_df = pd.concat([bmgc_promoter_name_id_df, bmgc_gene_name_id_df, bmgc_transcript_name_id_df, bmgc_protein_name_id_df, bmgc_pathway_name_id_df, bmgc_metabolite_name_id_df, bmgc_microbiota_name_id_df, bmgc_exposure_name_id_df, bmgc_phenotype_name_id_df, bmgc_disease_name_id_df, bmgc_drug_name_id_df], ignore_index=True)
display(bmgc_name_id_tmp_df)
# left join the bmgc_name_id_tmp_df into bmgc_entity_cp_df on BioMedGraphica_Conn_ID
bmgc_entity_cp_df = pd.merge(bmgc_entity_cp_df, bmgc_name_id_tmp_df, on='BioMedGraphica_Conn_ID', how='left')
display(bmgc_entity_cp_df)
bmgc_entity_cp_df.to_csv('./data/pretrain_plain_data/bmgc_name.csv', index=False)
bmgc_entity_cp_df.to_csv('./data/pretrain_status_data/bmgc_name.csv', index=False)
bmgc_entity_cp_df.to_csv('./data/TargetScreen/bmgc_name.csv', index=False)
bmgc_entity_cp_df.to_csv('./data/TargetQA/bmgc_name.csv', index=False)
bmgc_entity_cp_df.to_csv('./data/DrugScreen/bmgc_name.csv', index=False)
bmgc_entity_cp_df.to_csv('./data/DrugQA/bmgc_name.csv', index=False)

In [None]:
bmgc_promoter_desc_df = bmgc_promoter_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()
bmgc_promoter_desc_df['Description'] = np.nan # add the Description column to bmgc_promoter_name_df with NaN values
bmgc_gene_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_transcript_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_protein_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_pathway_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Pathway/BioMedGraphica_Conn_Pathway_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_metabolite_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Metabolite/BioMedGraphica_Conn_Metabolite_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_microbiota_desc_df = bmgc_microbiota_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()
bmgc_microbiota_desc_df['Description'] = np.nan # add the Description column to bmgc_microbiota_name_df with NaN values
bmgc_exposure_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Exposure/BioMedGraphica_Conn_Exposure_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_phenotype_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Phenotype/BioMedGraphica_Conn_Phenotype_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_disease_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])
bmgc_drug_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])

In [None]:
bmgc_desc_tmp_df = pd.concat([bmgc_promoter_desc_df, bmgc_gene_desc_df, bmgc_transcript_desc_df, bmgc_protein_desc_df, bmgc_pathway_desc_df, bmgc_metabolite_desc_df, bmgc_microbiota_desc_df, bmgc_exposure_desc_df, bmgc_phenotype_desc_df, bmgc_disease_desc_df, bmgc_drug_desc_df], ignore_index=True)
display(bmgc_desc_tmp_df)
# left join the bmgc_desc_tmp_df into bmgc_entity_cp_df on BioMedGraphica_Conn_ID
bmgc_entity_cp_df = bmgc_entity_df[['BioMedGraphica_Conn_ID']].copy()
bmgc_desc_df = pd.merge(bmgc_entity_cp_df, bmgc_desc_tmp_df, on='BioMedGraphica_Conn_ID', how='left')
# check the null values in the bmgc_desc_df
print("Null values in bmgc_desc_df:")
print(bmgc_desc_df.isnull().sum())
# fill the NaN values in the Description column with empty string
bmgc_desc_df['Description'] = bmgc_desc_df['Description'].fillna(' ')
# recheck the null values in the bmgc_desc_df
print("Null values in bmgc_desc_df:")
print(bmgc_desc_df.isnull().sum())
display(bmgc_desc_df)
bmgc_desc_df.to_csv('./data/pretrain_plain_data/bmgc_desc.csv', index=False)
bmgc_desc_df.to_csv('./data/pretrain_status_data/bmgc_desc.csv', index=False)
bmgc_desc_df.to_csv('./data/TargetScreen/bmgc_desc.csv', index=False)
bmgc_desc_df.to_csv('./data/TargetQA/bmgc_desc.csv', index=False)
bmgc_desc_df.to_csv('./data/DrugScreen/bmgc_desc.csv', index=False)
bmgc_desc_df.to_csv('./data/DrugQA/bmgc_desc.csv', index=False)

#### 9.2.1 Pretrain data integration (plain + status)

In [None]:
print("The overlapped samples dataframe for methylation between pretrain_plain_samples and pretrain_status_samples are:")
# insert the samples that is in the overlapped_samples but not in the maped_methy_df
for sample in pretrain_plain_samples:
    if sample not in maped_methy_df.columns:
        maped_methy_df[sample] = 0.0
# filter out the samples that is not in the pretraining_samples
pretrain_plain_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]
display(pretrain_plain_methy_df)
# insert the samples that is in the overlapped_samples but not in the maped_methy_df
for sample in pretrain_status_samples:
    if sample not in maped_methy_df.columns:
        maped_methy_df[sample] = 0.0
# filter out the samples that is not in the pretraining_samples
pretrain_status_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]
display(pretrain_status_methy_df)

print("The overlapped samples dataframe for protein between pretrain_plain_samples and pretrain_status_samples are:")
# insert the samples that is in the pretraining_samples but not in the mapped_protein_df
for sample in pretrain_plain_samples:
    if sample not in mapped_protein_df.columns:
        mapped_protein_df[sample] = 0.0
# filter out the samples that is not in the pretraining_samples
pretrain_plain_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]
display(pretrain_plain_protein_df)
# insert the samples that is in the pretraining_samples but not in the mapped_protein_df
for sample in pretrain_status_samples:
    if sample not in mapped_protein_df.columns:
        mapped_protein_df[sample] = 0.0
# filter out the samples that is not in the pretrain_status_samples
pretrain_status_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]
display(pretrain_status_protein_df)

In [None]:
# fill in the missing samples in the final_merged_gene_df for pretrain_plain_samples
for sample in pretrain_plain_samples:
    if sample not in final_merged_gene_df.columns:
        final_merged_gene_df[sample] = 0.0
# fill in the missing samples in the final_merged_transcript_df for pretrain_plain_samples
for sample in pretrain_plain_samples:
    if sample not in final_merged_transcript_df.columns:
        final_merged_transcript_df[sample] = 0.0
# get the final gene, transcript, drug dataframe by filtering the pretrain_plain_samples
pretrain_plain_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]
display(pretrain_plain_gene_df)
pretrain_plain_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]
display(pretrain_plain_transcript_df)

# fill in the missing samples in the final_merged_gene_df for pretrain_status_samples
for sample in pretrain_status_samples:
    if sample not in final_merged_gene_df.columns:
        final_merged_gene_df[sample] = 0.0
# fill in the missing samples in the final_merged_transcript_df for pretrain_status_samples
for sample in pretrain_status_samples:
    if sample not in final_merged_transcript_df.columns:
        final_merged_transcript_df[sample] = 0.0
# get the final gene, transcript, drug dataframe by filtering the pretrain_status_samples
pretrain_status_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]
display(pretrain_status_gene_df)
pretrain_status_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]
display(pretrain_status_transcript_df)

In [None]:
pretrain_plain_omics_df = pd.concat([pretrain_plain_methy_df, pretrain_plain_gene_df, pretrain_plain_transcript_df, pretrain_plain_protein_df], axis=0).reset_index(drop=True)
display(pretrain_plain_omics_df)
pretrain_plain_feat_df = pd.merge(bmgc_entity_df, pretrain_plain_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])
pretrain_plain_feat_df = pretrain_plain_feat_df.fillna(0.0)
display(pretrain_plain_feat_df)
pretrain_status_omics_df = pd.concat([pretrain_status_methy_df, pretrain_status_gene_df, pretrain_status_transcript_df, pretrain_status_protein_df], axis=0).reset_index(drop=True)
display(pretrain_status_omics_df)
pretrain_status_feat_df = pd.merge(bmgc_entity_df, pretrain_status_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])
pretrain_status_feat_df = pretrain_status_feat_df.fillna(0.0)
display(pretrain_status_feat_df)

In [None]:
# convert the pretrain_plain_df and pretrain_status_df to numpy arrays and transpose them
pretrain_plain_array = pretrain_plain_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
pretrain_status_array = pretrain_status_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
print("Shape of pretrain_plain_array:", pretrain_plain_array.shape)
print("Shape of pretrain_status_array:", pretrain_status_array.shape)
# Save the numpy arrays to .npy files
np.save('./data/pretrain_plain_data/pretrain_plain_feature.npy', pretrain_plain_array)
np.save('./data/pretrain_status_data/pretrain_status_feature.npy', pretrain_status_array)

##### 9.2.1.1 Pretrain Status

In [None]:
# convert pretrain_status_df to numpy array label with  cancerous as 1 and non-cancerous as 0
pretrain_status_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_df['cancerous_status']])
# Save the pretrain_status_label to a .npy file
np.save('./data/pretrain_status_data/pretrain_status_label.npy', pretrain_status_label)
print("Shape of pretrain_status_label:", pretrain_status_label.shape)

# assign the label by pretrain_status_training_samples and pretrain_status_test_samples
pretrain_status_train_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_train_df['cancerous_status']])
pretrain_status_test_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_test_df['cancerous_status']])
# assign the feat by pretrain_status_training_samples and pretrain_status_test_samples
pretrain_status_train_feat = pretrain_status_feat_df[['BioMedGraphica_Conn_ID'] + pretrain_status_train_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T
pretrain_status_test_feat = pretrain_status_feat_df[['BioMedGraphica_Conn_ID'] + pretrain_status_test_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T

# Save the training and test features and labels to .npy files
np.save('./data/pretrain_status_data/pretrain_status_train_feature.npy', pretrain_status_train_feat)
np.save('./data/pretrain_status_data/pretrain_status_test_feature.npy', pretrain_status_test_feat)
np.save('./data/pretrain_status_data/pretrain_status_train_label.npy', pretrain_status_train_label)
np.save('./data/pretrain_status_data/pretrain_status_test_label.npy', pretrain_status_test_label)
print("Shape of pretrain_status_train_feat:", pretrain_status_train_feat.shape)
print("Shape of pretrain_status_test_feat:", pretrain_status_test_feat.shape)
print("Shape of pretrain_status_train_label:", pretrain_status_train_label.shape)
print("Shape of pretrain_status_test_label:", pretrain_status_test_label.shape)
# Print the cancerous/non-cancerous distribution in training and test sets
print(f"\n📈 Pretrain Status Training Set Distribution (Numpy Files):\n"
      f"Cancerous: {np.sum(pretrain_status_train_label == 1)}, Non-cancerous: {np.sum(pretrain_status_train_label == 0)}")
print(f"📈 Pretrain Status Test Set Distribution (Numpy Files):\n"
      f"Cancerous: {np.sum(pretrain_status_test_label == 1)}, Non-cancerous: {np.sum(pretrain_status_test_label == 0)}")

#### 9.2.2 Target data integration

In [None]:
print("The overlapped samples dataframe for methylation between overlapped_dti_crispr_rna_samples:")
# insert the samples that is in the overlapped_samples but not in the maped_methy_df
for sample in overlapped_dti_crispr_rna_samples:
    if sample not in maped_methy_df.columns:
        maped_methy_df[sample] = 0.0
# filter out the samples that is not in the overlapped_dti_crispr_rna_samples
target_crispr_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]
display(target_crispr_methy_df)

print("The overlapped samples dataframe for protein between overlapped_dti_crispr_rna_samples are:")
# insert the samples that is in the pretraining_samples but not in the mapped_protein_df
for sample in overlapped_dti_crispr_rna_samples:
    if sample not in mapped_protein_df.columns:
        mapped_protein_df[sample] = 0.0
# filter out the samples that is not in the overlapped_dti_crispr_rna_samples
target_crispr_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]
display(target_crispr_protein_df)

In [None]:
# fill in the missing samples in the final_merged_gene_df for overlapped_dti_crispr_rna_samples
for sample in overlapped_dti_crispr_rna_samples:
    if sample not in final_merged_gene_df.columns:
        final_merged_gene_df[sample] = 0.0
# fill in the missing samples in the final_merged_transcript_df for overlapped_dti_crispr_rna_samples
for sample in overlapped_dti_crispr_rna_samples:
    if sample not in final_merged_transcript_df.columns:
        final_merged_transcript_df[sample] = 0.0
# get the final gene, transcript, drug dataframe by filtering the overlapped_dti_crispr_rna_samples
target_crispr_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]
display(target_crispr_gene_df)
target_crispr_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]
display(target_crispr_transcript_df)

In [None]:
target_crispr_omics_df = pd.concat([target_crispr_methy_df, target_crispr_gene_df, target_crispr_transcript_df, target_crispr_protein_df], axis=0).reset_index(drop=True)
display(target_crispr_omics_df)
target_crispr_df = pd.merge(bmgc_entity_df, target_crispr_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])
target_crispr_df = target_crispr_df.fillna(0.0)
display(target_crispr_df)

##### 9.2.2.1 Target Screen

In [None]:
# Select the features from target_crispr_df by target_crispr_train_samples and target_crispr_test_samples
target_crispr_train_feat = target_crispr_df[['BioMedGraphica_Conn_ID'] + target_crispr_train_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T
target_crispr_test_feat = target_crispr_df[['BioMedGraphica_Conn_ID'] + target_crispr_test_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T
# Save the training and test features to .npy files
np.save('./data/TargetScreen/target_crispr_train_feature.npy', target_crispr_train_feat)
np.save('./data/TargetScreen/target_crispr_test_feature.npy', target_crispr_test_feat)
# Print the shapes of the training and test features
print("Shape of target_crispr_train_feat:", target_crispr_train_feat.shape)
print("Shape of target_crispr_test_feat:", target_crispr_test_feat.shape)

In [None]:
# Select the features from target_crispr_df
target_crispr_feat = target_crispr_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
# Save the features to .npy files
np.save('./data/TargetScreen/target_crispr_feature.npy', target_crispr_feat)
# Print the shapes of the features
print("Shape of target_crispr_feat:", target_crispr_feat.shape)

In [None]:
# Build up internal relation for KO drop
bmgc_promoter_gene_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].drop(columns=['Type'])
bmgc_gene_transcript_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].drop(columns=['Type'])
bmgc_transcript_protein_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].drop(columns=['Type'])
# Rename the columns
bmgc_promoter_gene_df.rename(columns={'BMGC_From_ID':'promoterID','BMGC_To_ID':'geneID'}, inplace=True)
bmgc_gene_transcript_df.rename(columns={'BMGC_From_ID':'geneID','BMGC_To_ID':'transcriptID'}, inplace=True)
bmgc_transcript_protein_df.rename(columns={'BMGC_From_ID':'transcriptID','BMGC_To_ID':'proteinID'}, inplace=True)
# Merge gene_transcript and transcript_protein dataframes
bmgc_gene_transcript_protein_df = bmgc_transcript_protein_df.merge(bmgc_gene_transcript_df, on='transcriptID',how='outer')
display(bmgc_gene_transcript_protein_df)
# Merge promoter
bmgc_promoter_gene_transcript_protein_df = bmgc_gene_transcript_protein_df.merge(bmgc_promoter_gene_df, on='geneID',how='outer')
display(bmgc_promoter_gene_transcript_protein_df)
internal_relation_df = bmgc_promoter_gene_transcript_protein_df.dropna(subset=['geneID'])
internal_relation_df = internal_relation_df[['promoterID', 'geneID', 'transcriptID', 'proteinID']].copy()
display(internal_relation_df)

In [None]:
raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')
# get the column names of the raw_crispr_df aside from the first column and convert this to a list
first_column_name = raw_crispr_df.columns[0]
raw_crispr_df.rename(columns={first_column_name: 'Sample'}, inplace=True)
display(raw_crispr_df)

In [None]:
train_target_samples_id_df = train_target_samples_df[['depMapID']].copy()
train_target_samples_id_df.rename(columns={'depMapID': 'Sample'}, inplace=True)
test_target_samples_id_df = test_target_samples_df[['depMapID']].copy()
test_target_samples_id_df.rename(columns={'depMapID': 'Sample'}, inplace=True)
# Get the samples in the raw_crispr_df
train_crispr_score_df = raw_crispr_df.merge(train_target_samples_id_df, on="Sample", how="inner")
display(train_crispr_score_df)
test_crispr_score_df = raw_crispr_df.merge(test_target_samples_id_df, on="Sample", how="inner")
display(test_crispr_score_df)

# Set Index column to str to keep digit without .00
train_crispr_score_t_df = train_crispr_score_df.set_index("Sample").T.reset_index()
train_crispr_score_t_df.rename(columns={"index": "HGNC_Symbol"}, inplace=True)
display(train_crispr_score_t_df)
test_crispr_score_t_df = test_crispr_score_df.set_index("Sample").T.reset_index()
test_crispr_score_t_df.rename(columns={"index": "HGNC_Symbol"}, inplace=True)
display(test_crispr_score_t_df)

In [None]:
import re
def extract_gene_name(gene):
    return re.sub(r"\s*\(.*?\)", "", str(gene))

train_crispr_score_t_df["HGNC_Symbol"] = train_crispr_score_t_df["HGNC_Symbol"].apply(extract_gene_name)
unique_Gene_map = bmgc_gene_df[['HGNC_Symbol']].drop_duplicates()
unique_omics_Gene = train_crispr_score_t_df[['HGNC_Symbol']].drop_duplicates()
# match
intersection = set(unique_Gene_map['HGNC_Symbol'])&set(unique_omics_Gene['HGNC_Symbol'])
total_Gene = len(unique_Gene_map)
match_rate = len(intersection) / total_Gene if total_Gene > 0 else 0
print(f"match rate:{match_rate:.2%}")
train_crispr_score_bmgc_df = bmgc_gene_df.merge(train_crispr_score_t_df, on="HGNC_Symbol", how="inner")
train_crispr_score_bmgc_df = train_crispr_score_bmgc_df.drop(columns=['BioMedGraphica_ID','NCBI_Gene_ID','Ensembl_Gene_ID_Version','Gene_Start', 'Gene_End','Chromosome','Gene_Type','Ensembl_Gene_ID','HGNC_ID','Gene_Name','RefSeq_ID','OMIM_ID','HGNC_Symbol'])
display(train_crispr_score_bmgc_df)

test_crispr_score_t_df["HGNC_Symbol"] = test_crispr_score_t_df["HGNC_Symbol"].apply(extract_gene_name)
unique_Gene_map = bmgc_gene_df[['HGNC_Symbol']].drop_duplicates()
unique_omics_Gene = test_crispr_score_t_df[['HGNC_Symbol']].drop_duplicates()
# match
intersection = set(unique_Gene_map['HGNC_Symbol'])&set(unique_omics_Gene['HGNC_Symbol'])
total_Gene = len(unique_Gene_map)
match_rate = len(intersection) / total_Gene if total_Gene > 0 else 0
print(f"match rate:{match_rate:.2%}")
test_crispr_score_bmgc_df = bmgc_gene_df.merge(test_crispr_score_t_df, on="HGNC_Symbol", how="inner")
test_crispr_score_bmgc_df = test_crispr_score_bmgc_df.drop(columns=['BioMedGraphica_ID','NCBI_Gene_ID','Ensembl_Gene_ID_Version','Gene_Start', 'Gene_End','Chromosome','Gene_Type','Ensembl_Gene_ID','HGNC_ID','Gene_Name','RefSeq_ID','OMIM_ID','HGNC_Symbol'])
display(test_crispr_score_bmgc_df)

In [None]:
train_crispr_score_label = train_crispr_score_bmgc_df.copy()
train_crispr_score_label_melted = train_crispr_score_label.melt(id_vars=["BioMedGraphica_Conn_ID"], var_name="ACH_ID", value_name="Value")
train_crispr_score_label_melted = train_crispr_score_label_melted.dropna(subset=['Value'])
display(train_crispr_score_label_melted)
test_crispr_score_label = test_crispr_score_bmgc_df.copy()
test_crispr_score_label_melted = test_crispr_score_label.melt(id_vars=["BioMedGraphica_Conn_ID"], var_name="ACH_ID", value_name="Value")
test_crispr_score_label_melted = test_crispr_score_label_melted.dropna(subset=['Value'])
display(test_crispr_score_label_melted)

##### Merge into internal relation

In [None]:
# Based on dti_sample_index, convert this to a dict
target_crispr_sample_index_dict = dict(zip(overlapped_dti_crispr_rna_samples_df['depMapID'], overlapped_dti_crispr_rna_samples_df.index))
# Map columns ['ACH_ID', 'BioMedGraphica_Conn_ID'] by the crispr_sample_index_dict and node_index_dict
train_crispr_score_label_melted['ACH_ID'] = train_crispr_score_label_melted['ACH_ID'].map(target_crispr_sample_index_dict)
train_crispr_score_label_melted['BioMedGraphica_Conn_ID'] = train_crispr_score_label_melted['BioMedGraphica_Conn_ID'].map(node_index_dict)
test_crispr_score_label_melted['ACH_ID'] = test_crispr_score_label_melted['ACH_ID'].map(target_crispr_sample_index_dict)
test_crispr_score_label_melted['BioMedGraphica_Conn_ID'] = test_crispr_score_label_melted['BioMedGraphica_Conn_ID'].map(node_index_dict)

# Map internal_relation_df columns ['promoterID', 'geneID', 'transcriptID', 'proteinID'] by the node_index_dict
internal_relation_map_df = internal_relation_df.copy()
internal_relation_map_df['promoterID'] = internal_relation_map_df['promoterID'].map(node_index_dict)
internal_relation_map_df['geneID'] = internal_relation_map_df['geneID'].map(node_index_dict)
internal_relation_map_df['transcriptID'] = internal_relation_map_df['transcriptID'].map(node_index_dict)
internal_relation_map_df['proteinID'] = internal_relation_map_df['proteinID'].map(node_index_dict)

ko_internal_relation_map_df = internal_relation_map_df.copy()
display(ko_internal_relation_map_df)
ko_internal_relation_index_df = ko_internal_relation_map_df.groupby(["promoterID", "geneID"], as_index=False).agg({
    "transcriptID": lambda x: list(set(x.dropna())), 
    "proteinID": lambda x: list(set(x.dropna()))  
})

In [None]:
# Define a function to remove .0 suffix from numbers in lists
def remove_decimal_suffix(list_with_decimals):
    if isinstance(list_with_decimals, list):
        return [int(x) if not pd.isna(x) else x for x in list_with_decimals]
    return list_with_decimals

# Apply the function to both columns
ko_internal_relation_index_df['transcriptID'] = ko_internal_relation_index_df['transcriptID'].apply(remove_decimal_suffix)
ko_internal_relation_index_df['proteinID'] = ko_internal_relation_index_df['proteinID'].apply(remove_decimal_suffix)

# Display the result to verify
display(ko_internal_relation_index_df.head())

In [None]:
train_crispr_score_label_melted_ko = train_crispr_score_label_melted.merge(ko_internal_relation_index_df, left_on='BioMedGraphica_Conn_ID',right_on='geneID' ,how='left')
train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko.drop(columns=['BioMedGraphica_Conn_ID'])
display(train_crispr_score_label_melted_ko)

# Create a new column 'merged_ids' by combining the columns into lists
train_crispr_score_label_melted_ko['merged_ids'] = train_crispr_score_label_melted_ko.apply(
    lambda row: [row['promoterID'], row['geneID']] + 
                (row['transcriptID'] if isinstance(row['transcriptID'], list) else []) +
                (row['proteinID'] if isinstance(row['proteinID'], list) else []),
    axis=1
)

# Display the result
train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko.drop(columns=['promoterID', 'geneID', 'transcriptID', 'proteinID'])
display(train_crispr_score_label_melted_ko)

# reorder the columns to ['ACH_ID', 'merged_ids', 'Value']
train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko[['ACH_ID', 'merged_ids', 'Value']].copy()
display(train_crispr_score_label_melted_ko)
# Convert train_crispr_score_label_melted_ko to numpy array and save it to the CRISPR-Graph folder
train_crispr_score_label_melted_ko_array = train_crispr_score_label_melted_ko.to_numpy()
print('The shape of train_crispr_score_label_melted_ko_array is:', train_crispr_score_label_melted_ko_array.shape)
print(train_crispr_score_label_melted_ko_array[:10])  # Print first 10 rows as a sample
# Save the numpy array to the CRISPR-Graph folder
np.save('./data/TargetScreen/train_crispr_score_label_melted_ko.npy', train_crispr_score_label_melted_ko_array)

In [None]:
test_crispr_score_label_melted_ko = test_crispr_score_label_melted.merge(ko_internal_relation_index_df, left_on='BioMedGraphica_Conn_ID',right_on='geneID' ,how='left')
test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko.drop(columns=['BioMedGraphica_Conn_ID'])
display(test_crispr_score_label_melted_ko)

# Create a new column 'merged_ids' by combining the columns into lists
test_crispr_score_label_melted_ko['merged_ids'] = test_crispr_score_label_melted_ko.apply(
    lambda row: [row['promoterID'], row['geneID']] + 
                (row['transcriptID'] if isinstance(row['transcriptID'], list) else []) +
                (row['proteinID'] if isinstance(row['proteinID'], list) else []),
    axis=1
)

# Display the result
test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko.drop(columns=['promoterID', 'geneID', 'transcriptID', 'proteinID'])
display(test_crispr_score_label_melted_ko)

# reorder the columns to ['ACH_ID', 'merged_ids', 'Value']
test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko[['ACH_ID', 'merged_ids', 'Value']].copy()
display(test_crispr_score_label_melted_ko)
# Convert test_crispr_score_label_melted_ko to numpy array and save it to the CRISPR-Graph folder
test_crispr_score_label_melted_ko_array = test_crispr_score_label_melted_ko.to_numpy()
print('The shape of test_crispr_score_label_melted_ko_array is:', test_crispr_score_label_melted_ko_array.shape)
print(test_crispr_score_label_melted_ko_array[:10])  # Print first 10 rows as a sample
# Save the numpy array to the CRISPR-Graph folder
np.save('./data/TargetScreen/test_crispr_score_label_melted_ko.npy', test_crispr_score_label_melted_ko_array)

##### 9.2.2.2 TargetQA

##### Omic Feature

In [None]:
# Select the features from target_crispr_df
target_crispr_feat = target_crispr_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
# Save the features to .npy files
np.save('./data/TargetQA/target_crispr_feature.npy', target_crispr_feat)
# Print the shapes of the features
print("Shape of target_crispr_feat:", target_crispr_feat.shape)

##### Omic Information

In [None]:
# select the columns in the overlapped_dti_crispr_rna_samples for gene_df
dti_crispr_rna_gene_df = gene_df[['gene_name'] + sorted(overlapped_dti_crispr_rna_samples)].copy()
display(dti_crispr_rna_gene_df)
bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')
display(bmgc_protein_llmnameid_combined_df)

def extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_crispr_rna_gene_df columns
    if sample_ach_name not in dti_crispr_rna_gene_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_genes = dti_crispr_rna_gene_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]
    # Sort the top k genes by their values in descending order
    top_k_genes = top_k_genes.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_gene_hgnc_name_list = top_k_genes['gene_name'].tolist()
    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID
    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_gene_df = pd.merge(bmgc_gene_df, top_k_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
    # Get the corresponding proteins
    top_k_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_k_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])
    top_k_bmgc_gene_protein_info_df = pd.merge(top_k_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])
    top_k_gene_protein_bmgc_id_list = top_k_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()
    top_k_gene_protein_bmgc_llmnameid_list = top_k_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \| ', ' or ', regex=True).tolist()
    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000001'
k=10
top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Gene HGNC Names:", top_k_gene_hgnc_name_list)
print(f"Top {k} Gene Protein BMGC IDs:", top_k_gene_protein_bmgc_id_list)
print(f"Top {k} Gene Protein BMGC LLM Name IDs:", top_k_gene_protein_bmgc_llmnameid_list)

In [None]:
# select the columns in the overlapped_dti_crispr_rna_samples for raw_transcript_df
transcript_overlapped_dti_crispr_rna_samples = sorted(list(set(raw_transcript_df.columns[1:]) & set(overlapped_dti_crispr_rna_samples)))
dti_crispr_rna_transcript_df = raw_transcript_df[['gene_name'] + transcript_overlapped_dti_crispr_rna_samples].copy()
display(dti_crispr_rna_transcript_df)
bmgc_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')

def extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_crispr_rna_transcript_df columns
    if sample_ach_name not in dti_crispr_rna_transcript_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_transcripts = dti_crispr_rna_transcript_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]
    # Sort the top k transcripts by their values in descending order
    top_k_transcripts = top_k_transcripts.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_transcript_hgnc_name_list = top_k_transcripts['gene_name'].tolist()
    # Merge with the bmgc_transcript_df to get the BioMedGraphica_Conn_ID
    bmgc_transcript_df = bmgc_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_transcript_df = pd.merge(bmgc_transcript_df, top_k_transcripts, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
    # Get the corresponding proteins
    top_k_bmgc_transcript_protein_df = pd.merge(transcript_protein_entity_df, top_k_bmgc_transcript_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])
    top_k_bmgc_transcript_protein_info_df = pd.merge(top_k_bmgc_transcript_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])
    top_k_transcript_protein_bmgc_id_list = top_k_bmgc_transcript_protein_info_df['BMGC_PT_ID'].tolist()
    top_k_transcript_protein_bmgc_llmnameid_list = top_k_bmgc_transcript_protein_info_df['Names_and_IDs'].replace(r' \| ', ' or ', regex=True).tolist()
    return top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000001'
k=10
top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Transcript HGNC Names:", top_k_transcript_hgnc_name_list)
print(f"Top {k} Transcript Protein BMGC IDs:", top_k_transcript_protein_bmgc_id_list)
print(f"Top {k} Transcript Protein BMGC LLM Name IDs:", top_k_transcript_protein_bmgc_llmnameid_list)

In [None]:
bmg_protein_all_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')
bmg_protein_all_df = bmg_protein_all_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID', 'HGNC_Symbol']].copy()
display(bmg_protein_all_df)

# Rename columns in raw_protein_df using the provided mapping
raw_protein_map_df = raw_protein_df.rename(columns=protein_map_dict)
# Merge with bmg_protein_all_df to get HGNC symbols
symbol_protein_map_df = pd.merge(raw_protein_map_df, bmg_protein_all_df, left_on='Uniprot_Acc', right_on='Uniprot_ID', how='inner')
# Reorder columns: keep Uniprot IDs and protein expression values
symbol_protein_map_df = symbol_protein_map_df[['Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol'] + sorted(set(symbol_protein_map_df.columns) - {'Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol'})]
# Identify overlapping samples between protein data and the provided sample list
protein_overlapped_samples = sorted(set(symbol_protein_map_df.columns) & set(overlapped_dti_crispr_rna_samples))
# Select only HGNC symbol and overlapping sample columns
dti_crispr_rna_protein_df = symbol_protein_map_df[['HGNC_Symbol'] + protein_overlapped_samples].copy()
# Split multiple HGNC symbols by ";" and expand into multiple rows
dti_crispr_rna_protein_df = dti_crispr_rna_protein_df.assign(HGNC_Symbol=dti_crispr_rna_protein_df['HGNC_Symbol'].str.split(';')).explode('HGNC_Symbol')
# Remove leading/trailing whitespace in gene symbols
dti_crispr_rna_protein_df['HGNC_Symbol'] = dti_crispr_rna_protein_df['HGNC_Symbol'].str.strip()
# Drop rows with empty or missing gene symbols
dti_crispr_rna_protein_df = dti_crispr_rna_protein_df[dti_crispr_rna_protein_df['HGNC_Symbol'].notna() & (dti_crispr_rna_protein_df['HGNC_Symbol'] != '')].reset_index(drop=True)
# Display final dataframe
display(dti_crispr_rna_protein_df)

bmgc_protein_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')

def extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_crispr_rna_protein_df columns
    if sample_ach_name not in dti_crispr_rna_protein_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_proteins = dti_crispr_rna_protein_df.nlargest(k, sample_ach_name)[['HGNC_Symbol', sample_ach_name]]
    # Sort the top k proteins by their values in descending order
    top_k_proteins = top_k_proteins.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_protein_hgnc_name_list = top_k_proteins['HGNC_Symbol'].tolist()
    # Merge with the bmgc_protein_df to get the BioMedGraphica_Conn_ID
    bmgc_protein_df = bmgc_protein_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_protein_df = pd.merge(bmgc_protein_df, top_k_proteins, left_on='HGNC_Symbol', right_on='HGNC_Symbol', how='inner')
    # Get the corresponding protein information
    top_k_bmgc_protein_info_df = pd.merge(top_k_bmgc_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_ID', sample_ach_name])
    top_k_protein_bmgc_id_list = top_k_bmgc_protein_info_df['BioMedGraphica_Conn_ID'].tolist()
    # Replace both "|" and ";" with " or "
    top_k_protein_bmgc_llmnameid_list = top_k_bmgc_protein_info_df['Names_and_IDs'].replace([r' \| ', r';'], ' or ', regex=True).tolist()
    return top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000001'
k = 10
top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Protein HGNC Names:", top_k_protein_hgnc_name_list)
print(f"Top {k} Protein BMGC IDs:", top_k_protein_bmgc_id_list)
print(f"Top {k} Protein BMGC LLM Name IDs:", top_k_protein_bmgc_llmnameid_list)

##### Related Proteins

In [None]:
def bmgc_pt_id_to_hgnc(bmgc_id_list, bmgc_protein_df):
    """
    Convert a list of BioMedGraphica IDs to their corresponding HGNC symbols.
    
    Args:
        bmgc_id_list (list): List of BioMedGraphica IDs
        bmgc_protein_df (pd.DataFrame): DataFrame with BioMedGraphica IDs and HGNC symbols
        
    Returns:
        tuple: (
            dict: Dictionary mapping each BioMedGraphica ID to its list of HGNC symbols,
            list: Combined list of all HGNC symbols
        )
    """
    # Ensure bmgc_id_list is actually a list
    if not isinstance(bmgc_id_list, list):
        bmgc_id_list = [bmgc_id_list]
    results = {}
    all_hgnc_symbols = []
    for bmgc_id in bmgc_id_list:
        # Filter the DataFrame for the given BioMedGraphica ID
        filtered_df = bmgc_protein_df[bmgc_protein_df['BioMedGraphica_Conn_ID'] == bmgc_id]
        # Skip if no match found
        if filtered_df.empty:
            results[bmgc_id] = []
            continue
        # Get the HGNC symbols
        hgnc_value = filtered_df['HGNC_Symbol'].values[0]
        # Skip if HGNC symbol is NaN
        if pd.isna(hgnc_value):
            results[bmgc_id] = []
            continue
        # Process valid HGNC symbols
        hgnc_list = list(set(hgnc_value.split(';')))
        hgnc_list = [hgnc.strip() for hgnc in hgnc_list if hgnc.strip() != '']
        results[bmgc_id] = hgnc_list
        all_hgnc_symbols.extend(hgnc_list)
    # Remove duplicates from the combined list
    all_hgnc_symbols = list(set(all_hgnc_symbols))
    return results, all_hgnc_symbols

def hgnc_to_bmgc_pt_id(hgnc_list, bmgc_protein_df):
    """
    Convert a list of HGNC symbols to their corresponding BioMedGraphica IDs.
    
    Args:
        hgnc_list (list): List of HGNC symbols
        bmgc_protein_df (pd.DataFrame): DataFrame with BioMedGraphica IDs and HGNC symbols
        
    Returns:
        tuple: (
            dict: Dictionary mapping each HGNC symbol to its list of BioMedGraphica IDs,
            list: Combined list of all BioMedGraphica IDs
        )
    """
    # Ensure hgnc_list is actually a list
    if not isinstance(hgnc_list, list):
        hgnc_list = [hgnc_list]
    results = {}
    all_bmgc_ids = []
    for hgnc in hgnc_list:
        # Filter the DataFrame for the given HGNC symbol
        filtered_df = bmgc_protein_df[bmgc_protein_df['HGNC_Symbol'] == hgnc]
        # Skip if no match found
        if filtered_df.empty:
            results[hgnc] = []
            continue
        # Get the BioMedGraphica IDs
        bmgc_value = filtered_df['BioMedGraphica_Conn_ID'].values[0]
        # Skip if BioMedGraphica ID is NaN
        if pd.isna(bmgc_value):
            results[hgnc] = []
            continue
        # Process valid BioMedGraphica IDs
        bmgc_list = list(set(bmgc_value.split(';')))
        bmgc_list = [bmgc.strip() for bmgc in bmgc_list if bmgc.strip() != '']
        results[hgnc] = bmgc_list
        all_bmgc_ids.extend(bmgc_list)
    # Remove duplicates from the combined list
    all_bmgc_ids = list(set(all_bmgc_ids))
    return results, all_bmgc_ids

# Example usage
bmgc_ids = ['BMGC_PT000001', 'BMGC_PT013541']
hgnc_dict, all_hgnc_symbols = bmgc_pt_id_to_hgnc(bmgc_ids, bmgc_protein_df)
print(hgnc_dict)
print(all_hgnc_symbols)
# Example usage
hgnc_list = ['BRCA1', 'TP53']
bmgc_dict, all_bmgc_ids = hgnc_to_bmgc_pt_id(hgnc_list, bmgc_protein_df)
print(bmgc_dict)
print(all_bmgc_ids)

In [None]:
def extract_disease_protein(selected_sample_disease_bmgc_id, edge_index, 
                    node_index_df, nodeid_index_dict, index_nodeid_dict):
    # Extract the index based on the selected disease BMGC ID
    sample_disease_bmgc_id_index = nodeid_index_dict[selected_sample_disease_bmgc_id]
    # Find incoming edges (source nodes that point to the disease)
    incoming_mask = edge_index[1, :] == sample_disease_bmgc_id_index
    incoming_source_nodes = edge_index[0, incoming_mask]
    # Find outgoing edges (target nodes that the disease points to)
    outgoing_mask = edge_index[0, :] == sample_disease_bmgc_id_index
    outgoing_target_nodes = edge_index[1, outgoing_mask]
    # Combine all neighbor nodes (both incoming and outgoing)
    disease_related_nodes = np.concatenate([incoming_source_nodes, outgoing_target_nodes])
    unique_disease_related_nodes = np.unique(disease_related_nodes)
    # Get protein node index
    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']
    protein_node_index_list = protein_node_index_df['Index'].tolist()
    # Filter to get only protein nodes directly connected to the disease
    disease_protein_index = sorted(
        list(set(unique_disease_related_nodes) & set(protein_node_index_list))
    )
    # Map protein index to BMGC id
    disease_protein_bmgc_id = [index_nodeid_dict[i] for i in disease_protein_index]
    return disease_protein_index, disease_protein_bmgc_id

def extract_ppi_nodes(disease_protein_index, edge_index, node_index_df, index_nodeid_dict):
    # Get protein node index
    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']
    protein_node_index_list = protein_node_index_df['Index'].tolist()
    # Get all nodes related to the identified protein neighbors (second hop)
    protein_related_nodes = []
    # Iterate through each protein neighbor node index
    for protein_node_idx in disease_protein_index:
        # Find incoming edges (nodes that point to this protein)
        protein_incoming_mask = edge_index[1, :] == protein_node_idx
        protein_incoming_sources = edge_index[0, protein_incoming_mask]
        # Find outgoing edges (nodes that this protein points to)
        protein_outgoing_mask = edge_index[0, :] == protein_node_idx
        protein_outgoing_targets = edge_index[1, protein_outgoing_mask]
        # Add these connected nodes to our list
        protein_related_nodes.extend(protein_incoming_sources)
        protein_related_nodes.extend(protein_outgoing_targets)
    # Convert to numpy array and get unique nodes
    protein_related_nodes = np.array(protein_related_nodes)
    unique_protein_related_nodes = np.unique(protein_related_nodes)
    # Remove any protein nodes themselves from this list to avoid duplication
    unique_protein_related_nodes = np.setdiff1d(
        unique_protein_related_nodes, disease_protein_index
    )
    # Filter to only keep protein nodes among the second-hop neighbors
    ppi_nodes_index = sorted(
        list(set(unique_protein_related_nodes) & set(protein_node_index_list))
    )
    # Map PPI node index to BMGC id
    ppi_nodes_bmgc_id = [index_nodeid_dict[i] for i in ppi_nodes_index]
    return ppi_nodes_index, ppi_nodes_bmgc_id

def extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index,  # './data/DTI_data/edge_index.npy'
                               node_index_df, nodeid_index_dict, index_nodeid_dict):
    """
    Extract disease-related proteins and their interactions from the knowledge graph.
    
    Args:
        selected_sample_disease_bmgc_id (str): BMGC ID of the selected disease
        edge_index (str): Edge index file
        node_index_df (pd.DataFrame): DataFrame with node type information
        nodeid_index_dict (dict): Mapping from node ID to index
        index_nodeid_dict (dict): Mapping from index to node ID
    
    Returns:
        tuple: (disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id)
    """
    
    # Extract disease-protein connections
    disease_protein_index, disease_protein_bmgc_id = extract_disease_protein(
        selected_sample_disease_bmgc_id=selected_sample_disease_bmgc_id,
        edge_index=edge_index,
        node_index_df=node_index_df,
        nodeid_index_dict=nodeid_index_dict,
        index_nodeid_dict=index_nodeid_dict
    )

    # Extract protein-protein interactions (can replace this with LLM to generate PPI, may need NER and mapping to BMGC id)
    ppi_nodes_index, ppi_nodes_bmgc_id = extract_ppi_nodes(
        disease_protein_index=disease_protein_index,
        edge_index=edge_index,
        node_index_df=node_index_df,
        index_nodeid_dict=index_nodeid_dict
    )

    # Convert the BMGC id into HGNC symbol

    return disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id

# Example usage
selected_sample_disease_bmgc_id = 'BMGC_DS07934'
edge_index = np.load('./data/TargetQA/edge_index.npy')
index_node_dict = dict(zip(nodes_index_data['Index'], nodes_index_data['Node']))

disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(
    selected_sample_disease_bmgc_id=selected_sample_disease_bmgc_id,
    edge_index=edge_index,
    node_index_df=nodes_index_data,
    nodeid_index_dict=node_index_dict,
    index_nodeid_dict=index_node_dict
)
print("Disease Protein Index:", disease_protein_index)
print("Disease Protein BMGC ID:", disease_protein_bmgc_id)

##### CRISPR Answer

In [None]:
answer_crispr_df = crispr_df[['gene_name'] + sorted(overlapped_dti_crispr_rna_samples)].copy()
display(answer_crispr_df)

def extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm=100):
    # Extract the lowest top_bm values for the given sample name
    top_bm_genes = answer_crispr_df.nsmallest(top_bm, sample_ach_name)[['gene_name', sample_ach_name]]
    # Sort the top bm genes by their values in ascending order
    top_bm_genes = top_bm_genes.sort_values(by=sample_ach_name, ascending=True).reset_index(drop=True)
    top_bm_gene_hgnc_name_list = top_bm_genes['gene_name'].tolist()
    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID
    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_bm_bmgc_gene_df = pd.merge(bmgc_gene_df, top_bm_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
    # Get the corresponding proteins
    top_bm_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_bm_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])
    top_bm_bmgc_gene_protein_info_df = pd.merge(top_bm_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])
    top_bm_gene_protein_bmgc_id_list = top_bm_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()
    top_bm_gene_protein_bmgc_llmnameid_list = top_bm_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \| ', ' or ', regex=True).tolist()
    return top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000001'
top_bm = 10
bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')
top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list = extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm)
print(f"Top {top_bm} Gene HGNC Names:", top_bm_gene_hgnc_name_list)
print(f"Top {top_bm} Gene Protein BMGC IDs:", top_bm_gene_protein_bmgc_id_list)
print(f"Top {top_bm} Gene Protein BMGC LLM Name IDs:", top_bm_gene_protein_bmgc_llmnameid_list)

##### Knowledge Graph Information

In [None]:
def find_protein_relationships(hgnc_symbols, bmgc_protein_df, bmgc_relation_df):
    """
    Find relationships between a list of proteins based on HGNC symbols.
    
    Parameters:
    -----------
    hgnc_symbols : list
        List of HGNC symbols to find relationships between
    bmgc_protein_df : pandas DataFrame
        DataFrame containing BioMedGraphica_Conn_ID and HGNC_Symbol columns
    bmgc_relation_df : pandas DataFrame
        DataFrame containing BMGC_From_ID and BMGC_To_ID columns
        
    Returns:
    --------
    tuple (pandas DataFrame, list)
        - DataFrame with source_symbol, target_symbol and their relationship
        - List of text descriptions of relationships in "A -> B" format
    """
    # Filter the protein DataFrame to only include the proteins we care about
    filtered_proteins = bmgc_protein_df[bmgc_protein_df['HGNC_Symbol'].isin(hgnc_symbols)]
    # Create a mapping from HGNC symbol to BMGC ID
    hgnc_to_bmgc = dict(zip(filtered_proteins['HGNC_Symbol'], filtered_proteins['BioMedGraphica_Conn_ID']))
    bmgc_to_hgnc = dict(zip(filtered_proteins['BioMedGraphica_Conn_ID'], filtered_proteins['HGNC_Symbol']))
    # Get all BMGC IDs of our proteins
    bmgc_ids = list(hgnc_to_bmgc.values())
    # Filter the relationship DataFrame to only include relationships between our proteins
    protein_relations = bmgc_relation_df[
        bmgc_relation_df['BMGC_From_ID'].isin(bmgc_ids) & 
        bmgc_relation_df['BMGC_To_ID'].isin(bmgc_ids)
    ]
    # Map the BMGC IDs back to HGNC symbols
    result_data = []
    text_descriptions = []
    for _, row in protein_relations.iterrows():
        source_bmgc = row['BMGC_From_ID']
        target_bmgc = row['BMGC_To_ID']
        if source_bmgc in bmgc_to_hgnc and target_bmgc in bmgc_to_hgnc:
            source_symbol = bmgc_to_hgnc[source_bmgc]
            target_symbol = bmgc_to_hgnc[target_bmgc]
            # Create text description
            text_description = f"{source_symbol} -> {target_symbol}"
            text_descriptions.append(text_description)
            # If relation_type column exists, include it in the description and data
            relation_info = {
                'source_symbol': source_symbol,
                'target_symbol': target_symbol
            }
            # Add relation type if it exists in the DataFrame
            if 'relation_type' in bmgc_relation_df.columns:
                relation_type = row['relation_type']
                relation_info['relation_type'] = relation_type
                text_descriptions[-1] = f"{source_symbol} -{relation_type}-> {target_symbol}"
            result_data.append(relation_info)
    # Create a DataFrame from the results
    result_df = pd.DataFrame(result_data)
    return result_df, text_descriptions

# Define the HGNC symbols
hgnc_symbols = ['SNRPD3', 'RAN', 'RPS8', 'UBL5', 'SMU1', 'RRM1', 'PSMA6', 'PSMB3', 'WEE1', 
                'PHB1', 'BANF1', 'KIF11', 'SNRPD1', 'PSMA3', 'PSMD11', 'PRPF19', 'SNRPF', 
                'RPS29', 'CDC27', 'SRSF3', 'TUBGCP2', 'ECD', 'RPS20', 'PCNA', 'PSMA7', 'CDC7', 
                'RPL17', 'GINS1', 'PHB2', 'SRSF2', 'MAD2L1', 'MED14']

# Call the function to find relationships
relationships_df, relationship_texts = find_protein_relationships(hgnc_symbols, bmgc_protein_df, bmgc_relation_df)

# You can now use both the DataFrame and text descriptions
# Example usage:
print(f"Found {len(relationship_texts)} relationships between the proteins")
print("Example relationships:")
print(relationship_texts[:5])  # Print first 5 relationships

##### Formulate QA JSON

In [None]:
def qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=100, top_bm=100):
    # LLM Info
    print(f"Sample ACH Name: {sample_ach_name}")
    print(f"Extracting top {k} gene information for {sample_ach_name}...")
    top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    print(f"Extracting top {k} transcript information for {sample_ach_name}...")
    top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    print(f"Extracting top {k} protein information for {sample_ach_name}...")
    top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    # KG Info
    edge_index = np.load('./data/TargetQA/edge_index.npy')
    print(f"Extracting disease-related proteins index and bmgc id for {selected_sample_disease_bmgc_id} ({sample_ach_name}) ...")
    disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index, nodes_index_data, node_index_dict, index_node_dict)
    print(f"Knowledge Graph Info: Found {len(disease_protein_index)} disease-related proteins directly connected to {selected_sample_disease_bmgc_id} and {len(ppi_nodes_index)} proteins in their PPI network")
    print(f"Mapping disease-related proteins to HGNC symbols...")
    disease_protein_hgnc_dict, disease_protein_hgnc_list = bmgc_pt_id_to_hgnc(disease_protein_bmgc_id, bmgc_protein_df)
    print(f"Mapping PPI-related proteins to HGNC symbols...")
    ppi_hgnc_dict, ppi_hgnc_list = bmgc_pt_id_to_hgnc(ppi_nodes_bmgc_id, bmgc_protein_df)
    # LLM Used KG Info
    print(f"Extracting protein relationships from BMGC...")
    # Convert the any non-existed string in top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list into empty list []
    if top_k_gene_hgnc_name_list == "non-existed": top_k_gene_hgnc_name_list = []
    if top_k_transcript_hgnc_name_list == "non-existed": top_k_transcript_hgnc_name_list = []
    if top_k_protein_hgnc_name_list == "non-existed": top_k_protein_hgnc_name_list = []
    if disease_protein_hgnc_list == "non-existed": disease_protein_hgnc_list = []
    # Combine all the HGNC symbols into a single list for relationship extraction
    omics_disease_protein_hgnc_list = list(set(top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list))
    relationships_df, relationship_texts = find_protein_relationships(omics_disease_protein_hgnc_list, bmgc_protein_df, bmgc_relation_df)
    print(f"Knowledge Graph Info: Found {len(omics_disease_protein_hgnc_list)} unique proteins and {len(relationship_texts)} relationships between them")
    # Answer Info
    print(f"Extracting top {top_bm} CRISPR gene information for {sample_ach_name}...")
    top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list = extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm)
    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list, \
              top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list, \
                top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list, \
                disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id, \
                    disease_protein_hgnc_dict, disease_protein_hgnc_list, ppi_hgnc_dict, ppi_hgnc_list, relationship_texts, \
                        top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000001'
selected_sample_disease_bmgc_id = 'BMGC_DS07934'
k = 100
top_bm = 100
return_tuples = qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=k, top_bm=top_bm)

In [None]:
import json
import os
from tqdm import tqdm

# Parameters
k = 10
top_bm = 100
save_every_n = 10

# Output folder and filename
output_dir = "./data/TargetQA"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"target_qa_k{k}_bm{top_bm}.json"
output_path = os.path.join(output_dir, output_filename)

# Load existing JSON file if it exists
if os.path.exists(output_path):
    with open(output_path, "r") as f:
        multi_sample_qa_json = json.load(f)
    print(f"Loaded existing JSON with {len(multi_sample_qa_json)} processed samples")
else:
    multi_sample_qa_json = {}
    print("No existing JSON found, starting fresh")

# Load sample info data
dti_sample_info_index = pd.read_csv('./data/process_data/dti_combined_samples.csv')
target_sample_info_index = pd.merge(dti_sample_info_index, overlapped_dti_crispr_rna_samples_df['depMapID'], how='inner', on='depMapID')
target_sample_info_index = target_sample_info_index[target_sample_info_index['depMapID'].isin(overlapped_dti_crispr_rna_samples)].reset_index(drop=True)
target_sample_info_index['BMGC_Disease_name'] = target_sample_info_index['BMGC_Disease_name'].replace(r' \| ', ' or ', regex=True)
# Insert a new column "Index" in the first position
target_sample_info_index.insert(0, 'Index', range(1, len(target_sample_info_index) + 1))
display(target_sample_info_index)

In [None]:
# Get the list of samples that have already been processed
processed_samples = set(multi_sample_qa_json.keys())
print(f"Found {len(processed_samples)} already processed samples")

count = len(processed_samples)
total_to_process = len(target_sample_info_index)
remaining = total_to_process - count
print(f"Total samples to process: {total_to_process}, already processed: {count}, remaining: {remaining}")

# Iterate through the sample info dataframe, skipping already processed samples
for idx, row_tuple in tqdm(enumerate(target_sample_info_index.iterrows()), total=len(target_sample_info_index)):
    _, row = row_tuple  # Unpack the tuple - index and row data

    sample_ach_name = row["depMapID"]
    
    # Skip if already processed
    if sample_ach_name in processed_samples:
        continue
        
    count += 1
    target_sample_index = row["Index"]
    cell_line_name = row["Name"]
    disease = row["BMGC_Disease_name"]
    disease_bmgc_id = row["BMGC_Disease_ID"]

    print(f"Processing sample {count}/{total_to_process}: {sample_ach_name} ({cell_line_name})")
    print(f"Sample Index: {target_sample_index}")
    print(f"Sample Disease: {disease}")
    print(f"Sample Disease BMGC ID: {disease_bmgc_id}")

    try:
        (top_k_gene_hgnc, top_k_gene_bmgc, top_k_gene_llm,
        top_k_ts_hgnc, top_k_ts_bmgc, top_k_ts_llm,
        top_k_pt_hgnc, top_k_pt_bmgc, top_k_pt_llm,
        dis_pt_idx, dis_pt_bmgc, ppi_idx, ppi_bmgc,
        dis_pt_hgnc_dict, dis_pt_hgnc, ppi_hgnc_dict, ppi_hgnc, relationship_texts,
        ans_hgnc, ans_bmgc, ans_llm) = qa_sample_info(sample_ach_name, disease_bmgc_id, k=k, top_bm=top_bm)

        multi_sample_qa_json[sample_ach_name] = {
            "cell_line_name": cell_line_name,
            "sample_index": target_sample_index,
            "disease": disease,
            "disease_bmgc_id": disease_bmgc_id,
            "input": {
                "top_k_gene": {
                    "hgnc_symbols": top_k_gene_hgnc,
                    "protein_bmgc_ids": top_k_gene_bmgc,
                    "protein_llmname_ids": top_k_gene_llm
                },
                "top_k_transcript": {
                    "hgnc_symbols": top_k_ts_hgnc,
                    "protein_bmgc_ids": top_k_ts_bmgc,
                    "protein_llmname_ids": top_k_ts_llm
                },
                "top_k_protein": {
                    "hgnc_symbols": top_k_pt_hgnc,
                    "protein_bmgc_ids": top_k_pt_bmgc,
                    "protein_llmname_ids": top_k_pt_llm
                },
                "knowledge_graph": {
                    "disease_protein": {
                        "bmgc_ids": dis_pt_bmgc,
                        "hgnc_symbols": dis_pt_hgnc,
                        "indices": dis_pt_idx
                    },
                    "ppi_neighbors": {
                        "bmgc_ids": ppi_bmgc,
                        "hgnc_symbols": ppi_hgnc,
                        "indices": ppi_idx
                    },
                    "protein_relationships": relationship_texts,
                }
            },
            "ground_truth_answer": {
                "top_bm_gene": {
                    "hgnc_symbols": ans_hgnc,
                    "protein_bmgc_ids": ans_bmgc,
                    "protein_llmname_ids": ans_llm
                }
            }
        }

    except Exception as e:
        print(f"⚠️ Error processing {sample_ach_name}: {e}")
        continue

    # Periodic save every N samples
    if count % save_every_n == 0:
        with open(output_path, "w") as f:
            json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)
        print(f"💾 Auto-saved JSON at {count}/{total_to_process} samples to: {output_path}")
        print(f"Last processed sample: {sample_ach_name}")
        processed = len(multi_sample_qa_json)
        remaining = total_to_process - processed
        print(f"Progress: {processed}/{total_to_process} ({processed/total_to_process*100:.1f}%), Remaining: {remaining}")

# Final save after loop
with open(output_path, "w") as f:
    json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)

print(f"✅ Final JSON saved to: {output_path}")
print(f"Total samples processed: {len(multi_sample_qa_json)}")

##### Split QA as Train and Test

In [None]:
import json
import pandas as pd

def read_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def separate_data(data):
    # Get the training and testing list
    target_tr_df = pd.read_csv('./data/TargetQA/train_samples.csv')
    target_te_df = pd.read_csv('./data/TargetQA/test_samples.csv')
    target_tr_list = target_tr_df['depMapID'].tolist()
    target_te_list = target_te_df['depMapID'].tolist()
    # Separate the data into two lists based on tr/te list with (for sample_id, sample_info in data.items())
    tr_data = {}
    te_data = {}
    for sample_id, sample_info in data.items():
        if sample_id in target_tr_list:
            tr_data[sample_id] = sample_info
        elif sample_id in target_te_list:
            te_data[sample_id] = sample_info
    # Check if the separated data is correct
    tr_data_count = len(tr_data)
    te_data_count = len(te_data)
    print(f"Training data count: {tr_data_count}")
    print(f"Testing data count: {te_data_count}")
    # Check if the training and testing data are mutually exclusive
    if set(tr_data.keys()).intersection(te_data.keys()):
        print("Error: Training and testing data are not mutually exclusive.")
    else:
        print("Training and testing data are mutually exclusive.")
    return tr_data, te_data

def save_to_json(data, file_path):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=2)
    print(f"Data saved to {file_path}")

In [None]:
# Read the JSON file
data = read_json('./data/TargetQA/target_qa_k10_bm100.json')

# Separate the data into training and testing sets
tr_data, te_data = separate_data(data)
# Save the separated data to new JSON files
save_to_json(tr_data, './data/TargetQA/target_qa_k10_bm100_tr.json')
save_to_json(te_data, './data/TargetQA/target_qa_k10_bm100_te.json')

##### Formulate QA text

#### 9.2.3 Drug data integration

In [None]:
# insert the samples that is in the remaining_dti_samples but not in the maped_methy_df
for sample in remaining_dti_samples:
    if sample not in maped_methy_df.columns:
        maped_methy_df[sample] = 0.0
# filter out the samples that is not in the remaining_dti_samples
dti_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_methy_df)

# insert the samples that is in the remaining_dti_samples but not in the mapped_protein_df
for sample in remaining_dti_samples:
    if sample not in mapped_protein_df.columns:
        mapped_protein_df[sample] = 0.0
# filter out the samples that is not in the remaining_dti_samples
dti_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_protein_df)

In [None]:
# fill in the missing samples in the final_merged_gene_df
for sample in remaining_dti_samples:
    if sample not in final_merged_gene_df.columns:
        final_merged_gene_df[sample] = 0.0
# fill in the missing samples in the final_merged_transcript_df
for sample in remaining_dti_samples:
    if sample not in final_merged_transcript_df.columns:
        final_merged_transcript_df[sample] = 0.0
# get the final gene, transcript, drug dataframe by filtering the remaining_dti_samples
dti_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_gene_df)
dti_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_transcript_df)
dti_drug_overlap_df = final_drug_df[final_drug_df['ARXSPAN_ID'].isin(remaining_dti_samples)].reset_index(drop=True)
dti_drug_overlap_df = pd.merge(dti_drug_overlap_df, dti_combined_df, left_on='ARXSPAN_ID', right_on='depMapID', how='left')
dti_drug_overlap_df = pd.merge(dti_drug_overlap_df, final_merged_drug_df, on='DRUG_NAME', how='left').rename(columns={'BMGC_Disease_name': 'BMGC_Disease_Name', 'DRUG_NAME': 'BMGC_Drug_Name', 'BioMedGraphica_Conn_ID': 'BMGC_Drug_ID', 'Name': 'Cell_Line_Name', 'AUC_PUBLISHED': 'AUC'})
# only keep columns ['depMapID', 'Cell_Line_Name', 'BMGC_Drug_ID', 'BMGC_Drug_Name', 'BMGC_Disease_ID', 'BMGC_Disease_Name', 'AUC']
dti_drug_overlap_df = dti_drug_overlap_df[['depMapID', 'Cell_Line_Name', 'BMGC_Drug_ID', 'BMGC_Drug_Name', 'BMGC_Disease_ID', 'BMGC_Disease_Name', 'AUC']]
# check if there is null values in the dti_drug_overlap_df
print(dti_drug_overlap_df.isnull().sum())
display(dti_drug_overlap_df)

In [None]:
dti_omics_df = pd.concat([dti_methy_df, dti_gene_df, dti_transcript_df, dti_protein_df], axis=0).reset_index(drop=True)
display(dti_omics_df)
dti_feat_df = pd.merge(bmgc_entity_df, dti_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])
dti_feat_df = dti_feat_df.fillna(0.0)
display(dti_feat_df)

# convert dti_feat_df to numpy array and transpose it
dti_array = dti_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
print("Shape of dti_array:", dti_array.shape)
# Save the numpy array to .npy file
np.save('./data/DrugQA/dti_feature.npy', dti_array)
np.save('./data/DrugScreen/dti_feature.npy', dti_array)

In [None]:
# filter out the drug screen by training and test with condition that meet both cell line test samples and drug test samples for test
final_training_drug_screen_df = dti_drug_overlap_df[dti_drug_overlap_df['depMapID'].isin(dti_train_samples)].reset_index(drop=True)
display(final_training_drug_screen_df)
# and calculate the final training drug ids and cell line ids
training_drug_ids = list(set(final_training_drug_screen_df['BMGC_Drug_ID']))
training_cell_line_ids = list(set(final_training_drug_screen_df['depMapID']))
print("len(training_drug_ids):", len(training_drug_ids))
print("len(training_cell_line_ids):", len(training_cell_line_ids))

# the rest of the drug screen is the test drug screen
final_test_drug_screen_df = dti_drug_overlap_df[dti_drug_overlap_df['depMapID'].isin(dti_test_samples)].reset_index(drop=True)
display(final_test_drug_screen_df)
# and calculate the final test drug ids and cell line ids
test_drug_ids = list(set(final_test_drug_screen_df['BMGC_Drug_ID']))
test_cell_line_ids = list(set(final_test_drug_screen_df['depMapID']))
print("len(test_drug_ids):", len(test_drug_ids))
print("len(test_cell_line_ids):", len(test_cell_line_ids))

In [None]:
print("The overlapped samples dataframe for methylation between remaining_dti_samples:")
# insert the samples that is in the overlapped_samples but not in the maped_methy_df
for sample in remaining_dti_samples:
    if sample not in maped_methy_df.columns:
        maped_methy_df[sample] = 0.0
# filter out the samples that is not in the remaining_dti_samples
dti_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_methy_df)

print("The overlapped samples dataframe for protein between remaining_dti_samples are:")
# insert the samples that is in the pretraining_samples but not in the mapped_protein_df
for sample in remaining_dti_samples:
    if sample not in mapped_protein_df.columns:
        mapped_protein_df[sample] = 0.0
# filter out the samples that is not in the remaining_dti_samples
dti_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_protein_df)

In [None]:
# fill in the missing samples in the final_merged_gene_df for remaining_dti_samples
for sample in remaining_dti_samples:
    if sample not in final_merged_gene_df.columns:
        final_merged_gene_df[sample] = 0.0
# fill in the missing samples in the final_merged_transcript_df for remaining_dti_samples
for sample in remaining_dti_samples:
    if sample not in final_merged_transcript_df.columns:
        final_merged_transcript_df[sample] = 0.0
# get the final gene, transcript, drug dataframe by filtering the remaining_dti_samples
dti_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_gene_df)
dti_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]
display(dti_transcript_df)

In [None]:
dti_omics_df = pd.concat([dti_methy_df, dti_gene_df, dti_transcript_df, dti_protein_df], axis=0).reset_index(drop=True)
display(dti_omics_df)
dti_df = pd.merge(bmgc_entity_df, dti_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])
dti_df = dti_df.fillna(0.0)
display(dti_df)

In [None]:
# Select the features from dti_df
dti_feat = dti_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T
# Save the features to .npy files
np.save('./data/DrugQA/dti_feature.npy', dti_feat)
np.save('./data/DrugScreen/dti_feature.npy', dti_feat)
# Print the shapes of the features
print("Shape of dti_feat:", dti_feat.shape)

# Create a dictionary mapping each depMapID to its corresponding row index (default DataFrame index)
dti_sample_index_dict = dict(zip(remaining_dti_samples_df['depMapID'], remaining_dti_samples_df.index))
print(dti_sample_index_dict)

##### 9.2.3.1 Drug Screen Integration

In [None]:
# For final_training_drug_screen_df, only keep the columns ['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']
final_training_drug_screen_dfc = final_training_drug_screen_df[['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']].copy()
# Map depMapID to dti_sample_index, BMGC_Drug_ID to node_index, BMGC_Disease_ID to node_index
final_training_drug_screen_dfc['dti_sample_index'] = final_training_drug_screen_dfc['depMapID'].map(dti_sample_index_dict)
final_training_drug_screen_dfc['BMGC_Drug_ID'] = final_training_drug_screen_dfc['BMGC_Drug_ID'].map(node_index_dict)
final_training_drug_screen_dfc['BMGC_Disease_ID'] = final_training_drug_screen_dfc['BMGC_Disease_ID'].map(node_index_dict)
final_training_drug_screen_dfc = final_training_drug_screen_dfc.drop(columns=['depMapID'])
# reorder the columns to ['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']
final_training_drug_screen_dfc = final_training_drug_screen_dfc[['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']]
display(final_training_drug_screen_dfc)

# For final_test_drug_screen_df, only keep the columns ['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']
final_test_drug_screen_dfc = final_test_drug_screen_df[['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']].copy()
# Map depMapID to dti_sample_index, BMGC_Drug_ID to node_index, BMGC_Disease_ID to node_index
final_test_drug_screen_dfc['dti_sample_index'] = final_test_drug_screen_dfc['depMapID'].map(dti_sample_index_dict)
final_test_drug_screen_dfc['BMGC_Drug_ID'] = final_test_drug_screen_dfc['BMGC_Drug_ID'].map(node_index_dict)
final_test_drug_screen_dfc['BMGC_Disease_ID'] = final_test_drug_screen_dfc['BMGC_Disease_ID'].map(node_index_dict)
final_test_drug_screen_dfc = final_test_drug_screen_dfc.drop(columns=['depMapID'])
# reorder the columns to ['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']
final_test_drug_screen_dfc = final_test_drug_screen_dfc[['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']]
display(final_test_drug_screen_dfc)

In [None]:
# Convert final_training_drug_screen_dfc and final_test_drug_screen_dfc to numpy arrays and save them
final_training_drug_screen_array = final_training_drug_screen_dfc.values
final_test_drug_screen_array = final_test_drug_screen_dfc.values
# Save the numpy arrays to .npy files
np.save('./data/DrugScreen/final_training_drug_screen.npy', final_training_drug_screen_array)
np.save('./data/DrugScreen/final_test_drug_screen.npy', final_test_drug_screen_array)
print("Shape of final_training_drug_screen_array:", final_training_drug_screen_array.shape)
print("Shape of final_test_drug_screen_array:", final_test_drug_screen_array.shape)

##### 9.2.3.2 DrugQA Integration

In [None]:
# select the columns in the remaining_dti_samples for gene_df
# Only include samples that actually exist in gene_df columns
available_samples = [sample for sample in remaining_dti_samples if sample in gene_df.columns]
dti_gene_df = gene_df[['gene_name'] + sorted(available_samples)].copy()
print(f"Total remaining_dti_samples: {len(remaining_dti_samples)}")
print(f"Available samples in gene_df: {len(available_samples)}")
print(f"Missing samples: {len(remaining_dti_samples) - len(available_samples)}")
display(dti_gene_df)
bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')
display(bmgc_protein_llmnameid_combined_df)

def extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_gene_df columns
    if sample_ach_name not in dti_gene_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_genes = dti_gene_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]
    # Sort the top k genes by their values in descending order
    top_k_genes = top_k_genes.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_gene_hgnc_name_list = top_k_genes['gene_name'].tolist()
    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID
    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_gene_df = pd.merge(bmgc_gene_df, top_k_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
    # Get the corresponding proteins
    top_k_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_k_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])
    top_k_bmgc_gene_protein_info_df = pd.merge(top_k_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])
    top_k_gene_protein_bmgc_id_list = top_k_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()
    top_k_gene_protein_bmgc_llmnameid_list = top_k_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \| ', ' or ', regex=True).tolist()
    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000002'
k=10
top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Gene HGNC Names:", top_k_gene_hgnc_name_list)
print(f"Top {k} Gene Protein BMGC IDs:", top_k_gene_protein_bmgc_id_list)
print(f"Top {k} Gene Protein BMGC LLM Name IDs:", top_k_gene_protein_bmgc_llmnameid_list)

In [None]:
# select the columns in the remaining_dti_samples for raw_transcript_df
# Only include samples that actually exist in raw_transcript_df columns
available_transcript_samples = [sample for sample in remaining_dti_samples if sample in raw_transcript_df.columns]
dti_transcript_df = raw_transcript_df[['gene_name'] + sorted(available_transcript_samples)].copy()
print(f"Total remaining_dti_samples: {len(remaining_dti_samples)}")
print(f"Available samples in raw_transcript_df: {len(available_transcript_samples)}")
print(f"Missing samples: {len(remaining_dti_samples) - len(available_transcript_samples)}")
display(dti_transcript_df)

def extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_transcript_df columns
    if sample_ach_name not in dti_transcript_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_transcripts = dti_transcript_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]
    # Sort the top k transcripts by their values in descending order
    top_k_transcripts = top_k_transcripts.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_transcript_hgnc_name_list = top_k_transcripts['gene_name'].tolist()
    # Merge with the bmgc_transcript_df to get the BioMedGraphica_Conn_ID
    bmgc_transcript_df = bmgc_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_transcript_df = pd.merge(bmgc_transcript_df, top_k_transcripts, left_on='HGNC_Symbol', right_on='gene_name', how='inner')
    # Get the corresponding proteins
    top_k_bmgc_transcript_protein_df = pd.merge(transcript_protein_entity_df, top_k_bmgc_transcript_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])
    top_k_bmgc_transcript_protein_info_df = pd.merge(top_k_bmgc_transcript_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])
    top_k_transcript_protein_bmgc_id_list = top_k_bmgc_transcript_protein_info_df['BMGC_PT_ID'].tolist()
    top_k_transcript_protein_bmgc_llmnameid_list = top_k_bmgc_transcript_protein_info_df['Names_and_IDs'].replace(r' \| ', ' or ', regex=True).tolist()
    return top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000002'
k=10
top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Transcript HGNC Names:", top_k_transcript_hgnc_name_list)
print(f"Top {k} Transcript Protein BMGC IDs:", top_k_transcript_protein_bmgc_id_list)
print(f"Top {k} Transcript Protein BMGC LLM Name IDs:", top_k_transcript_protein_bmgc_llmnameid_list)

In [None]:
# Load protein mapping data with BioMedGraphica IDs, Uniprot IDs, and HGNC symbols
bmg_protein_all_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')
bmg_protein_all_df = bmg_protein_all_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID', 'HGNC_Symbol']].copy()
display(bmg_protein_all_df)

# Map sample IDs in raw protein data to standardized format
raw_protein_map_df = raw_protein_df.rename(columns=protein_map_dict)

# Join protein expression data with HGNC symbol annotations
symbol_protein_map_df = pd.merge(raw_protein_map_df, bmg_protein_all_df, left_on='Uniprot_Acc', right_on='Uniprot_ID', how='inner')

# Reorganize columns: protein identifiers first, then sample expression values
identifier_cols = ['Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol']
expression_cols = sorted(set(symbol_protein_map_df.columns) - set(identifier_cols))
symbol_protein_map_df = symbol_protein_map_df[identifier_cols + expression_cols]

# Only include samples that actually exist in symbol_protein_map_df columns
available_protein_samples = [sample for sample in remaining_dti_samples if sample in symbol_protein_map_df.columns]
print(f"Total remaining_dti_samples: {len(remaining_dti_samples)}")
print(f"Available samples in symbol_protein_map_df: {len(available_protein_samples)}")
print(f"Missing samples: {len(remaining_dti_samples) - len(available_protein_samples)}")

# Extract protein data for DTI samples with HGNC symbols
dti_protein_df = symbol_protein_map_df[['HGNC_Symbol'] + sorted(available_protein_samples)].copy()

# Handle multiple HGNC symbols per protein (semicolon-separated)
dti_protein_df = dti_protein_df.assign(
    HGNC_Symbol=dti_protein_df['HGNC_Symbol'].str.split(';')
).explode('HGNC_Symbol')

# Clean up HGNC symbols: remove whitespace and filter out empty entries
dti_protein_df['HGNC_Symbol'] = dti_protein_df['HGNC_Symbol'].str.strip()
dti_protein_df = dti_protein_df[
    dti_protein_df['HGNC_Symbol'].notna() & 
    (dti_protein_df['HGNC_Symbol'] != '')
].reset_index(drop=True)

display(dti_protein_df)

def extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):
    # Check if sample_ach_name exists in dti_protein_df columns
    if sample_ach_name not in dti_protein_df.columns:
        return "non-existed", "non-existed", "non-existed"
    # Extract the top k highest values for the given sample name
    top_k_proteins = dti_protein_df.nlargest(k, sample_ach_name)[['HGNC_Symbol', sample_ach_name]]
    # Sort the top k proteins by their values in descending order
    top_k_proteins = top_k_proteins.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)
    top_k_protein_hgnc_name_list = top_k_proteins['HGNC_Symbol'].tolist()
    # Merge with the bmgc_protein_df to get the BioMedGraphica_Conn_ID
    bmgc_protein_df = bmgc_protein_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()
    top_k_bmgc_protein_df = pd.merge(bmgc_protein_df, top_k_proteins, left_on='HGNC_Symbol', right_on='HGNC_Symbol', how='inner')
    # Get the corresponding protein information
    top_k_bmgc_protein_info_df = pd.merge(top_k_bmgc_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_ID', sample_ach_name])
    top_k_protein_bmgc_id_list = top_k_bmgc_protein_info_df['BioMedGraphica_Conn_ID'].tolist()
    # Replace both "|" and ";" with " or "
    top_k_protein_bmgc_llmnameid_list = top_k_bmgc_protein_info_df['Names_and_IDs'].replace([r' \| ', r';'], ' or ', regex=True).tolist()
    return top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list

# Example usage
sample_ach_name = 'ACH-000008'
k = 10
top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
print(f"Top {k} Protein HGNC Names:", top_k_protein_hgnc_name_list)
print(f"Top {k} Protein BMGC IDs:", top_k_protein_bmgc_id_list)
print(f"Top {k} Protein BMGC LLM Name IDs:", top_k_protein_bmgc_llmnameid_list)

In [None]:
def extract_drug_protein(selected_sample_drug_bmgc_id, edge_index, 
                    node_index_df, nodeid_index_dict, index_nodeid_dict):
    # Extract the index based on the selected drug BMGC ID
    sample_drug_bmgc_id_index = nodeid_index_dict[selected_sample_drug_bmgc_id]
    # Find incoming edges (source nodes that point to the drug)
    incoming_mask = edge_index[1, :] == sample_drug_bmgc_id_index
    incoming_source_nodes = edge_index[0, incoming_mask]
    # Find outgoing edges (target nodes that the drug points to)
    outgoing_mask = edge_index[0, :] == sample_drug_bmgc_id_index
    outgoing_target_nodes = edge_index[1, outgoing_mask]
    # Combine all neighbor nodes (both incoming and outgoing)
    drug_related_nodes = np.concatenate([incoming_source_nodes, outgoing_target_nodes])
    unique_drug_related_nodes = np.unique(drug_related_nodes)
    # Get protein node index
    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']
    protein_node_index_list = protein_node_index_df['Index'].tolist()
    # Filter to get only protein nodes directly connected to the drug
    drug_protein_index = sorted(
        list(set(unique_drug_related_nodes) & set(protein_node_index_list))
    )
    # Map protein index to BMGC id
    drug_protein_bmgc_id = [index_nodeid_dict[i] for i in drug_protein_index]
    return drug_protein_index, drug_protein_bmgc_id

# Example usage
selected_sample_drug_bmgc_id = 'BMGC_DG00001'

drug_protein_index, drug_protein_bmgc_id = extract_drug_protein(
    selected_sample_drug_bmgc_id=selected_sample_drug_bmgc_id,
    edge_index=edge_index,
    node_index_df=nodes_index_data,
    nodeid_index_dict=node_index_dict,
    index_nodeid_dict=index_node_dict
)

print("Drug Protein Index:", drug_protein_index)
print("Drug Protein BMGC ID:", drug_protein_bmgc_id)

##### Formulate the drug related protein json

In [None]:
import json
from tqdm import tqdm

dti_drug_ids = sorted(list(set(dti_drug_overlap_df['BMGC_Drug_ID'])))

drug_kg_protein_bmgc_dict = {}
print_count = 0
max_prints = 5

for i, drug_bmgc_id in enumerate(tqdm(dti_drug_ids, desc="Processing drug-protein relationships")):
    drug_protein_index, drug_protein_bmgc_id = extract_drug_protein(
        selected_sample_drug_bmgc_id=drug_bmgc_id,
        edge_index=edge_index,
        node_index_df=nodes_index_data,
        nodeid_index_dict=node_index_dict,
        index_nodeid_dict=index_node_dict
    )
    hgnc_dict, drug_protein_hgnc_list = bmgc_pt_id_to_hgnc(drug_protein_bmgc_id, bmgc_protein_df)
    
    # Only print first 5 entries
    if print_count < max_prints:
        print(f"Drug BMGC ID: {drug_bmgc_id}, Protein Index: {drug_protein_index}, Protein BMGC ID: {drug_protein_bmgc_id}")
        print_count += 1
    elif print_count == max_prints:
        print("... (remaining entries processed silently)")
        print_count += 1
    
    # Convert NumPy integers to Python integers for JSON serialization
    drug_kg_protein_bmgc_dict[drug_bmgc_id] = {
        'drug_protein_index': [int(x) for x in drug_protein_index],  # Convert numpy ints to Python ints
        'drug_protein_bmgc_id': drug_protein_bmgc_id
    }
    
    # Save every 10 iterations
    if (i + 1) % 10 == 0:
        output_path = "./data/DrugQA/drug_kg_protein_relationships.json"
        with open(output_path, "w") as f:
            json.dump(drug_kg_protein_bmgc_dict, f, indent=2)
        print(f"💾 Auto-saved after processing {i + 1}/{len(dti_drug_ids)} drugs")

# Final save after loop completion
output_path = "./data/DrugQA/drug_kg_protein_relationships.json"
with open(output_path, "w") as f:
    json.dump(drug_kg_protein_bmgc_dict, f, indent=2)
    
print(f"✅ Final save completed - processed {len(drug_kg_protein_bmgc_dict)} drugs and saved to: {output_path}")

##### Formulate the sample related json

In [None]:
def drug_qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=10):
    # LLM Info
    print(f"Sample ACH Name: {sample_ach_name}")
    print(f"Extracting top {k} gene information for {sample_ach_name}...")
    top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    print(f"Extracting top {k} transcript information for {sample_ach_name}...")
    top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    print(f"Extracting top {k} protein information for {sample_ach_name}...")
    top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)
    # KG Info
    edge_index = np.load('./data/DrugQA/edge_index.npy')
    print(f"Extracting disease-related proteins index and bmgc id for {selected_sample_disease_bmgc_id} ({sample_ach_name}) ...")
    disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index, nodes_index_data, node_index_dict, index_node_dict)
    print(f"Knowledge Graph Info: Found {len(disease_protein_index)} disease-related proteins directly connected to {selected_sample_disease_bmgc_id} and {len(ppi_nodes_index)} proteins in their PPI network")
    print(f"Mapping disease-related proteins to HGNC symbols...")
    disease_protein_hgnc_dict, disease_protein_hgnc_list = bmgc_pt_id_to_hgnc(disease_protein_bmgc_id, bmgc_protein_df)
    print(f"Mapping PPI-related proteins to HGNC symbols...")
    ppi_hgnc_dict, ppi_hgnc_list = bmgc_pt_id_to_hgnc(ppi_nodes_bmgc_id, bmgc_protein_df)
    # LLM Used KG Info
    print(f"Extracting protein relationships from BMGC...")
    # Convert the any non-existed string in top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list into empty list []
    if top_k_gene_hgnc_name_list == "non-existed": top_k_gene_hgnc_name_list = []
    if top_k_transcript_hgnc_name_list == "non-existed": top_k_transcript_hgnc_name_list = []
    if top_k_protein_hgnc_name_list == "non-existed": top_k_protein_hgnc_name_list = []
    if disease_protein_hgnc_list == "non-existed": disease_protein_hgnc_list = []
    # Combine all the HGNC symbols into a single list for relationship extraction
    omics_disease_protein_hgnc_list = list(set(top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list))
    relationships_df, relationship_texts = find_protein_relationships(omics_disease_protein_hgnc_list, bmgc_protein_df, bmgc_relation_df)
    print(f"Knowledge Graph Info: Found {len(omics_disease_protein_hgnc_list)} unique proteins and {len(relationship_texts)} relationships between them")
    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list, \
              top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list, \
                top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list, \
                disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id, \
                    disease_protein_hgnc_dict, disease_protein_hgnc_list, ppi_hgnc_dict, ppi_hgnc_list, relationship_texts

# Example usage
sample_ach_name = 'ACH-000002'
selected_sample_disease_bmgc_id = 'BMGC_DS07934'
k = 10
return_tuples = drug_qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k)

In [None]:
import json
import os
from tqdm import tqdm

# Parameters
k = 10
save_every_n = 10

# Output folder and filename
output_dir = "./data/DrugQA"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"drug_qa_k{k}_bm{top_bm}.json"
output_path = os.path.join(output_dir, output_filename)

# Load existing JSON file if it exists
if os.path.exists(output_path):
    with open(output_path, "r") as f:
        multi_sample_qa_json = json.load(f)
    print(f"Loaded existing JSON with {len(multi_sample_qa_json)} processed samples")
else:
    multi_sample_qa_json = {}
    print("No existing JSON found, starting fresh")

# Load sample info data
dti_sample_info_index = pd.read_csv('./data/process_data/dti_combined_samples.csv')
dti_sample_info_index = pd.merge(dti_sample_info_index, remaining_dti_samples_df['depMapID'], how='inner', on='depMapID').reset_index(drop=True)
dti_sample_info_index['BMGC_Disease_name'] = dti_sample_info_index['BMGC_Disease_name'].replace(r' \| ', ' or ', regex=True)
# Insert a new column "Index" in the first position
dti_sample_info_index.insert(0, 'Index', range(1, len(dti_sample_info_index) + 1))
display(dti_sample_info_index)

In [None]:
# Get the list of samples that have already been processed
processed_samples = set(multi_sample_qa_json.keys())
print(f"Found {len(processed_samples)} already processed samples")

count = len(processed_samples)
total_to_process = len(dti_sample_info_index)
remaining = total_to_process - count
print(f"Total samples to process: {total_to_process}, already processed: {count}, remaining: {remaining}")

# Iterate through the sample info dataframe, skipping already processed samples
for idx, row_tuple in tqdm(enumerate(dti_sample_info_index.iterrows()), total=len(dti_sample_info_index)):
    _, row = row_tuple  # Unpack the tuple - index and row data

    sample_ach_name = row["depMapID"]
    
    # Skip if already processed
    if sample_ach_name in processed_samples:
        continue
        
    count += 1
    target_sample_index = row["Index"]
    cell_line_name = row["Name"]
    disease = row["BMGC_Disease_name"]
    disease_bmgc_id = row["BMGC_Disease_ID"]

    print(f"Processing sample {count}/{total_to_process}: {sample_ach_name} ({cell_line_name})")
    print(f"Sample Index: {target_sample_index}")
    print(f"Sample Disease: {disease}")
    print(f"Sample Disease BMGC ID: {disease_bmgc_id}")

    try:
        (top_k_gene_hgnc, top_k_gene_bmgc, top_k_gene_llm,
        top_k_ts_hgnc, top_k_ts_bmgc, top_k_ts_llm,
        top_k_pt_hgnc, top_k_pt_bmgc, top_k_pt_llm,
        dis_pt_idx, dis_pt_bmgc, ppi_idx, ppi_bmgc,
        dis_pt_hgnc_dict, dis_pt_hgnc, 
        ppi_hgnc_dict, ppi_hgnc, relationship_texts) = drug_qa_sample_info(sample_ach_name, disease_bmgc_id, k)

        multi_sample_qa_json[sample_ach_name] = {
            "cell_line_name": cell_line_name,
            "sample_index": target_sample_index,
            "disease": disease,
            "disease_bmgc_id": disease_bmgc_id,
            "input": {
                "top_k_gene": {
                    "hgnc_symbols": top_k_gene_hgnc,
                    "protein_bmgc_ids": top_k_gene_bmgc,
                    "protein_llmname_ids": top_k_gene_llm
                },
                "top_k_transcript": {
                    "hgnc_symbols": top_k_ts_hgnc,
                    "protein_bmgc_ids": top_k_ts_bmgc,
                    "protein_llmname_ids": top_k_ts_llm
                },
                "top_k_protein": {
                    "hgnc_symbols": top_k_pt_hgnc,
                    "protein_bmgc_ids": top_k_pt_bmgc,
                    "protein_llmname_ids": top_k_pt_llm
                },
                "knowledge_graph": {
                    "disease_protein": {
                        "bmgc_ids": dis_pt_bmgc,
                        "hgnc_symbols": dis_pt_hgnc,
                        "indices": dis_pt_idx
                    },
                    "ppi_neighbors": {
                        "bmgc_ids": ppi_bmgc,
                        "hgnc_symbols": ppi_hgnc,
                        "indices": ppi_idx
                    },
                    "protein_relationships": relationship_texts,
                }
            }
        }

    except Exception as e:
        print(f"⚠️ Error processing {sample_ach_name}: {e}")
        continue

    # Periodic save every N samples
    if count % save_every_n == 0:
        with open(output_path, "w") as f:
            json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)
        print(f"💾 Auto-saved JSON at {count}/{total_to_process} samples to: {output_path}")
        print(f"Last processed sample: {sample_ach_name}")
        processed = len(multi_sample_qa_json)
        remaining = total_to_process - processed
        print(f"Progress: {processed}/{total_to_process} ({processed/total_to_process*100:.1f}%), Remaining: {remaining}")

# Final save after loop
with open(output_path, "w") as f:
    json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)

print(f"✅ Final JSON saved to: {output_path}")
print(f"Total samples processed: {len(multi_sample_qa_json)}")