In [3]:
import torch
from mol2dreams.utils.parser import build_model_from_config

In [4]:
load_path = "../../data/data/precomputed_batches_small.pt"

loaded_batches = torch.load(load_path)

print(f"Loaded {len(loaded_batches)} batches from {load_path}.")

Loaded 3 batches from ../../data/data/precomputed_batches_small.pt.


In [5]:
for batch in loaded_batches:
    print(batch) 

DataBatch(x=[1427, 84], edge_index=[2, 2968], edge_attr=[2968, 7], y=[32, 1024], IDENTIFIER=[32], COLLISION_ENERGY=[32, 1], adduct=[32], precursor_mz=[32, 1], batch=[1427], ptr=[33])
DataBatch(x=[1405, 84], edge_index=[2, 2988], edge_attr=[2988, 7], y=[32, 1024], IDENTIFIER=[32], COLLISION_ENERGY=[32, 1], adduct=[32], precursor_mz=[32, 1], batch=[1405], ptr=[33])
DataBatch(x=[1391, 84], edge_index=[2, 2990], edge_attr=[2990, 7], y=[32, 1024], IDENTIFIER=[32], COLLISION_ENERGY=[32, 1], adduct=[32], precursor_mz=[32, 1], batch=[1391], ptr=[33])


In [6]:
batch.x[0][0].dtype, batch.edge_index[0][0].dtype, batch.edge_attr[0][0].dtype, batch.y[0][0].dtype

(torch.float32, torch.int64, torch.float32, torch.float32)

# Build network 

In [7]:
config = {
    'input_layer': {
        'type': 'CONV_GNN',
        'params': {
            'node_features': 84,
            'embedding_size_reduced': 128
        }
    },
    'body_layer': {
        'type': 'SKIPBLOCK_BODY',
        'params': {
            'embedding_size_gnn': 128,
            'embedding_size': 256,
            'num_skipblocks': 7,
            'pooling_fn': 'mean'
        }
    },
    'head_layer': {
        'type': 'BidirectionalHeadLayer',
        'params': {
            'input_size': 256,
            'output_size': 1024
        }
    }
}


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_model_from_config(config)
model.to(device)



Mol2DreaMS(
  (input_layer): CONV_GNN(
    (initial_conv): GCNConv(84, 128)
    (reluinit): ReLU()
    (conv1): GCNConv(128, 128)
    (reluconv1): ReLU()
    (conv2): GCNConv(128, 128)
    (reluconv2): ReLU()
    (conv3): GCNConv(128, 128)
    (reluconv3): ReLU()
    (conv4): GCNConv(128, 128)
    (reluconv4): ReLU()
  )
  (body_layer): SKIPBLOCK_BODY(
    (bottleneck): Linear(in_features=128, out_features=256, bias=True)
    (skipblocks): ModuleList(
      (0-6): 7 x SKIPblock(
        (batchNorm1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU()
        (dropout1): Dropout(p=0.2, inplace=False)
        (hidden1): Linear(in_features=256, out_features=128, bias=True)
        (batchNorm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU()
        (dropout2): Dropout(p=0.2, inplace=False)
        (hidden2): Linear(in_features=128, out_features=256, bias=True)
      )
    )
    

In [13]:
model.count_parameters()

1368960

In [12]:
for batch in loaded_batches:
    # Ensure batch data types are correct
    batch.x = batch.x.float()
    batch.edge_index = batch.edge_index.long()
    batch.edge_attr = batch.edge_attr.float()
    batch.y = batch.y.float()

    # Move batch to the same device as the model
    batch = batch.to(device)

    # Forward pass
    output = model(batch)

    # Print output shape
    print(f"Output shape: {output.shape}")

Output shape: torch.Size([32, 1024])
Output shape: torch.Size([32, 1024])
Output shape: torch.Size([32, 1024])
