In [None]:
from mini_latent_pd.models.minifold.model.model import MiniFoldModel

In [1]:
import tree

In [1]:
import torch
import torch.nn.functional as F
from pathlib import Path
from mini_latent_pd.models.minifold.model.model import MiniFoldModel
from mini_latent_pd.models.minifold.data.config import model_config
from mini_latent_pd.models.minifold.data.of_data import of_inference
from mini_latent_pd.models.minifold.utils.residue_constants import restype_order_with_x_inverse
from esm.pretrained import load_model_and_alphabet

def load_minifold_model(cache_dir="./minifold_cache", model_size="48L", device="cuda"):
    """
    Loads the MiniFold model, alphabet, and config for interactive use.
    """
    # 1. Paths
    cache = Path(cache_dir).expanduser()
    checkpoint_path = cache / f"minifold_{model_size}.ckpt"
    
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Run predict.py once to download it.")

    print(f"Loading checkpoint from {checkpoint_path}...")
    
    # 2. Load Checkpoint & Hyperparams
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    hparams = ckpt["hyper_parameters"]
    
    # 3. Initialize Model
    # We create the config required for feature generation
    config_of = model_config(
        "initial_training",
        train=False,
        low_prec=False,
        long_sequence_inference=False,
    )
    
    model = MiniFoldModel(
        esm_model_name=hparams["esm_model_name"],
        num_blocks=hparams["num_blocks"],
        no_bins=hparams["no_bins"],
        config_of=config_of,
        use_structure_module=True,
        kernels=False, # Important: False for Mac/MPS compatibility
    )

    # 4. Load State Dict
    state_dict = ckpt["state_dict"]
    # Clean keys as done in predict.py
    state_dict = {k: v for k, v in state_dict.items() if "boundaries" not in k}
    state_dict = {k: v for k, v in state_dict.items() if "mid_points" not in k}
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
    
    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()
    
    # 5. Load Alphabet (needed for tokenizing sequences)
    _, alphabet = load_model_and_alphabet(hparams["esm_model_name"])
    
    return model, alphabet, config_of.data

In [2]:
class MinifoldCollator:
    def __init__(self, alphabet, config):
        self.alphabet = alphabet
        self.config = config

    def __call__(self, batch):
        """
        Transforms a list of mdCATH dicts into a MiniFold batch.
        """
        seq_tokens_list = []
        masks_list = []
        batch_of_list = []
        gt_coords_list = []
        ids = []

        for item in batch:
            seq_str = item["sequence"]
            ids.append(item["id"])
            
            # --- A. Generate Features ---
            # This replicates predict.py's `prepare_input`
            of_feats = of_inference(seq_str, "predict", self.config)
            
            # Reconstruct clean sequence from features to ensure alignment
            clean_seq = "".join(
                [restype_order_with_x_inverse[x.item()] for x in of_feats["aatype"]]
            )[: of_feats["seq_length"]]
            
            # Encode for ESM
            encoded_seq = torch.tensor(self.alphabet.encode(clean_seq), dtype=torch.long)
            
            # --- B. Collect ---
            seq_tokens_list.append(encoded_seq)
            masks_list.append(of_feats["seq_mask"][:, 0].bool())
            
            # Filter OpenFold features to only what MiniFold needs
            relevant = {"aatype", "seq_mask", "residx_atom37_to_atom14", "atom37_atom_exists"}
            batch_of_list.append({k: v for k, v in of_feats.items() if k in relevant})
            
            # Collect Ground Truth Coords (from mdCATH)
            gt_coords_list.append(torch.tensor(item["coords"]))

        # --- C. Padding ---
        max_len = max(len(s) for s in seq_tokens_list)
        
        # 1. Pad ESM Tokens (1 is usually padding in ESM alphabet)
        padded_seqs = torch.stack([
            F.pad(s, (0, max_len - len(s)), value=1) for s in seq_tokens_list
        ])
        
        # 2. Pad Masks
        padded_masks = torch.stack([
            F.pad(m, (0, max_len - len(m)), value=0) for m in masks_list
        ])
        
        # 3. Pad OpenFold Features
        batched_of = {}
        if batch_of_list:
            for k in batch_of_list[0].keys():
                feats = [d[k] for d in batch_of_list]
                # Pad the first dimension (sequence length)
                batched_of[k] = torch.stack([
                    F.pad(f, [0] * 2 * (len(f.shape) - 1) + [0, max_len - f.shape[0]], value=0) 
                    for f in feats
                ])

        # 4. Pad Ground Truth Coords
        # shape: (Batch, Max_Len, 3)
        padded_coords = torch.nn.utils.rnn.pad_sequence(
            gt_coords_list, batch_first=True, padding_value=0.0
        )

        return {
            "id": ids,
            "seq": padded_seqs,
            "mask": padded_masks,
            "batch_of": batched_of,
            "gt_coords": padded_coords
        }

