## Encoder Test

In this notebook, we will look into the different encoders (for sequence and structure). 

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, PretrainedConfig

from Bio import SeqIO

In [2]:
class ESMSeqEncoder(nn.Module):
    def __init__(
        self,
        model_name: str = 'facebook/esm2_t6_8M_UR50D',
        args=None,
        delay_load: bool = False,
        no_pooling: bool = False,   # NEW: return full per-residue embeddings?
    ):
        super().__init__()
        self.is_loaded = False

        self.model_name = model_name
        self.select_layer = getattr(args, 'protein_select_layer', -1)
        self.pooling = getattr(args, 'protein_pooling', 'cls')  # 'cls' or 'mean'
        self.no_pooling = no_pooling  # NEW flag

        if not delay_load:
            self.load_model()

    def load_model(self, device_map=None):
        if self.is_loaded:
            print(f'{self.model_name} is already loaded. Skipping load.')
            return

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
        self.encoder = AutoModel.from_pretrained(
            self.model_name,
            device_map=device_map,
            trust_remote_code=True,
            output_hidden_states=True
        )
        # Freeze encoder weights by default
        self.encoder.requires_grad_(False)

        self.is_loaded = True

    def tokenize(self, sequences):
        return self.tokenizer(
            sequences,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=1024
        )

    @torch.no_grad()
    def forward(self, sequences):
        if not self.is_loaded:
            self.load_model()

        # Tokenize & move to model device
        inputs = self.tokenize(sequences)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Run through ESM, grab hidden states
        outputs = self.encoder(**inputs)
        # hidden_states is a tuple: (layer0, layer1, ..., layerN)
        hidden_states = outputs.hidden_states[self.select_layer]  # (batch, seq_len, hidden_size)

        if self.no_pooling:
            # Return full sequence embeddings
            return hidden_states

        # Otherwise pool to single vector per sequence
        if self.pooling == 'cls':
            # CLS token is at position 0
            features = hidden_states[:, 0, :]
        elif self.pooling == 'mean':
            mask = inputs['attention_mask'].unsqueeze(-1).expand_as(hidden_states)
            sum_emb = torch.sum(hidden_states * mask, dim=1)
            counts = mask.sum(dim=1).clamp(min=1e-9)
            features = sum_emb / counts
        else:
            raise ValueError(f"Unsupported pooling type: {self.pooling}")

        return features
    
    # Load sequence from FASTA
    def load_fasta_sequence(self, fasta_path):
        record = next(SeqIO.parse(fasta_path, "fasta"))
        return str(record.seq)

    @property
    def dtype(self):
        if not self.is_loaded:
            # If not loaded, infer from config (usually fp32)
            return torch.get_default_dtype()
        return self.encoder.dtype

    @property
    def device(self):
        if not self.is_loaded:
            return torch.device('cpu')
        # encoder.device may be a map for multi-GPU; pick first
        dev = next(self.encoder.parameters()).device
        return dev

    @property
    def config(self):
        if self.is_loaded:
            return self.encoder.config
        return PretrainedConfig.from_pretrained(self.model_name)

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def dummy_feature(self):
        """
        Returns a zero tensor matching the shape of the output:
        - (1, seq_len, hidden_size) if sequence_output, else (1, hidden_size)
        Note: seq_len = 1 for dummy by default.
        """
        if self.no_pooling:
            # dummy single residue embedding
            return torch.zeros(1, 1, self.hidden_size, device=self.device, dtype=self.dtype)
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)


In [3]:


# Initialize encoder
encoder = ESMSeqEncoder(
    model_name='facebook/esm2_t6_8M_UR50D',
    no_pooling=True  # Set to True if you want per-residue embeddings
)
encoder.load_model()

# Load test sequence
sequence = encoder.load_fasta_sequence("../asset/demo_seq_str/pdb_1ubq/rcsb_pdb_1UBQ.fasta")

# Run through encoder
with torch.no_grad():
    output = encoder([sequence])  # Note: input is a list of sequences

print("Output shape:", output.shape)  # (1, hidden_size) or (1, seq_len, hidden_size) if no_pooling=True

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


facebook/esm2_t6_8M_UR50D is already loaded. Skipping load.
Output shape: torch.Size([1, 78, 320])


In [4]:
output

