# How close are Codon and Amino Acid embeddings of the same sequence?

## Start with visualization
Let's look at t-SNE of (1) the codon embeddings, (2) the amino acid embeddings, and (3) the combined embeddings.

### First generate embeddings for codon and amino acid sequences

In [None]:
# TODO: Implement JIT compile model: https://huggingface.co/docs/transformers/torchscript
def generate_embeddings(model, dataloader):
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = batch.to(model.device)
            outputs = model(**batch, output_hidden_states=True)
            last_hidden_states = outputs.hidden_states[-1]
            seq_lengths = attention_mask.sum(axis=1)
            for seq_len, elem in zip(seq_lengths, last_hidden_states):
                embedding = elem[1 : seq_len - 1, :].cpu().numpy()
                embeddings.append(embedding)

    return np.array(embeddings)

In [3]:
from dataclasses import dataclass
from torch.utils.data import DataLoader
from transformers import (
    PreTrainedTokenizerFast,
    BatchEncoding,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    EsmForMaskedLM,
    EsmTokenizer,
)

from train import GenSLMColatorForLanguageModeling, FastaDataset

@dataclass
class GenSLMInferenceConfig:
    # Could also be an ESM checkpoint, e.g. "esm1_t34_670M_UR50S"
    model_path: str = "facebook/esm2_t6_8M_UR50D"
    tokenizer_path: str = "tokenizer_esm_genslm"
    data_path: str = "/lambda_stor/homes/khippe/genslm_foundation/genome_data/mdh_sc23/fasta/mdh_natural_sequences.ffn"
    per_device_eval_batch_size: int = 128

config = GenSLMInferenceConfig()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = EsmTokenizer.from_pretrained(config.tokenizer_path)
model = EsmForMaskedLM.from_pretrained(config.model_path)
model.to(device).eval()

dataset = FastaDataset(   
    file_path=config.data_path,
    return_codon=True,
    return_aminoacid=True,
)

data_collator = GenSLMColatorForLanguageModeling(
    return_codon=True,
    return_aminoacid=False,
    tokenizer=tokenizer,
)

dataloader = DataLoader(
    dataset,
    batch_size=config.per_device_eval_batch_size,
    collate_fn=data_collator,
    num_workers=1,
    pin_memory=True
)

embeddings = generate_embeddings(model, dataloader)


## Investigate whether the Codon and AA embeddings are within the top-k nearest neighbors
1. Embed codon sequences
2. Embed AA sequences
3. Find the top-k nearest neighbors for each sequence
4. For each codon sequence, count the number of times the AA sequence is in the top-k nearest neighbors