In [1]:
import torch
import yaml
from fire.test_components import simple_decorator

from mol2dreams.trainer.trainer import Trainer
from mol2dreams.trainer.loss import MSELoss

from mol2dreams.utils.parser import (
    build_model_from_config,
    build_trainer_from_config,
    build_loss_from_config,
    build_optimizer_from_config,
    build_data_loaders_from_config
)


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
    'input_layer': {
        'type': 'CONV_GNN',
        'params': {
            'node_features': 84,
            'embedding_size_reduced': 128
        }
    },
    'body_layer': {
        'type': 'SKIPBLOCK_BODY',
        'params': {
            'embedding_size_gnn': 128,
            'embedding_size': 256,
            'num_skipblocks': 7,
            'pooling_fn': 'mean'
        }
    },
    'head_layer': {
        'type': 'BidirectionalHeadLayer',
        'params': {
            'input_size': 256,
            'output_size': 1024
        }
    }
}

In [30]:
# Build model
model = build_model_from_config(config)
model.to(device)

# Print number of parameters
num_params = model.count_parameters()
print(f"The model has {num_params} trainable parameters.")

The model has 1368960 trainable parameters.


## Train part

In [2]:
from mol2dreams.datasets.SimpleDataset import SimpleDataset
from torch.utils.data import Dataset, DataLoader

In [24]:
load_path = "../../data/data/precomputed_batches_small.pt"

loaded_data = torch.load(load_path)

simple_dataset = SimpleDataset(loaded_data)

batch_size = 32
simple_loader = DataLoader(
    simple_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=simple_dataset.collate_fn,
    num_workers=0
)



print(f"Loaded {len(simple_loader)} batches from {load_path}.")

Loaded 4 batches from ../../data/data/precomputed_batches_small.pt.


In [22]:
loss_fn = MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=simple_loader,
    val_loader=simple_loader,
    device=device,
    log_dir='../../data/logs/mol2dreams',
    epochs=50,
    validate_every=5,   
    save_every=5,           
    save_best_only=True     
)



In [23]:
trainer.train()

trainer.test()

Epoch [1/50], Loss: 0.6944, Cosine Sim: 0.1222
Epoch [2/50], Loss: 0.6203, Cosine Sim: 0.3176
Epoch [3/50], Loss: 0.5948, Cosine Sim: 0.3716
Epoch [4/50], Loss: 0.5842, Cosine Sim: 0.3952
Epoch [5/50], Loss: 0.5744, Cosine Sim: 0.4070
Validation Loss: 0.6217, Validation Cosine Sim: 0.3800
Best model saved at epoch 5 with validation loss 0.6217
Model checkpoint saved at ../../data/logs/mol2dreams/model_epoch_5.pt
Epoch [6/50], Loss: 0.5716, Cosine Sim: 0.4083
Epoch [7/50], Loss: 0.5709, Cosine Sim: 0.4137
Epoch [8/50], Loss: 0.5574, Cosine Sim: 0.4370
Epoch [9/50], Loss: 0.5556, Cosine Sim: 0.4413
Epoch [10/50], Loss: 0.5486, Cosine Sim: 0.4546
Validation Loss: 0.6104, Validation Cosine Sim: 0.3489
Best model saved at epoch 10 with validation loss 0.6104
Model checkpoint saved at ../../data/logs/mol2dreams/model_epoch_10.pt
Epoch [11/50], Loss: 0.5443, Cosine Sim: 0.4603
Epoch [12/50], Loss: 0.5508, Cosine Sim: 0.4471
Epoch [13/50], Loss: 0.5466, Cosine Sim: 0.4551
Epoch [14/50], Loss: 

## Train from yaml

In [2]:
with open("/Users/macbook/CODE/mol2DreaMS/mol2dreams/configs/local_config_simple_dataset.yaml") as stream:
    config = yaml.safe_load(stream)

In [3]:
trainer = build_trainer_from_config(config)



In [4]:
trainer.train()

