In [1]:
%cd ~/qsar_flash/

/home/nm_rostislav/qsar_flash


In [2]:
import flash
from torchmetrics.regression.mae import MeanAbsoluteError

from qsar_flash import MolecularGraphRegressor, GraphRegressionData, QM9Property

In [4]:
import torch
from torch import nn
from torch_geometric.nn.inits import glorot_orthogonal
from torch.nn import functional as F

dataset = QM9Property(property_index=4, root="data/qm9")

datamodule = GraphRegressionData.from_datasets(
    train_dataset=dataset,
    val_split=0.2,
    batch_size=64,
    num_workers=64,
    pin_memory=True,
)

# 2. Build the task
backbone_kwargs = dict(
    hidden_channels=128, 
    out_channels=128, 
    int_emb_size=64,
    out_emb_channels=256,
    num_blocks=6,
    basis_emb_size=8, 
    num_spherical=7, 
    num_radial=6,
    cutoff=10.0, 
    envelope_exponent=5, 
    num_before_skip=1,
    num_after_skip=2, 
    num_output_layers=3,
)

class DimenetMolecularGraphHead(torch.nn.Module):
    def __init__(self, hidden_channels: int, dropout: float = 0.5):
        super().__init__()
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, 1)
        self.dropout = dropout
        self.act = nn.SiLU()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x):
        x = self.act(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin2(x)

model = MolecularGraphRegressor(
    backbone="DimeNet++", metrics=MeanAbsoluteError(), learning_rate=0.001, pooling_fn="add", optimizer="AdamW",
    backbone_kwargs=backbone_kwargs, head=DimenetMolecularGraphHead(backbone_kwargs["out_channels"]), 
)

In [6]:
# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=50, gpus=None, gradient_clip_val=10., fast_dev_run=True)
trainer.fit(model, datamodule=datamodule)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..

  | Name          | Type                      | Params
------------------------------------------------------------
0 | train_metrics | ModuleDict                | 0     
1 | val_metrics   | ModuleDict                | 0     
2 | test_metrics  | ModuleDict                | 0     
3 | backbone      | DimeNetPlusPlusBackbone 

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]