In [3]:
import torch
from datasets import IterableDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

# Import the generator from your new file
# (Ensure mdcath_dataset.py is in the same folder as your notebook)
from mini_latent_pd.data.mdcath_dataset import mdcath_generator

# --- Configuration ---
REPO_ID = "compsciencelab/mdCATH"
SUB_SAMPLE_PROTEINS = 4 
BATCH_SIZE = 1

# --- Padding Collator (Keep this in notebook as it's simple) ---
def padding_collate(batch):
    coords_list = [torch.tensor(item["coords"]) for item in batch]
    coords_padded = pad_sequence(coords_list, batch_first=True, padding_value=0.0)
    lengths = torch.tensor([len(c) for c in coords_list])
    B, Max_L = coords_padded.shape[:2]
    mask = torch.arange(Max_L).expand(B, Max_L) < lengths.unsqueeze(1)

    return {
        "id": [item["id"] for item in batch],
        "coords": coords_padded,
        "mask": mask,
        "temp": torch.tensor([item["temp"] for item in batch]),
        "lengths": lengths
    }

# --- Execution ---

if __name__ == "__main__":
    print("--- 1. Initializing Streaming Dataset ---")
    
    # Now this works because mdcath_generator is an imported function
    ds = IterableDataset.from_generator(
        mdcath_generator, 
        gen_kwargs={"repo_id": REPO_ID, "sub_sample_proteins": SUB_SAMPLE_PROTEINS}
    )
    
    # Optional: Hugging Face dataset mapping
    ds = ds.map(lambda x: {"num_atoms": len(x["coords"])})

    print("--- 2. Running DataLoader with Padding ---")
    dataloader = DataLoader(ds, batch_size=BATCH_SIZE, collate_fn=padding_collate)

    for i, batch in enumerate(dataloader):
        print(f"\nBatch {i+1}")
        print(f"  IDs: {batch['id']}")
        print(f"  Coords Shape: {batch['coords'].shape} (Padded)")
        
        if i >= 4: 
            print("Test complete.")
            break

  from .autonotebook import tqdm as notebook_tqdm


--- 1. Initializing Streaming Dataset ---
--- 2. Running DataLoader with Padding ---
Generator finding files... (using subset of 4)

Batch 1
  IDs: ['2cw7A03']
  Coords Shape: torch.Size([1, 2769, 3]) (Padded)

Batch 2
  IDs: ['2cw7A03']
  Coords Shape: torch.Size([1, 2769, 3]) (Padded)

Batch 3
  IDs: ['2cw7A03']
  Coords Shape: torch.Size([1, 2769, 3]) (Padded)

Batch 4
  IDs: ['2cw7A03']
  Coords Shape: torch.Size([1, 2769, 3]) (Padded)

Batch 5
  IDs: ['2cw7A03']
  Coords Shape: torch.Size([1, 2769, 3]) (Padded)
Test complete.


In [1]:
import urllib.request
from pathlib import Path

# Config
cache_dir = Path("./minifold_cache")
model_size = "12L" # or "48L" if you want the larger one
filename = f"minifold_{model_size}.ckpt"
target_path = cache_dir / filename

# Create dir
cache_dir.mkdir(parents=True, exist_ok=True)

