# Embedding Novel Species

This notebook will create the files you need to embed a novel species that wasn't included in the training data.

To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.

You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)

For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).

If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings).

In [1]:
import numpy as np
import pickle as pkl
import pandas as pd

In [2]:
SPECIES_NAME = "chicken" # short hand name for this species, will be used in arguments and files

# Path to the species proteome
SPECIES_PROTEIN_FASTA_PATH = "../../../SATURN/protein_embeddings/data/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.fa"

# Path to the ESM2 Embeddings
SPECIES_PROTEIN_EMBEDDINGS_PATH = "../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt"

# primary_assembly name, this needs to be matched to the FASTA file
ASSEMBLY_NAME = "bGalGal1.mat.broiler.GRCg7b"
# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,
# randomly generated chromosome tokens will be the same
TAXONOMY_ID = 9031

You can view the FASTA format here, please confirm the primary_assembly name is correct.

In [3]:
!head {SPECIES_PROTEIN_FASTA_PATH}

>ENSGALP00010000002.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:2824:3798:1 gene:ENSGALG00010000007.1 transcript:ENSGALT00010000007.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND1 description:NADH dehydrogenase subunit 1 [Source:NCBI gene (formerly Entrezgene);Acc:63549479]
MTLPTLTNLLIMTLSYILPILIAVAFLTLVERKILSYMQARKGPNIVGPFGLLQPVADGV
KLFIKEPIRPSTSSPFLFIITPILALLLALTIWVPLPLPFPLADLNLGLLFLLAMSSLTV
YSLLWSGWASNSKYALIGALRAVAQTISYEVTLAIILLSTIMLSGNYTLSTLAITQEPIY
LIFSAWPLAMMWYISTLAETNRAPFDLTEGESELVSGFNVEYAAGPFAMFFLAEYANIML
MNTLTTVLFLNPSFLNLPPELFPIALATKTLLLSSSFLWIRASYPRFRYDQLMHLLWKNF
LPLTLALCLWHTSMPISYAGLPPI
>ENSGALP00010000003.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:4015:5053:1 gene:ENSGALG00010000011.1 transcript:ENSGALT00010000011.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND2 description:NADH dehydrogenase subunit 2 [Source:NCBI gene (formerly Entrezgene);Acc:63549482]
MNPHAKLICTVSLIMGTSITISSNHWIL

In [4]:
species_to_paths = {
    SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,
}

species_to_ids = {
    SPECIES_NAME: ASSEMBLY_NAME,
}

In [5]:
all_pos_def = []

missing_genes = {}
for species in species_to_ids.keys():
    missing_genes[species] = []
    proteome_path = species_to_paths[species]
    species_id = species_to_ids[species]

    with open(proteome_path) as f:
        proteome_lines = f.readlines()

    gene_symbol_to_location = {}
    gene_symbol_to_chrom = {}

    for line in proteome_lines:
        if line.startswith(">"):
            split_line = line.split()
            gene_symbol = [token for token in split_line if token.startswith("gene_symbol")]
            if len(gene_symbol) > 0:
                gene_symbol = gene_symbol[0].split(":")
                
                if len(gene_symbol) == 2:
                    gene_symbol = gene_symbol[1]
                elif len(gene_symbol) > 2:
                    gene_symbol = ":".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them
                else:
                    1/0 # something weird happening, throw an error
                
                
                chrom = None
                
                chrom_arr = [token for token in split_line if token.startswith("chromosome:")]
                if len(chrom_arr) > 0:
                    chrom = chrom_arr[0].replace("chromosome:", "")
                else:
                    chrom_arr = [token for token in split_line if token.startswith("primary_assembly:")]
                    if len(chrom_arr) > 0:
                        chrom = chrom_arr[0].replace("primary_assembly:", "")
                    else:
                        chrom_arr = [token for token in split_line if token.startswith("scaffold:")] 
                        if len(chrom_arr) > 0:
                            chrom = chrom_arr[0].replace("scaffold:", "")
                if chrom is not None:
                    gene_symbol_to_location[gene_symbol] = chrom.split(":")[2]
                    gene_symbol_to_chrom[gene_symbol] = chrom.split(":")[1]
                else:
                    missing_genes[species].append(gene_symbol)
                    

    positional_df = pd.DataFrame()
    positional_df["gene_symbol"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]
    positional_df["chromosome"] = list(gene_symbol_to_chrom.values())
    positional_df["start"] = list(gene_symbol_to_location.values())
    positional_df = positional_df.sort_values(["chromosome", "start"])
    #positional_df = positional_df.set_index("gene_symbol")
    positional_df["species"] = species
    all_pos_def.append(positional_df)

