In [1]:
import torch
import yaml
from mol2dreams.trainer.trainer import Trainer
from mol2dreams.trainer.loss import MSELoss

from mol2dreams.utils.parser import (
    build_model_from_config,
    build_trainer_from_config,
    build_loss_from_config,
    build_optimizer_from_config,
    build_data_loaders_from_config
)


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load configuration
config = {
    'input_layer': {
        'type': 'CONV_GNN',
        'params': {
            'node_features': 84,
            'embedding_size_reduced': 128
        }
    },
    'body_layer': {
        'type': 'SKIPBLOCK_BODY',
        'params': {
            'embedding_size_gnn': 128,
            'embedding_size': 256,
            'num_skipblocks': 7,
            'pooling_fn': 'mean'
        }
    },
    'head_layer': {
        'type': 'BidirectionalHeadLayer',
        'params': {
            'input_size': 256,
            'output_size': 1024
        }
    }
}

In [3]:
# Build model
model = build_model_from_config(config)
model.to(device)

# Print number of parameters
num_params = model.count_parameters()
print(f"The model has {num_params} trainable parameters.")

The model has 1368960 trainable parameters.




## Train part

In [4]:
load_path = "../../data/data/precomputed_batches_small.pt"

loaded_batches = torch.load(load_path)

print(f"Loaded {len(loaded_batches)} batches from {load_path}.")

Loaded 3 batches from ../../data/data/precomputed_batches_small.pt.


In [5]:
loss_fn = MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=loaded_batches,
    val_loader=loaded_batches,
    device=device,
    log_dir='../../data/logs/mol2dreams',
    epochs=50,
    validate_every=5,   
    save_every=5,           
    save_best_only=True     
)



In [6]:
for batch in loaded_batches:
    print(batch.x[0][0].dtype, batch.edge_index[0][0].dtype, batch.edge_attr[0][0].dtype, batch.y[0][0].dtype)

torch.float32 torch.int64 torch.float32 torch.float32
torch.float32 torch.int64 torch.float32 torch.float32
torch.float32 torch.int64 torch.float32 torch.float32


In [7]:
trainer.train()

trainer.test()

Epoch [1/0], Loss: 0.6985, Cosine Sim: 0.0959
Epoch [2/1], Loss: 0.6316, Cosine Sim: 0.2801
Epoch [3/2], Loss: 0.6022, Cosine Sim: 0.3551
Epoch [4/3], Loss: 0.5843, Cosine Sim: 0.3931
Epoch [5/4], Loss: 0.5694, Cosine Sim: 0.4143
Validation Loss: 0.6276, Validation Cosine Sim: 0.3687
Best model saved at epoch 5 with validation loss 0.6276
Model checkpoint saved at ../../data/logs/mol2dreams/model_epoch_5.pt
Epoch [6/5], Loss: 0.5560, Cosine Sim: 0.4357
Epoch [7/6], Loss: 0.5487, Cosine Sim: 0.4476
Epoch [8/7], Loss: 0.5406, Cosine Sim: 0.4614
Epoch [9/8], Loss: 0.5343, Cosine Sim: 0.4714
Epoch [10/9], Loss: 0.5287, Cosine Sim: 0.4805
Validation Loss: 0.5964, Validation Cosine Sim: 0.3650
Best model saved at epoch 10 with validation loss 0.5964
Model checkpoint saved at ../../data/logs/mol2dreams/model_epoch_10.pt
Epoch [11/10], Loss: 0.5237, Cosine Sim: 0.4885
Epoch [12/11], Loss: 0.5198, Cosine Sim: 0.4930
Epoch [13/12], Loss: 0.5167, Cosine Sim: 0.4988
Epoch [14/13], Loss: 0.5119, Co

## Train from yaml

In [17]:
config = {
    'model': {
        'input_layer': {
            'type': 'CONV_GNN',
            'params': {
                'node_features': 84,
                'embedding_size_reduced': 128
            }
        },
        'body_layer': {
            'type': 'SKIPBLOCK_BODY',
            'params': {
                'embedding_size_gnn': 128,
                'embedding_size': 256,
                'num_skipblocks': 7,
                'pooling_fn': 'mean'
            }
        },
        'head_layer': {
            'type': 'BidirectionalHeadLayer',
            'params': {
                'input_size': 256,
                'output_size': 1024
            }
        }
    },
    'training': {
        'num_epochs': 50,
        'validate_every': 5,
        'save_every': 5,
        'save_best_only': True,
        'device': 'cpu',
        'log_dir': '../../data/logs',
        'loss_function': {
            'type': 'MSELoss',
            'params': {}
        },
        'optimizer': {
            'type': 'Adam',
            'params': {
                'lr': 0.001
            }
        },
        'train_loader': {
            'path': '../../data/data/precomputed_batches_small.pt'
        },
        'val_loader': {
            'path': '../../data/data/precomputed_batches_small.pt'
        },
        'test_loader': {
            'path': '../../data/data/precomputed_batches_small.pt'
        }
    },
    'data_processing': {
        'bond_config' : {
            'features': {
                'bond_type': True,
                'conjugated': True,
                'in_ring': True,
                'stereochemistry': False,
            }
        },
        'atom_config' : {
            'features': {
                'atom_symbol': True,
                'total_valence': True,
                'aromatic': True,
                'hybridization': True,
                'formal_charge': True,
                'default_valence': True,
                'ring_size': True,        
                'hydrogen_count': True,
            },
            'feature_attributes': {
                'atom_symbol': {
                    'top_n_atoms': 42,     
                    'include_other': True,    
                },
        
            }
        }
        
    },
}


In [13]:
# yaml_file_path = '../configs/example_config.yaml'
# with open(yaml_file_path, 'w') as file:
#     yaml.dump(config, file, default_flow_style=False)

In [18]:
trainer = build_trainer_from_config(config)

In [19]:
trainer.train()

Epoch [1/0], Loss: 0.6985, Cosine Sim: 0.0959
Epoch [2/1], Loss: 0.6316, Cosine Sim: 0.2801
Epoch [3/2], Loss: 0.6022, Cosine Sim: 0.3551
Epoch [4/3], Loss: 0.5843, Cosine Sim: 0.3931
Epoch [5/4], Loss: 0.5694, Cosine Sim: 0.4143
Validation Loss: 0.6276, Validation Cosine Sim: 0.3687
Best model saved at epoch 5 with validation loss 0.6276
Model checkpoint saved at ../../data/logs/20241008_182041_mol2dreams/model_epoch_5.pt
Epoch [6/5], Loss: 0.5560, Cosine Sim: 0.4357
Epoch [7/6], Loss: 0.5487, Cosine Sim: 0.4476
Epoch [8/7], Loss: 0.5406, Cosine Sim: 0.4614
Epoch [9/8], Loss: 0.5343, Cosine Sim: 0.4714
Epoch [10/9], Loss: 0.5287, Cosine Sim: 0.4805
Validation Loss: 0.5964, Validation Cosine Sim: 0.3650
Best model saved at epoch 10 with validation loss 0.5964
Model checkpoint saved at ../../data/logs/20241008_182041_mol2dreams/model_epoch_10.pt
Epoch [11/10], Loss: 0.5237, Cosine Sim: 0.4885
Epoch [12/11], Loss: 0.5198, Cosine Sim: 0.4930
Epoch [13/12], Loss: 0.5167, Cosine Sim: 0.4988

In [20]:
trainer.test()

Test Loss: 0.5002, Test Cosine Sim: 0.5200