# Download if missing
if not target_path.exists():
    print(f"Downloading {filename}...")
    url = f"https://huggingface.co/jwohlwend/minifold/resolve/main/minifold_{model_size}_final.ckpt"
    urllib.request.urlretrieve(url, str(target_path))
    print("Download complete.")
else:
    print(f"Found checkpoint at {target_path}")

Downloading minifold_12L.ckpt...
Download complete.


In [4]:
def crop_batch(batch, max_len=600):
    """
    Hard-crops a batch to a maximum length to prevent OOM/Buffer errors.
    """
    current_len = batch["seq"].shape[1]
    if current_len <= max_len:
        return batch
    
    print(f"⚠️ Cropping batch from {current_len} to {max_len} for testing...")
    
    # Crop sequence and mask
    new_batch = {
        "seq": batch["seq"][:, :max_len],
        "mask": batch["mask"][:, :max_len],
        "batch_of": {},
        "id": batch["id"] # Keep IDs
    }
    
    # Crop OpenFold features (careful with dimensions)
    for k, v in batch["batch_of"].items():
        if v.shape[1] == current_len:
             # Assuming dim 1 is sequence length (Batch, Len, ...)
            new_batch["batch_of"][k] = v[:, :max_len, ...]
        else:
            # Some features might not map directly to len (check specific keys if needed)
            new_batch["batch_of"][k] = v
            
    return new_batch

In [6]:
from mini_latent_pd.data.mdcath_dataset import mdcath_generator

# 1. Setup Device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Load Model
# Ensure you point cache_dir to where predict.py downloaded weights
model, alphabet, config = load_minifold_model(cache_dir="./minifold_cache", model_size="12L", device=device)

# 3. Prepare Dataset
# (Assuming mdcath_generator is defined in a previous cell)
from torch.utils.data import DataLoader
from datasets import IterableDataset

ds = IterableDataset.from_generator(
        mdcath_generator, 
        gen_kwargs={"repo_id": REPO_ID, "sub_sample_proteins": SUB_SAMPLE_PROTEINS}
    )
collator = MinifoldCollator(alphabet, config)

# 4. Create DataLoader
loader = DataLoader(ds, batch_size=2, collate_fn=collator)

# 5. Run a Test Batch
print("Fetching batch...")
batch = next(iter(loader))

# Move to device
batch_gpu = {
    "seq": batch["seq"].to(device),
    "mask": batch["mask"].to(device),
    "batch_of": {k: v.to(device) for k, v in batch["batch_of"].items()}
}

# batch_gpu = crop_batch(batch_gpu, max_len=512)

print(f"Running inference for protein IDs: {batch['id']}")
with torch.no_grad():
    # Run MiniFold!
    output = model(batch_gpu)

print("\nSuccess!")
print("Latent Pair Representation shape:", output["pair"].shape) # (B, L, L, 128)
print("Predicted Coordinates shape:", output["final_atom_positions"].shape) # (B, L, 37, 3)

Using device: mps
Loading checkpoint from minifold_cache/minifold_12L.ckpt...
Fetching batch...
Generator finding files... (using subset of 4)
Running inference for protein IDs: ['1l1cA00', '1l1cA00']

Success!
Latent Pair Representation shape: torch.Size([2, 883, 883, 128])
Predicted Coordinates shape: torch.Size([2, 883, 37, 3])


  return data[ranges]


In [10]:
for key, value in batch.items():
    print(f"{key}: {value.shape if isinstance(value, torch.Tensor) else type(value)}")

id: <class 'list'>
seq: torch.Size([2, 883])
mask: torch.Size([2, 883])
batch_of: <class 'dict'>
gt_coords: torch.Size([2, 883, 3])


1. The Input Batch

These are the tensors fed into the model.

- id: ['1ABC', '2XYZ']

A list of the protein identifiers for tracking which file corresponds to which output.
- seq: (2, 883)

What it is: The ESM token indices for the amino acid sequence.

Details: Integers representing amino acids. The padding value (usually 1) is used for any protein shorter than 883 residues.
- mask: (2, 883)