tensor([[[ 0.1612,  0.9960, -0.2067,  ...,  1.1084, -0.2607, -0.5740],
         [-0.0533,  0.7383, -0.3275,  ...,  0.6771, -0.3348,  0.0325],
         [ 0.0191,  0.1796, -0.0121,  ...,  0.2409,  0.1158,  0.0507],
         ...,
         [ 0.8836, -0.0777, -0.1994,  ...,  0.2270,  0.3404,  0.0509],
         [ 0.3051, -0.4682,  0.1557,  ..., -0.3124,  0.4537, -0.8090],
         [ 0.1718,  0.5675,  0.0740,  ...,  0.2063, -0.1933, -0.4174]]])

In [7]:
import torch
import torch.nn as nn
import esm
from esm.inverse_folding.util import extract_coords_from_structure, get_encoder_output

from typing import Optional
from biotite.structure import AtomArray
from biotite.structure.io.pdbx import CIFFile, get_structure as get_structure_cif
from biotite.structure.io.pdb import PDBFile, get_structure as get_structure_pdb
import biotite.structure as struc  # for filter_peptide_backbone

def load_structure(
    file_path: str,
    chain: Optional[str] = None,
    model: int = 1
) -> AtomArray:
    """
    Load a protein structure from .cif/.mmcif or .pdb, select one model & chain,
    then filter to peptide backbone atoms only.
    """
    ext = file_path.split('.')[-1].lower()
    # Read & convert to AtomArray
    if ext in ("cif", "mmcif"):
        cif    = CIFFile.read(file_path)
        struct = get_structure_cif(cif, model=model)
    elif ext == "pdb":
        pdb    = PDBFile.read(file_path)
        struct = get_structure_pdb(pdb, model=model)
    else:
        raise ValueError(f"Unsupported extension '.{ext}'")

    # Optional chain selection
    if chain is not None:
        struct = struct[struct.chain_id == chain]

    # **Filter to peptide backbone (drops waters, side-chains, non-standard residues)**
    backbone_mask = struc.filter_peptide_backbone(struct)
    struct = struct[backbone_mask]

    return struct

class ESMIFEncoder(nn.Module):
    def __init__(
        self,
        model_name: str = "esm_if1_gvp4_t16_142M_UR50",
        args=None,
        delay_load: bool = False,
        no_pooling: bool = False,
    ):
        super().__init__()
        self.model_name = model_name
        self.no_pooling = no_pooling
        self.is_loaded = False

        if not delay_load:
            self.load_model()

    def load_model(self):
        if self.is_loaded:
            print(f"{self.model_name} already loaded. Skipping.")
            return

        # Load the inverse-folding model and its alphabet
        model, alphabet = getattr(esm.pretrained, self.model_name)()
        model = model.eval().requires_grad_(False)
        self.model = model
        self.alphabet = alphabet
        self.is_loaded = True

    @torch.no_grad()
    def forward(self, structure_path: str, chain: str = None):
        if not self.is_loaded:
            self.load_model()

        # 1) Load and filter backbone atoms / select chain
        structure = load_structure(structure_path, chain)

        # 2) Extract (L × 3 × 3) coords tensor + sequence string
        coords, seq = extract_coords_from_structure(structure)

        # 3) Convert coords to torch tensor
        coords_tensor = torch.tensor(coords, dtype=torch.float32)

        # 4) Run the inverse-folding model
        encoder_out = get_encoder_output(self.model, self.alphabet, coords_tensor)
        # embeddings = encoder_out["representations"]  # (L, hidden_size)
        embeddings = encoder_out
        

        if self.no_pooling:
            # Return per-residue (1, L, hidden_size)
            return embeddings.unsqueeze(0)
        else:
            # Mean pool over L residues → (1, hidden_size)
            return embeddings.mean(dim=0, keepdim=True)

    @property
    def device(self):
        if not self.is_loaded:
            return torch.device("cpu")
        return next(self.model.parameters()).device

    @property
    def dtype(self):
        if not self.is_loaded:
            return torch.get_default_dtype()
        return next(self.model.parameters()).dtype

    @property
    def hidden_size(self):
        if not self.is_loaded:
            self.load_model()
        return self.model.embed_dim

    @property
    def dummy_feature(self):
        """
        - If no_pooling=True: returns (1,1,hidden_size)
        - Else: (1,hidden_size)
        """
        if self.no_pooling:
            return torch.zeros(1, 1, self.hidden_size,
                               device=self.device, dtype=self.dtype)
        else:
            return torch.zeros(1, self.hidden_size,
                               device=self.device, dtype=self.dtype)

