In [2]:
%cd ~/qsar_flash/

/home/nm_rostislav/qsar_flash


In [3]:
import flash
from torchmetrics.regression.mae import MeanAbsoluteError

from qsar_flash import MolecularGraphRegressor, GraphRegressionData, QM9Property

In [6]:
import torch
from torch import nn
from torch_geometric.nn.inits import glorot_orthogonal
from torch.nn import functional as F

dataset = QM9Property(property_index=4, root="data/qm9")

datamodule = GraphRegressionData.from_datasets(
    train_dataset=dataset,
    val_split=0.2,
    batch_size=64,
    num_workers=64,
    pin_memory=True,
)

# 2. Build the task
backbone_kwargs = dict(
    hidden_channels=128, 
    out_channels=128, 
    num_blocks=6,
    num_bilinear=8, 
    num_spherical=7, 
    num_radial=6,
    cutoff=5.0, 
    envelope_exponent=5, 
    num_before_skip=1,
    num_after_skip=2, 
    num_output_layers=1,
)

class DimenetMolecularGraphHead(torch.nn.Module):
    def __init__(self, hidden_channels: int, dropout: float = 0.5):
        super().__init__()
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, 1)
        self.dropout = dropout
        self.act = nn.SiLU()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x):
        x = self.act(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin2(x)

model = MolecularGraphRegressor(
    backbone="DimeNet", metrics=MeanAbsoluteError(), learning_rate=0.001, pooling_fn="add", optimizer="AdamW",
    backbone_kwargs=backbone_kwargs, head=DimenetMolecularGraphHead(backbone_kwargs["out_channels"]), 
)

  rank_zero_warn(
  rank_zero_warn(


In [7]:
# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=50, gpus=[1], gradient_clip_val=10.)
trainer.fit(model, datamodule=datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | train_metrics | ModuleDict                | 0     
1 | val_metrics   | ModuleDict                | 0     
2 | test_metrics  | ModuleDict                | 0     
3 | backbone      | DimeNetBackbone           | 2.0 M 
4 | head          | DimenetMolecularGraphHead | 16.6 K
------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.997     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