What it is: A boolean mask indicating where real data exists versus padding.

        Values: 1 (True) for real residues, 0 (False) for padded regions. You must use this to mask your loss function during training, or the model will try to fold the padding zeros.

- batch_of: <dict>

        What it is: Pre-computed geometric features needed by the structure module (e.g., rigid body frame definitions for each amino acid type).

        Key contents: aatype (amino acid identity), rigid_group_default_frame (local frames for side chains).

- gt_coords: (2, 883, 3)

        What it is: Your Ground Truth coordinates from mdCATH.

        Note: Since the last dimension is 3, this contains only one atom per residue (likely Alpha Carbon, CA). You will use this to calculate the loss against the model's predicted CA atoms.

2. The Model Output

These are the tensors returned by model(batch).
A. The Latent Spaces 

    pair: (2, 883, 883, 128)

        What it is: The Geometry-Aware Latent Space (sz​).

        Significance: This is the output of the "Folding Trunk" (the MiniFormer). It represents the relationship between every pair of residues (L×L).

        For your project: This is the tensor you should encode/diffuse/modify. It contains the rich evolutionary and structural constraints before they are realized into 3D coordinates.

    single: (2, 883, 1024)

        What it is: The Single Latent Representation (ss​).

        Significance: A projection of the ESM embeddings mixed with structural info. It serves as the primary input to the Structure Module to generate the backbone frames.

B. The Structural Outputs (The Decoder)

These are generated by the StructureModule from the latent spaces.

    final_atom_positions: (2, 883, 37, 3)

        What it is: The predicted 3D coordinates for all heavy atoms.

        Dimensions:

            37: The standard mapping for up to 37 heavy atoms (N, CA, C, O, CB, ...).

            3: X, Y, Z coordinates (in Angstroms).

        Note: To calculate loss against your gt_coords (CA only), you need to extract the CA atom (index 1) from this tensor: pred_ca = output['final_atom_positions'][:, :, 1, :].

    final_atom_mask: (2, 883, 37)

        What it is: Tells you which of the 37 atoms actually exist for a given amino acid (e.g., Glycine has no Beta Carbon, so its index would be 0 here).

    final_affine_tensor: (2, 883, 4, 4)

        What it is: The predicted rotation and translation matrices (frames) for every residue's backbone. The Structure Module predicts these frames first, then places the atoms into them.

C. Probabilistic Outputs & Confidence

    preds: (2, 883, 883, 64)

        What it is: The Distogram Logits.

        Details: It predicts the probability of the distance between two residues falling into one of 64 bins (e.g., bin 0 = 2-4 Angstroms, bin 63 = >22 Angstroms).

    plddt: (2, 883)

        What it is: The per-residue confidence score (0-100). High values mean the model is confident in that region's structure.

    lddt_logits: (2, 883, 50)

        What it is: The raw logits used to calculate the plddt score.

    sm: <dict>

        What it is: Internal debug state of the Structure Module (e.g., side-chain torsion angles). You can usually ignore this unless you are debugging specific side-chain rotations.

In [8]:
for key, value in output.items():
    print(f"{key}: {value.shape if isinstance(value, torch.Tensor) else type(value)}")

preds: torch.Size([2, 883, 883, 64])
pair: torch.Size([2, 883, 883, 128])
single: torch.Size([2, 883, 1024])
sm: <class 'dict'>
final_atom_positions: torch.Size([2, 883, 37, 3])
final_atom_mask: torch.Size([2, 883, 37])
final_affine_tensor: torch.Size([2, 883, 4, 4])
lddt_logits: torch.Size([2, 883, 50])
plddt: torch.Size([2, 883])


In [19]:
batch['batch_of']

