In [13]:
from models.spectra_encoder import SpectraEncoder, SpectraEncoderGrowing

# Test forward run of SpectraEncoder

In [6]:
import torch

# Initialize the model with required parameters
model = SpectraEncoder(
    form_embedder="float",
    output_size=4096,
    hidden_size=256,  # Using a reasonable size for testing
    spectra_dropout=0.1,
    top_layers=2,
    refine_layers=0,
    magma_modulo=2048,
    peak_attn_layers=4,  # Important parameter for FormulaTransformer
    num_heads=8,
    set_pooling="intensity"
)

In [7]:
from numpy import array
spec_features = {'peak_type': array([0, 0, 0, 0, 0, 0, 0, 0, 3]),
 'form_vec': array([[ 6.,  4.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          3.,  0.,  0.,  0.,  0.],
        [14., 15.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
          3.,  0.,  0.,  0.,  0.],
        [ 7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
          3.,  0.,  0.,  0.,  0.],
        [13., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.],
        [ 7.,  6.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.],
        [14., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          3.,  0.,  0.,  0.,  0.],
        [ 7.,  7.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
          3.,  0.,  0.,  0.,  0.],
        [13., 13.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
          1.,  0.,  0.,  0.,  0.],
        [16., 17.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
          4.,  0.,  0.,  0.,  0.]]),
 'ion_vec': [0, 0, 0, 0, 0, 0, 0, 0, 0],
 'frag_intens': array([1.        , 0.85716669, 0.5961169 , 0.59105782, 0.49522242,
        0.37701743, 0.28298424, 0.21223818, 1.        ]),
 'name': 'MassSpecGymID0000001',
 'magma_fps': array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        ...,
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [-1., -1., -1., ..., -1., -1., -1.]]),
 'magma_aux_loss': True,
 'instrument': 0}

  'magma_fps': array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],


In [None]:
import numpy as np

# Convert the spec_features dictionary to a batch format (adding batch dimension)
def prepare_batch(spec_features):
    batch = {}
    # Number of peaks
    num_peaks = len(spec_features['peak_type'])
    batch['num_peaks'] = torch.tensor([num_peaks], dtype=torch.long)
    
    # Peak types
    batch['types'] = torch.tensor(spec_features['peak_type'], dtype=torch.long).unsqueeze(0)
    
    # Formula vectors
    batch['form_vec'] = torch.tensor(spec_features['form_vec'], dtype=torch.float).unsqueeze(0)
    
    # Ion vectors
    batch['ion_vec'] = torch.tensor(spec_features['ion_vec'], dtype=torch.long).unsqueeze(0)
    
    # Intensity values
    batch['intens'] = torch.tensor(spec_features['frag_intens'], dtype=torch.float).unsqueeze(0)
    
    # Instrument
    batch['instruments'] = torch.tensor([spec_features['instrument']], dtype=torch.long)
    
    return batch

# Prepare the batcha
batch = prepare_batch(spec_features)

In [9]:
batch 

