In [273]:
# We want to create a Transformer model which would eventually emits pathways.
# For that we download all GSEA pathways  

In [173]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math 
from collections import Counter
%matplotlib inline

In [187]:
# Function to parse .gmt file and return a list of all unique genes
def parse_gmt_file(filepath):
    genesets = {}
    with open(filepath, 'r') as file:
        for line in file:
            parts = line.strip().split('\t') # Assuming tab-delimited .gmt file
            gene_set_name = parts[0]
            genes = parts[2:] # Skip the gene set name and description
            genesets [gene_set_name] = genes
    return genesets

In [188]:
file_path = 'db/filtered.genesets.gmt' # Replace with your .gmt file path
genesets = parse_gmt_file(file_path)

for key in genesets: # some genesets have expression
    genesets[key] = [e.split(",")[0] if "," in e else e for e in genesets[key]]

for key in genesets:
    genesets[key].append("<|end|>")

In [189]:
# Step 1: Pre-process the database to create a mapping dictionary
def create_gene_mapping(db):
    # Initialize a dictionary to hold the mapping
    gene_mapping = {}
    
    # Add direct symbol mappings
    for _, row in db.iterrows():

        gene = row['symbol']
        gene_mapping[gene] = []
        
        # Add alias symbol mappings
        if pd.notna(row['alias_symbol']):
            for alias in str(row['alias_symbol']).split('|'):
                gene_mapping[gene].append(alias)
        
        # Add previous symbol mappings
        if pd.notna(row['prev_symbol']):
            for prev in str(row['prev_symbol']).split('|'):
                gene_mapping[gene].append(prev)
                
    return gene_mapping

# Function to harmonize genes in genesets_dict
def harmonize_genesets(genesets_dict, gene_mapping, reversed_gene_mapping):
    harmonized_genesets_dict = {}
    i = 0
    total_genesets = len(genesets)
    
    for geneset, genes in genesets_dict.items():
        harmonized_genes = []
        for gene in genes:

            # check if it is updated
            if gene in gene_mapping:
                harmonized_genes.append(gene)
            # if it is not harmonized
            else:
                # check if we can find a mapping
                if gene in reversed_gene_mapping:
                    harmonized_gene = reversed_gene_mapping[gene]
                    if harmonized_gene in harmonized_genes:  # Avoid duplicates
                        # print("Already in the harmonized list!")
                        continue
                    else:
                        #print(f"New: {harmonized_gene}, old: {gene}")
                        harmonized_genes.append(harmonized_gene)
                else:
                    # print(f"No match or mapping for {gene}")
                    harmonized_genes.append(gene)
        harmonized_genesets_dict[geneset] = harmonized_genes

        # print(len(genes), len(harmonized_genes))
        i+=1
        if i % (total_genesets // 10) == 0 or i == total_genesets:  # Update every 10%
            print(f"Processed {i}/{total_genesets} pathways...")
            
    return harmonized_genesets_dict

# Reverse the gene_mapping to facilitate alias/previous name lookup
reversed_gene_mapping = {}
for harmonized, aliases in gene_mapping.items():
    for alias in aliases:
        reversed_gene_mapping[alias] = harmonized


In [190]:
# Pre-process the database to create the mapping
gene_mapping = create_gene_mapping(db)

In [191]:
genesets_harmonized = harmonize_genesets(genesets, gene_mapping, reversed_gene_mapping)

Processed 35111/351117 pathways...
Processed 70222/351117 pathways...
Processed 105333/351117 pathways...
Processed 140444/351117 pathways...
Processed 175555/351117 pathways...
Processed 210666/351117 pathways...
Processed 245777/351117 pathways...
Processed 280888/351117 pathways...
Processed 315999/351117 pathways...
Processed 351110/351117 pathways...
Processed 351117/351117 pathways...


In [377]:
harmonized_gene_perc_per_pathway = {}
for key in updated_genesets_harmonized:
    harmonized_genes_count = 0
    # Check each gene to see if it's in the reversed_gene_mapping
    for gene in genesets_harmonized[key]:
        if gene in gene_mapping:
            harmonized_genes_count += 1
    harmonized_gene_perc_per_pathway[key] = harmonized_genes_count/len(genesets_harmonized[key])

In [379]:
# AFTER GENE NOMENCLATURE 
# Flatten all gene lists into a single list
data = [gene for genes in genesets_harmonized.values() for gene in genes]
# Use Counter to count occurrences of each gene
tokens = Counter(data)

len(tokens)

126551

In [367]:
filtered_genes = [key for key, value in tokens.items() if value > 5] # select genes occuring more than 5 times

In [373]:
# Convert filtered_genes to a set for faster lookup
filtered_genes_set = set(filtered_genes)

In [374]:
# New dictionary to hold the updated genesets, optimized for speed
updated_genesets_harmonized = {}

for geneset, genes in genesets_harmonized.items():
    # Use the set for faster 'in' checks
    updated_genes = [gene for gene in genes if gene in filtered_genes_set]
    updated_genesets_harmonized[geneset] = updated_genes

In [397]:
# Step 2: Function to count harmonized genes in each geneset
def perc_harmonized_genes_in_genesets(genesets_dict, map):
    count_dict = {}
    for geneset, genes in genesets_dict.items():
        # Initialize the count for this geneset
        harmonized_genes_count = 0
        
        # Check each gene to see if it's in the reversed_gene_mapping
        for gene in genes:
            if gene in map:
                harmonized_genes_count += 1
                
        # Store the count for this geneset
        count_dict[geneset] = harmonized_genes_count/len(genes)*100
    
    return count_dict

perc = perc_harmonized_genes_in_genesets(updated_genesets_harmonized, gene_mapping)

In [411]:
# AFTER GENE NOMENCLATURE and TOKEN filtering
# Flatten all gene lists into a single list
data = [gene for genes in updated_genesets_harmonized.values() for gene in genes]
# Use Counter to count occurrences of each gene
tokens = Counter(data)

len(tokens)

55340

In [405]:
selected_keys = {key: value for key, value in perc.items() if value >= 90} # select high-quality human pathways
selected_genesets = {key: updated_genesets_harmonized[key] for key in selected_keys if key in updated_genesets_harmonized}

In [412]:
# AFTER GENE NOMENCLATURE and TOKEN filtering and removing shitty pathways
# Flatten all gene lists into a single list
data = [gene for genes in selected_genesets.values() for gene in genes]
# Use Counter to count occurrences of each gene
tokens = Counter(data)

len(tokens)

44902

In [413]:
print(f"Before sample size: {len(updated_genesets_harmonized)}\nAfter sample size: {len(selected_genesets)}")

Before sample size: 351117
After sample size: 263730


In [414]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(tokens) }
itos = { i:ch for i,ch in enumerate(tokens) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [415]:
# create the train and test splits
n = len(data)
train_data = data[:int(n*0.95)]
val_data = data[int(n*0.95):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)

vocab_size = len(tokens)
print(f"vocab size: {vocab_size:,}")
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

vocab size: 44,902
train has 47,316,670 tokens
val has 2,490,352 tokens


In [529]:
filtered_dict = {key: value for key, value in selected_genesets.items() if "GSE149609" in key}

{}

In [552]:
# number of pathways
Counter(val_data)['<|end|>']

10019

In [556]:
import random

random_validation_set = list(selected_genesets.items())[-1000:]
random_validation_set = dict(random_validation_set)

with open('random_val_1000.pkl', 'wb') as f:
    pickle.dump(random_validation_set, f)

In [419]:
import osd
import pickle

# __file__ = '.'

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))

# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)