{'aatype': tensor([[[12],
          [12],
          [12],
          ...,
          [19],
          [19],
          [19]],
 
         [[12],
          [12],
          [12],
          ...,
          [19],
          [19],
          [19]]]),
 'seq_mask': tensor([[[1.],
          [1.],
          [1.],
          ...,
          [1.],
          [1.],
          [1.]],
 
         [[1.],
          [1.],
          [1.],
          ...,
          [1.],
          [1.],
          [1.]]]),
 'residx_atom37_to_atom14': tensor([[[[0],
           [1],
           [2],
           ...,
           [0],
           [0],
           [0]],
 
          [[0],
           [1],
           [2],
           ...,
           [0],
           [0],
           [0]],
 
          [[0],
           [1],
           [2],
           ...,
           [0],
           [0],
           [0]],
 
          ...,
 
          [[0],
           [1],
           [2],
           ...,
           [0],
           [0],
           [0]],
 
          [[0],
  

In [12]:
output['final_atom_positions'][:, :, 1, :]

tensor([[[-96.3977,   5.6547,  -6.8404],
         [-95.3969,   4.2691,  -7.1769],
         [-94.0278,   5.1603,  -7.7622],
         ...,
         [ 32.4375,   2.9705,   4.2263],
         [ 32.9936,  -0.4333,   3.6020],
         [ 33.3603,   0.5361,   0.8055]],

        [[-22.7918,   4.1711,   1.4760],
         [-23.7921,   2.8411,   3.4759],
         [-22.2749,   1.5842,   1.8703],
         ...,
         [  0.2235,  -0.8644,   0.6441],
         [  0.7852,  -1.4505,   1.4023],
         [  1.0849,  -1.3420,  -0.4275]]], device='mps:0')

In [18]:
batch['gt_coords'].to(device)-output['final_atom_positions'][:, :, 1, :]

tensor([[[ 69.3777,  47.4753,  42.9504],
         [ 68.8169,  48.2709,  42.4469],
         [ 66.2578,  48.5597,  43.3122],
         ...,
         [-87.0375,  58.2095,  30.0537],
         [-86.8236,  60.1133,  31.2680],
         [-86.4503,  60.6439,  34.6745]],

        [[ -4.5482,  57.4589,  48.6540],
         [ -3.9879,  58.8489,  45.6341],
         [ -4.5651,  61.0158,  48.4797],
         ...,
         [-21.0835,  45.5444,  63.6059],
         [-20.2252,  45.0405,  63.1377],
         [-22.1349,  44.2520,  64.9275]]], device='mps:0')

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
from mini_latent_pd.models.minifold.model.model import MiniFoldModel
from mini_latent_pd.models.minifold.data.config import model_config
from mini_latent_pd.models.minifold.data.of_data import of_inference
from mini_latent_pd.models.minifold.utils.residue_constants import restype_order_with_x_inverse
from esm.pretrained import load_model_and_alphabet

def load_minifold_model(cache_dir="./minifold_cache", model_size="48L", device="cuda"):
    """
    Loads the MiniFold model, alphabet, and config for interactive use.
    """
    # 1. Paths
    cache = Path(cache_dir).expanduser()
    checkpoint_path = cache / f"minifold_{model_size}.ckpt"
    
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Run predict.py once to download it.")

    print(f"Loading checkpoint from {checkpoint_path}...")
    
    # 2. Load Checkpoint & Hyperparams
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    hparams = ckpt["hyper_parameters"]
    
    # 3. Initialize Model
    # We create the config required for feature generation
    config_of = model_config(
        "initial_training",
        train=False,
        low_prec=False,
        long_sequence_inference=False,
    )
    
    model = MiniFoldModel(
        esm_model_name=hparams["esm_model_name"],
        num_blocks=hparams["num_blocks"],
        no_bins=hparams["no_bins"],
        config_of=config_of,
        use_structure_module=True,
        kernels=False, # Important: False for Mac/MPS compatibility
    )

    # 4. Load State Dict
    state_dict = ckpt["state_dict"]
    # Clean keys as done in predict.py
    state_dict = {k: v for k, v in state_dict.items() if "boundaries" not in k}
    state_dict = {k: v for k, v in state_dict.items() if "mid_points" not in k}
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
    
    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()
    
    # 5. Load Alphabet (needed for tokenizing sequences)
    _, alphabet = load_model_and_alphabet(hparams["esm_model_name"])
    
    return model, alphabet, config_of.data