{'num_peaks': tensor([9]),
 'types': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 3]]),
 'form_vec': tensor([[[ 6.,  4.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,
            0.,  0.,  0.,  0.],
          [14., 15.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
            0.,  0.,  0.,  0.],
          [ 7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
            0.,  0.,  0.,  0.],
          [13., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
            0.,  0.,  0.,  0.],
          [ 7.,  6.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.],
          [14., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,
            0.,  0.,  0.,  0.],
          [ 7.,  7.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
            0.,  0.,  0.,  0.],
          [13., 13.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,
            0.,  0.,  0.,  0.],
        

In [10]:
# Set model to evaluation mode
model.eval()

# Run forward pass
with torch.no_grad():
    output, aux_outputs = model(batch)
    
# Print shapes to verify
print(f"Output shape: {output.shape}")
print(f"Predicted fragment fingerprints shape: {aux_outputs['pred_frag_fps'].shape}")
print(f"Hidden state shape: {aux_outputs['h0'].shape}")

# Check values
print(f"Output min/max: {output.min().item():.4f}/{output.max().item():.4f}")

Output shape: torch.Size([1, 4096])
Predicted fragment fingerprints shape: torch.Size([1, 9, 2048])
Hidden state shape: torch.Size([1, 256])
Output min/max: 0.2793/0.7018


# Test forward run of SpectraEncoderGrowing

In [16]:
model2 = SpectraEncoderGrowing(
    form_embedder="float",
    output_size=4096,
    hidden_size=256,
    spectra_dropout=0.1,
    top_layers=2,
    refine_layers=3,  # Number of splits in FPGrowingModule
    magma_modulo=2048,
    peak_attn_layers=4,
    num_heads=8,
    set_pooling="intensity"
)

In [17]:
# Set model to evaluation mode
model2.eval()

# Run forward pass
with torch.no_grad():
    output, aux_outputs = model2(batch)
    
# Print shapes to verify
print(f"Output shape: {output.shape}")
print(f"Predicted fragment fingerprints shape: {aux_outputs['pred_frag_fps'].shape}")
print(f"Hidden state shape: {aux_outputs['h0'].shape}")

# Check intermediate predictions (specific to SpectraEncoderGrowing)
for i, intermediate in enumerate(aux_outputs["int_preds"]):
    print(f"Intermediate prediction {i} shape: {intermediate.shape}")

# Check values
print(f"Output min/max: {output.min().item():.4f}/{output.max().item():.4f}")

Output shape: torch.Size([1, 4096])
Predicted fragment fingerprints shape: torch.Size([1, 9, 2048])
Hidden state shape: torch.Size([1, 256])
Intermediate prediction 0 shape: torch.Size([1, 512])
Intermediate prediction 1 shape: torch.Size([1, 1024])
Intermediate prediction 2 shape: torch.Size([1, 2048])
Output min/max: 0.0496/0.4883


In [None]:
import torch
import os

# Initialize your model first (with the same parameters as the saved model)
model = SpectraEncoderGrowing(
    form_embedder="float",
    output_size=4096,
    hidden_size=256,
    spectra_dropout=0.1,
    top_layers=2,
    refine_layers=3,
    magma_modulo=2048,
    peak_attn_layers=4,
    num_heads=8,
    set_pooling="intensity"
)

# Path to the checkpoint
checkpoint_path = "services/DiffMS/checkpoints/model_checkpoints/encoder_msg.pt"

# Check if the checkpoint exists
if not os.path.exists(checkpoint_path):
    print(f"Checkpoint not found at {checkpoint_path}")
else:
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    # Check the contents of the checkpoint
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        # Case: Checkpoint is a dictionary with 'state_dict' key
        state_dict = checkpoint['state_dict']
    else:
        # Case: Checkpoint is directly the state_dict
        state_dict = checkpoint
    
    # Print some information about the loaded state_dict
    print(f"Loaded checkpoint with {len(state_dict)} keys")
    
    # Try to load the weights into the model
    try:
        # Check if we need to remove prefixes from state_dict keys
        if any(key.startswith('module.') for key in state_dict):
            # Remove 'module.' prefix (common when model was trained with DataParallel)
            state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
        
        # Load state dictionary
        model.load_state_dict(state_dict, strict=False)
        print("Model weights loaded successfully!")
        
        # If you want to see which parameters were loaded and which were not
        model_dict = model.state_dict()
        missing_keys = [key for key in model_dict if key not in state_dict]
        unexpected_keys = [key for key in state_dict if key not in model_dict]
        
        if missing_keys:
            print(f"Warning: {len(missing_keys)} keys in model were not found in checkpoint:")
            print(missing_keys[:5], "..." if len(missing_keys) > 5 else "")
        
        if unexpected_keys:
            print(f"Warning: {len(unexpected_keys)} keys in checkpoint were not used:")
            print(unexpected_keys[:5], "..." if len(unexpected_keys) > 5 else "")
            
    except Exception as e:
        print(f"Error loading weights: {e}")

# Model is now loaded and ready for inference
model.eval()

# Load the checkpoint weights

In [26]:
import torch
import os

# Initialize your model first (with the same parameters as the saved model)
model = SpectraEncoderGrowing(
                        inten_transform='float',
                        inten_prob=0.1,
                        remove_prob=0.5,
                        peak_attn_layers=2,
                        num_heads=8,
                        pairwise_featurization=True,
                        embed_instrument=False,
                        cls_type='ms1',
                        set_pooling='cls',
                        spec_features='peakformula',
                        mol_features='fingerprint',
                        form_embedder='pos-cos',
                        output_size=4096,
                        hidden_size=512,
                        spectra_dropout=0.1,
                        top_layers=1,
                        refine_layers=4,
                    )

In [27]:
# Path to the checkpoint
checkpoint_path = "/home/i_golov/Spectrum_Structure_prediction/Spectrum-to-Molecular/services/DiffMS/checkpoints/model_checkpoints/encoder_msg.pt"

# Check if the checkpoint exists
if not os.path.exists(checkpoint_path):
    print(f"Checkpoint not found at {checkpoint_path}")
else:
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    # Check the contents of the checkpoint
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        # Case: Checkpoint is a dictionary with 'state_dict' key
        state_dict = checkpoint['state_dict']
    else:
        # Case: Checkpoint is directly the state_dict
        state_dict = checkpoint
    
# Print some information about the loaded state_dict
print(f"Loaded checkpoint with {len(state_dict)} keys")

Loaded checkpoint with 59 keys


In [28]:
# Try to load the weights into the model
try:
    # Check if we need to remove prefixes from state_dict keys
    if any(key.startswith('module.') for key in state_dict):
        # Remove 'module.' prefix (common when model was trained with DataParallel)
        state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
    
    # Load state dictionary
    model.load_state_dict(state_dict, strict=False)
    print("Model weights loaded successfully!")
    
    # If you want to see which parameters were loaded and which were not
    model_dict = model.state_dict()
    missing_keys = [key for key in model_dict if key not in state_dict]
    unexpected_keys = [key for key in state_dict if key not in model_dict]
    
    if missing_keys:
        print(f"Warning: {len(missing_keys)} keys in model were not found in checkpoint:")
        print(missing_keys[:5], "..." if len(missing_keys) > 5 else "")
    
    if unexpected_keys:
        print(f"Warning: {len(unexpected_keys)} keys in checkpoint were not used:")
        print(unexpected_keys[:5], "..." if len(unexpected_keys) > 5 else "")
        
except Exception as e:
    print(f"Error loading weights: {e}")

# Model is now loaded and ready for inference
model.eval()

Model weights loaded successfully!


SpectraEncoderGrowing(
  (spectra_encoder): ModuleList(
    (0): FormulaTransformer(
      (form_embedder_mod): FourierFeaturizerPosCos()
      (intermediate_layer): MLPBlocks(
        (activation): ReLU()
        (dropout_layer): Dropout(p=0.1, inplace=False)
        (input_layer): Linear(in_features=343, out_features=512, bias=True)
        (layers): ModuleList(
          (0): Linear(in_features=512, out_features=512, bias=True)
        )
      )
      (pairwise_featurizer): MLPBlocks(
        (activation): ReLU()
        (dropout_layer): Dropout(p=0.1, inplace=False)
        (input_layer): Linear(in_features=162, out_features=512, bias=True)
        (layers): ModuleList(
          (0): Linear(in_features=512, out_features=512, bias=True)
        )
      )
      (peak_attn_layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    

In [29]:
with torch.no_grad():
    output, aux_outputs = model(batch)

In [33]:
aux_outputs.keys()

dict_keys(['pred_frag_fps', 'int_preds', 'h0'])