Epoch [1/50], Loss: 0.6932, Cosine Sim: 0.1314
Epoch [2/50], Loss: 0.6206, Cosine Sim: 0.3201
Epoch [3/50], Loss: 0.5794, Cosine Sim: 0.3763
Epoch [4/50], Loss: 0.5777, Cosine Sim: 0.3846
Epoch [5/50], Loss: 0.5990, Cosine Sim: 0.3817
Validation Loss: 0.6345, Validation Cosine Sim: 0.3684
Best model saved at epoch 5 with validation loss 0.6345
Model checkpoint saved at ../../data/logs/mol2dreams/20241010_100301_mol2dreams/model_epoch_5.pt
Epoch [6/50], Loss: 0.5369, Cosine Sim: 0.4357
Epoch [7/50], Loss: 0.5789, Cosine Sim: 0.4225
Epoch [8/50], Loss: 0.5908, Cosine Sim: 0.3988
Epoch [9/50], Loss: 0.5379, Cosine Sim: 0.4458
Epoch [10/50], Loss: 0.5477, Cosine Sim: 0.4480
Validation Loss: 0.5989, Validation Cosine Sim: 0.3793
Best model saved at epoch 10 with validation loss 0.5989
Model checkpoint saved at ../../data/logs/mol2dreams/20241010_100301_mol2dreams/model_epoch_10.pt
Epoch [11/50], Loss: 0.5266, Cosine Sim: 0.4531
Epoch [12/50], Loss: 0.5716, Cosine Sim: 0.4424
Epoch [13/50], 

In [5]:
trainer.test()

Test Loss: 0.5126, Test Cosine Sim: 0.5104


### Train Cosine Similarity loss

In [6]:
with open("/Users/macbook/CODE/mol2DreaMS/mol2dreams/configs/local_config_train_cosin_loss.yaml") as stream:
    config = yaml.safe_load(stream)

In [7]:
trainer = build_trainer_from_config(config)

In [8]:
trainer.train()

Epoch [1/50], Loss: 0.8503, Cosine Sim: 0.1497
Epoch [2/50], Loss: 0.6596, Cosine Sim: 0.3404
Epoch [3/50], Loss: 0.6014, Cosine Sim: 0.3986
Epoch [4/50], Loss: 0.5931, Cosine Sim: 0.4069
Epoch [5/50], Loss: 0.5943, Cosine Sim: 0.4057
Validation Loss: 0.6047, Validation Cosine Sim: 0.3953
Best model saved at epoch 5 with validation loss 0.6047
Model checkpoint saved at ../../data/logs/mol2dreams/20241010_115915_mol2dreams/model_epoch_5.pt
Epoch [6/50], Loss: 0.5368, Cosine Sim: 0.4632
Epoch [7/50], Loss: 0.5538, Cosine Sim: 0.4462
Epoch [8/50], Loss: 0.5796, Cosine Sim: 0.4204
Epoch [9/50], Loss: 0.5302, Cosine Sim: 0.4698
Epoch [10/50], Loss: 0.5233, Cosine Sim: 0.4767
Validation Loss: 0.6136, Validation Cosine Sim: 0.3864
Model checkpoint saved at ../../data/logs/mol2dreams/20241010_115915_mol2dreams/model_epoch_10.pt
Epoch [11/50], Loss: 0.5120, Cosine Sim: 0.4880
Epoch [12/50], Loss: 0.5319, Cosine Sim: 0.4681
Epoch [13/50], Loss: 0.5338, Cosine Sim: 0.4662
Epoch [14/50], Loss: 0.5

In [9]:
trainer.test()

Test Loss: 0.5206, Test Cosine Sim: 0.4794


## Load model from config

In [25]:
with open("/Users/macbook/CODE/mol2DreaMS/mol2dreams/configs/local_config_load_trained_model_simple_dataset.yaml") as stream:
    config = yaml.safe_load(stream)

In [26]:
trainer = build_trainer_from_config(config)

Loaded pretrained weights from /Users/macbook/CODE/mol2DreaMS/data/logs/mol2dreams/best_model.pt


In [27]:
trainer.test()

Test Loss: 0.5387, Test Cosine Sim: 0.4683
