# Explore the embeddings of the entire model when given sequence and structure - can we extract fixed length embeddings

This is a user facing intended use case, so should be available easily from the given api

In [25]:
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    ESMProteinTensor,
    GenerationConfig,
    ForwardConfig
)

import torch

In [3]:
model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cpu"))

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

## we do not want to do any embeddings, so this should either be an output of `__forward__` or we will have to construct out own set of calls to the torch modules

In [4]:
pdb_id = "1ITU" # PDB ID corresponding to Renal Dipeptidase
chain_id = "A" # Chain ID corresponding to Renal Dipeptidase in the PDB structure
example = ProteinChain.from_rcsb(pdb_id, chain_id)

In [8]:
example = ESMProtein.from_protein_chain(example, with_annotations=True)

In [11]:
config = ForwardConfig(
    return_embeddings=True
)

In [23]:
prompt = model.encode(example)

In [32]:
outs = model.forward(

    sequence_tokens=prompt.sequence.unsqueeze(0),
    structure_tokens=prompt.structure.unsqueeze(0),
    ss8_tokens=prompt.secondary_structure.unsqueeze(0),
    sasa_tokens=prompt.sasa.unsqueeze(0),
    structure_coords=prompt.coordinates.unsqueeze(0),
)

In [33]:
outs

ESMOutput(sequence_logits=tensor([[[-19.8999, -19.8482, -19.7283,  ..., -19.8965, -19.8223, -19.8756],
         [-19.7770, -19.7875, -19.7118,  ..., -19.6097, -19.6936, -19.7686],
         [-20.8810, -20.8687, -20.8972,  ..., -20.6740, -20.8086, -20.8622],
         ...,
         [-21.9760, -21.9503, -21.9068,  ..., -21.9945, -21.8641, -21.9197],
         [-21.9072, -21.8573, -21.8436,  ..., -21.8125, -21.7358, -21.8285],
         [-20.7629, -20.8298, -20.7420,  ..., -20.8609, -20.7464, -20.7786]]],
       grad_fn=<ViewBackward0>), structure_logits=tensor([[[20.5455, 17.6710, 22.3964,  ..., 17.8930, 15.7483, 18.1989],
         [14.2900, 18.6756, 18.8012,  ..., 19.9908,  5.7742, 13.6581],
         [23.2597, 22.9823, 25.4071,  ..., 24.7140, 22.0978, 23.7403],
         ...,
         [21.6033, 20.2033, 26.9223,  ..., 19.4511, 15.9058, 27.8980],
         [25.1001, 18.6467, 29.0939,  ..., 19.5094, 13.6320, 22.2143],
         [25.2722, 22.8813, 25.9067,  ..., 23.1652, 19.3094, 17.3553]]],
    

In [39]:
outs.embeddings.shape

torch.Size([1, 371, 1536])

## embeddings is in the base output! Good on them