In [6]:
master_pos_def = pd.concat(all_pos_def)
master_pos_def

Unnamed: 0,gene_symbol,chromosome,start,species
2327,GCC1,1,1006145,chicken
2502,NCAM2,1,100828671,chicken
3084,ENS-2,1,101147482,chicken
2331,DENND6B,1,1012031,chicken
3973,MRPL39,1,102578362,chicken
...,...,...,...,...
4722,CA9,Z,9779343,chicken
4738,ARHGEF39,Z,9835547,chicken
3885,MRPL17,Z,9850679,chicken
4172,CCBE1,Z,9852827,chicken


In [7]:
master_pos_def["species"].value_counts() # double check how many genes are mapped

chicken    13271
Name: species, dtype: int64

In [8]:
for k, v in missing_genes.items():
    print(f"{k}: {len(v)}") # are any genes missing?

chicken: 0


In [9]:
# Count genes per chromosome
for species in species_to_ids.keys():
    print("*********")
    print(species)
    display(master_pos_def[master_pos_def["species"] == species]["chromosome"].value_counts().head(50))
    print("*********")

*********
chicken


1                    1785
2                    1169
3                    1067
4                     953
5                     817
Z                     629
6                     458
8                     450
7                     442
9                     382
10                    366
14                    359
11                    327
15                    326
13                    306
20                    298
12                    293
19                    278
18                    274
17                    260
26                    237
28                    237
27                    235
21                    226
23                    214
25                    176
34                    155
24                    149
22                    142
16                     54
30                     52
38                     49
31                     14
MT                     13
39                     10
JAENSK010000484.1       7
35                      6
JAENSK010000592.1       6
W           

*********


In [10]:
master_pos_def.to_csv(f"{SPECIES_NAME}_to_chrom_pos.csv", index=False) # Save the DF

In [11]:
# The chromosome file path will be:
print(f"{SPECIES_NAME}_to_chrom_pos.csv")

chicken_to_chrom_pos.csv


In [12]:
N_UNIQ_CHROM = len(master_pos_def[master_pos_def["species"] == species]["chromosome"].unique())
N_UNIQ_CHROM

66

# Generate token file

In [13]:
import torch
import pickle
token_dim = 5120

This will create the token file. Please note the offset value.

In [14]:
species_to_offsets = {}

all_pe = torch.load("../model_files/all_tokens.torch")[0:4] # read in existing token file to make sure 
# that special vocab tokens are the same for different seeds

offset = len(all_pe) # special tokens at the top!

PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)

pe_stacked = torch.stack(list(PE.values()))
all_pe = torch.vstack((all_pe, pe_stacked))
species_to_offsets[species] = offset

print("CHROM_TOKEN_OFFSET:", all_pe.shape[0])
torch.manual_seed(TAXONOMY_ID)
CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) 
# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)
all_pe = torch.vstack(
    (all_pe, CHROM_TENSORS))  # Add the chrom tensors to the end
all_pe.requires_grad = False


torch.save(all_pe, f"{SPECIES_NAME}_pe_tokens.torch")

with open(f"{SPECIES_NAME}_offsets.pkl", "wb+") as f:
    pickle.dump(species_to_offsets, f)
print("Saved PE, offsets file")

CHROM_TOKEN_OFFSET: 13275
Saved PE, offsets file


In [15]:
all_pe.shape

torch.Size([13341, 5120])

In [16]:
all_pe.shape

torch.Size([13341, 5120])

In [17]:
print(f"{SPECIES_NAME}_offsets.pkl")

chicken_offsets.pkl


In [18]:
SPECIES_PROTEIN_EMBEDDINGS_PATH

'../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt'

# Example evaluation of new species

**Note: when you evaluate a new species, you need to change some arguments and modify some files:**

You will  need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.

In the file add a row for the new species with the format:
`species name,full path to protein embedding file`

Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.

When you want to embed this new species, you will need to specify these newly created files as arguments.
- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.
- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions
- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!
- `offset_pkl_path`: This is another file that maps genes to tokens


```

accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True

```