In [10]:
# Initialize the encoder
encoder = ESMIFEncoder(no_pooling = True)

# Specify the path to your PDB or mmCIF file
structure_path = "../asset/demo_seq_str/pdb_1ubq/1ubq.cif"

# Optionally, specify the chain of interest
chain = "A"

# Get the protein embeddings
embeddings = encoder(structure_path, chain)

  F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)


In [11]:
embeddings.size()

torch.Size([1, 76, 512])

In [5]:
# import torch
# import torch.nn as nn
# import esm
# from esm.inverse_folding.util import load_structure, extract_coords_from_structure, get_encoder_output

# class ESMIFEncoder(nn.Module):
#     def __init__(
#         self,
#         model_name: str = "esm_if1_gvp4_t16_142M_UR50",
#         args=None,
#         delay_load: bool = False,
#         no_pooling: bool = False,
#     ):
#         super().__init__()
#         self.model_name = model_name
#         self.no_pooling = no_pooling
#         self.is_loaded = False

#         if not delay_load:
#             self.load_model()

#     def load_model(self):
#         if self.is_loaded:
#             print(f"{self.model_name} already loaded. Skipping.")
#             return

#         # Load the inverse-folding model and its alphabet
#         model, alphabet = getattr(esm.pretrained, self.model_name)()
#         model = model.eval().requires_grad_(False)
#         self.model = model
#         self.alphabet = alphabet
#         self.is_loaded = True

#     @torch.no_grad()
#     def forward(self, structure_path: str, chain: str = None):
#         if not self.is_loaded:
#             self.load_model()

#         # 1) Load and filter backbone atoms / select chain
#         structure = load_structure(structure_path, chain)

#         # 2) Extract (L × 3 × 3) coords tensor + sequence string
#         coords, seq = extract_coords_from_structure(structure)

#         # 3) Convert coords to torch tensor
#         coords_tensor = torch.tensor(coords, dtype=torch.float32)

#         # 4) Run the inverse-folding model
#         encoder_out = get_encoder_output(self.model, self.alphabet, coords_tensor)
#         # embeddings = encoder_out["representations"]  # (L, hidden_size)
#         embeddings = encoder_out
        

#         if self.no_pooling:
#             # Return per-residue (1, L, hidden_size)
#             return embeddings.unsqueeze(0)
#         else:
#             # Mean pool over L residues → (1, hidden_size)
#             return embeddings.mean(dim=0, keepdim=True)

#     @property
#     def device(self):
#         if not self.is_loaded:
#             return torch.device("cpu")
#         return next(self.model.parameters()).device

#     @property
#     def dtype(self):
#         if not self.is_loaded:
#             return torch.get_default_dtype()
#         return next(self.model.parameters()).dtype

#     @property
#     def hidden_size(self):
#         if not self.is_loaded:
#             self.load_model()
#         return self.model.embed_dim

#     @property
#     def dummy_feature(self):
#         """
#         - If no_pooling=True: returns (1,1,hidden_size)
#         - Else: (1,hidden_size)
#         """
#         if self.no_pooling:
#             return torch.zeros(1, 1, self.hidden_size,
#                                device=self.device, dtype=self.dtype)
#         else:
#             return torch.zeros(1, self.hidden_size,
#                                device=self.device, dtype=self.dtype)

In [65]:
esm.inverse_folding.util.get_encoder_output?

[0;31mSignature:[0m [0mesm[0m[0;34m.[0m[0minverse_folding[0m[0;34m.[0m[0mutil[0m[0;34m.[0m[0mget_encoder_output[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0malphabet[0m[0;34m,[0m [0mcoords[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/anaconda3/envs/torch_env/lib/python3.12/site-packages/esm/inverse_folding/util.py
[0;31mType:[0m      function

In [None]:
# import pickle
# from typing import Optional
# from biotite.structure import AtomArray
# from biotite.structure.io.pdbx import CIFFile, get_structure as get_structure_cif
# from biotite.structure.io.pdb import PDBFile, get_structure as get_structure_pdb
# import biotite.structure as struc  # for filter_peptide_backbone

# atom_names = [
#     "N", "CA", "C", "O", "CB",
#     "CG", "CG1", "CG2", "CD", "CD1", "CD2",
#     "CE", "CE1", "CE2", "CE3", "CZ", "CZ2", "CZ3", "CH2",
#     "ND1", "ND2", "NE", "NE1", "NE2", "NH1", "NH2",
#     "NZ", "OD1", "OD2", "OE1", "OE2", "OG", "OG1", "OH",
#     "SD", "SG"
# ]

# def load_structure(
#     file_path: str,
#     chain: Optional[str] = None,
#     model: int = 1
# ) -> AtomArray:
#     """
#     Load a protein structure from .cif/.mmcif or .pdb, select one model & chain,
#     then filter to peptide backbone atoms only.
#     """
#     ext = file_path.split('.')[-1].lower()
#     if ext == "pkl":
#     with open(file_path, "rb") as f:
#         data = pickle.load(f)

#     coords_all_atoms = data["atom_positions"]        # (L, 37, 3)
#     atom_mask = data["atom_mask"]                    # (L, 37)
#     aatype = data["aatype"]                          # (L,)
#     res_idx = data["residue_index"]                  # (L,)
#     chain_idx = data["chain_index"]                  # (L,)
#     modeled_idx = data.get("modeled_idx", np.arange(len(aatype)))

#     # Fixed order of 37 standard atom names
#     fixed_atom_names = [
#         "N", "CA", "C", "O", "CB",
#         "CG", "CG1", "CG2", "CD", "CD1", "CD2",
#         "CE", "CE1", "CE2", "CE3", "CZ", "CZ2", "CZ3", "CH2",
#         "ND1", "ND2", "NE", "NE1", "NE2", "NH1", "NH2",
#         "NZ", "OD1", "OD2", "OE1", "OE2", "OG", "OG1", "OH",
#         "SD", "SG"
#     ]

#     # Collect atom-level data
#     atom_name_list = []
#     coord_list = []
#     res_id_list = []
#     chain_id_list = []

#     for i in range(coords_all_atoms.shape[0]):
#         res_atoms = coords_all_atoms[i]          # (37, 3)
#         res_mask = atom_mask[i] > 0.0            # (37,)

#         for j in range(37):
#             if not res_mask[j]:
#                 continue
#             atom_name_list.append(fixed_atom_names[j])
#             coord_list.append(res_atoms[j])
#             res_id_list.append(res_idx[i])
#             chain_id_list.append(chr(65 + chain_idx[i]))  # assumes A-Z chains

#     # Build AtomArray
#     coords = np.array(coord_list)
#     atom_names = np.array(atom_name_list)
#     res_ids = np.array(res_id_list)
#     chain_ids = np.array(chain_id_list)

#     atoms = AtomArray(len(coords))
#     atoms.coord = coords
#     atoms.atom_name = atom_names
#     atoms.res_id = res_ids
#     atoms.chain_id = chain_ids
#     atoms.element = guess_element(atom_names)

#     # Optional chain filtering
#     if chain is not None:
#         atoms = atoms[atoms.chain_id == chain]

#     # Keep only backbone atoms (N, CA, C)
#     backbone_mask = filter_peptide_backbone(atoms)
#     return atoms[backbone_mask]
#     elif ext in ("cif", "mmcif"):
#         cif = CIFFile.read(file_path)
#         struct = get_structure_cif(cif, model=model)
#     elif ext == "pdb":
#         pdb    = PDBFile.read(file_path)
#         struct = get_structure_pdb(pdb, model=model)
#     else:
#         raise ValueError(f"Unsupported extension '.{ext}'")

#     # Optional chain selection
#     if chain is not None:
#         struct = struct[struct.chain_id == chain]

#     # **Filter to peptide backbone (drops waters, side-chains, non-standard residues)**
#     backbone_mask = struc.filter_peptide_backbone(struct)
#     struct = struct[backbone_mask]

#     return struct

In [28]:
# import esm
# from esm.data import Alphabet
# from esm.inverse_folding.util import extract_coords_from_structure
# # from esm.inverse_folding.util import extract_coords_from_complex

# from biotite.structure import filter_amino_acids

# # patch esm so hub loader finds Alphabet
# esm.Alphabet = Alphabet

# class ESMIFEncoder(nn.Module):
#     def __init__(
#         self,
#         model_name: str = "esm_if1_gvp4_t16_142M_UR50",
#         args=None,
#         delay_load: bool = False,filter_peptide_backbone
#         no_pooling: bool = False,
#     ):
#         super().__init__()
#         self.model_name = model_name
#         self.no_pooling = no_pooling
#         self.is_loaded = False
#         if not delay_load:
#             self._load_model()

#     def _load_model(self):
#         if self.is_loaded:
#             return
#         model, alphabet = torch.hub.load(
#             "facebookresearch/esm:main",
#             self.model_name
#         )
#         self.model = model.eval().requires_grad_(False)
#         self.alphabet = alphabet
#         self.is_loaded = True

#     # @torch.no_grad()
#     # def forward(self, structure: AtomArray):
#     #     """
#     #     Expects a loaded AtomArray (via load_structure), not a path.
#     #     """
#     #     if not self.is_loaded:
#     #         self._load_model()

#     #     # Extract coordinates and sequence for the complex structure
#     #     coords, native_seqs = extract_coords_from_structure(structure)

#     #     # Convert coords to a tensor
#     #     coords = torch.tensor(coords, dtype=torch.float32)
#     #     coords = coords.to(self.device)

#     #     # Directly call the model without passing the additional arguments
#     #     out = self.model({"coords": coords})

#     #     # Extract embeddings
#     #     emb = out["representations"]

#     #     # Return either pooled or per-residue embeddings
#     #     if self.no_pooling:
#     #         return emb.unsqueeze(0)  # Add batch dimension
#     #     return emb.mean(dim=0, keepdim=True)  # Pool embeddings

#     @torch.no_grad()
#     def forward(self, structure: AtomArray):
#         """
#         Expects a loaded AtomArray (via load_structure), not a path.
#         """
#         if not self.is_loaded:
#             self._load_model()

#         # Extract coordinates and sequence from the structure
#         coords, seq = extract_coords_from_structure(structure)

#         # Convert coords to a tensor
#         coords = torch.tensor(coords, dtype=torch.float32)
#         coords = coords.to(self.device)

#         # Create default "empty" values for the missing arguments
#         padding_mask = torch.zeros((coords.shape[0],), dtype=torch.bool, device=self.device)  # Dummy padding mask
#         confidence = torch.ones((coords.shape[0],), dtype=torch.float32, device=self.device)  # Dummy confidence
#         prev_output_tokens = torch.zeros((coords.shape[0],), dtype=torch.long, device=self.device)  # Dummy previous tokens

#         # Call the model with the required dummy values
#         out = self.model({
#             "coords": coords,
#             "padding_mask": padding_mask,
#             "confidence": confidence,
#             "prev_output_tokens": prev_output_tokens
#         })

#         # Extract embeddings
#         emb = out["representations"]

#         # Return either pooled or per-residue embeddings
#         if self.no_pooling:
#             return emb.unsqueeze(0)  # Add batch dimension
#         return emb.mean(dim=0, keepdim=True)  # Pool embeddings

#     @property
#     def device(self):
#         return (next(self.model.parameters()).device
#                 if self.is_loaded else torch.device("cpu"))

#     @property
#     def dtype(self):
#         return (next(self.model.parameters()).dtype
#                 if self.is_loaded else torch.get_default_dtype())

#     @property
#     def hidden_size(self):
#         if not self.is_loaded:
#             self._load_model()
#         return self.model.embed_dim

#     @property
#     def dummy_feature(self):
#         if self.no_pooling:
#             return torch.zeros(1, 1, self.hidden_size,
#                                device=self.device, dtype=self.dtype)
#         return torch.zeros(1, self.hidden_size,
#                            device=self.device, dtype=self.dtype)

In [29]:
# from biotite.structure import filter_peptide_backbone
import esm.inverse_folding


In [17]:
# from esm.inverse_folding.model import InverseFoldingModel
# from esm.inverse_folding.util  import load_coords, extract_coords_from_pdb



AttributeError: module 'biotite.structure.io.pdbx' has no attribute 'PDBxFile'

In [55]:

# # 3.1) Load your structure from CIF
# pdb_file = "../asset/demo_seq_str/pdb_1ubq/1ubq.cif"


# # structure = load_structure(pdb_file, chain="A", model=1)

# # Assuming 'structure' is your AtomArray or AtomArrayStack
# # structure = filter_amino_acids(structure)
# # 3.2) Initialize and run the encoder
# encoder = ESMIFEncoder(no_pooling=False)
# embeddings = encoder(pdb_file)

# # 3.3) Inspect output
# print("Embeddings shape:", embeddings.shape)
# # → (1, L, hidden_size), where L = number of residues in chain A


  F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)


IndexError: too many indices for tensor of dimension 2

In [None]:
print(dir(esm))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', 'axial_attention', 'constants', 'data', 'inverse_folding', 'modules', 'multihead_attention', 'rotary_embedding']


In [30]:
# import pickle

# Path to the .pkl file
file_path = '../asset/demo_seq_str/pdb_1ubq/1ubq.pkl'

# Open the file and load the content
with open(file_path, 'rb') as file:
    pkl_data = pickle.load(file)

# Now, `data` contains the object that was saved in the .pkl file
print(pkl_data.keys())



dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors', 'bb_mask', 'bb_positions', 'modeled_idx'])


In [31]:
pkl_data["atom_positions"]

array([[[-3.10289063e+00, -4.57898263e+00, -1.29023926e+01],
        [-4.17689003e+00, -3.59598283e+00, -1.26743927e+01],
        [-3.52989067e+00, -2.36998299e+00, -1.19853928e+01],
        ...,
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00]],

       [[-4.10789169e+00, -1.23898247e+00, -1.22583928e+01],
        [-3.59289040e+00,  1.20169763e-02, -1.16183927e+01],
        [-4.34289040e+00,  2.44017327e-01, -1.03143925e+01],
        ...,
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
        [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00]],

       [[-3.59388985e+00,  6.47017205e-01, -9.29939266e+00],
        [-4.20789017e+00,  1.04901763e+00, -8.01939245e+00],
        [-3.56089081e+00,  2.41901656e+00, -7.65439268e+00],
        ...,
        [-0.00000000e+00, -0.00000000e+00,

In [32]:
import numpy as np

In [33]:

# # Extract relevant data from the .pkl file (assuming these are already NumPy arrays)
# atom_positions = pkl_data['atom_positions']  # Expected shape: (n_residue, 37, 3)
# aatype = pkl_data['aatype']  # Amino acid types as integers
# atom_mask = pkl_data['atom_mask']  # Mask for atoms
# bb_mask = pkl_data['bb_mask']  # Backbone mask
# bb_positions = pkl_data['bb_positions']  # Backbone positions
# residue_index = pkl_data['residue_index']  # Residue indices
# chain_index = pkl_data['chain_index']  # Chain indices
# modeled_idx = pkl_data['modeled_idx']  # Modeled atom indices

# # Reshape atom_positions to be 2D: (num_atoms, 3)
# num_residues = atom_positions.shape[0]
# num_atoms_per_residue = atom_positions.shape[1]  # 37
# reshaped_atom_positions = atom_positions.reshape(-1, 3)  # Flatten the residues into a single array of atoms

# # Initialize AtomArray with the correct length (total number of atoms)
# num_atoms = reshaped_atom_positions.shape[0]
# structure = AtomArray(num_atoms)

# # Directly assign the reshaped NumPy arrays to the AtomArray attributes
# structure.coord = reshaped_atom_positions  # Now it's a 2D array (num_atoms, 3)
# structure.aatype = np.repeat(aatype, num_atoms_per_residue)  # Repeat amino acid types for each atom
# structure.atom_mask = np.repeat(atom_mask, num_atoms_per_residue)  # Repeat atom mask
# structure.bb_mask = np.repeat(bb_mask, num_atoms_per_residue)  # Repeat backbone mask
# structure.bb_positions = np.repeat(bb_positions, num_atoms_per_residue, axis=0)  # Repeat backbone positions
# structure.residue_index = np.repeat(residue_index, num_atoms_per_residue)  # Repeat residue index
# structure.chain_index = np.repeat(chain_index, num_atoms_per_residue)  # Repeat chain index
# structure.modeled_idx = np.repeat(modeled_idx, num_atoms_per_residue)  # Repeat modeled atom indices



In [34]:
# Now, structure is an instance of AtomArray and can be passed to the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the encoder
encoder = ESMIFEncoder(no_pooling=True)

# Forward pass through the model (using the structure)
with torch.no_grad():
    embeddings = encoder(file_path)

print("Embeddings shape:", embeddings.shape)

IndexError: list index out of range