In [None]:
%cd ~/qsar_flash/

In [None]:
import flash
from torchmetrics import MeanAbsoluteError

from qsar_flash import MolecularGraphRegressor, GraphRegressionData, QM9Property

In [None]:
MolecularGraphRegressor.available_backbones()

In [None]:
MolecularGraphRegressor.available_lr_schedulers()

In [None]:
from torch.nn import functional as F

dataset = QM9Property(property_index=4, root="data/qm9")

datamodule = GraphRegressionData.from_datasets(
    train_dataset=dataset,
    val_split=0.2,
    batch_size=32,
    num_workers=64,
    pin_memory=True,
)

# 2. Build the task
backbone_kwargs = dict(
    hidden_channels=128,
    num_filters=128,
    num_interactions=6,
    num_gaussians=100,
    cutoff=10.,
    max_num_neighbors=32,
)

lr_scheduler_kwargs = dict(
    max_lr=0.0001,
    total_steps=len(datamodule.train_dataset)  # type: ignore
)

lr_scheduler_pl_kwargs = dict(
    interval='step'
)
model = MolecularGraphRegressor(
    backbone="SchNet", metrics=MeanAbsoluteError(), learning_rate=0.0001, pooling_fn="add", optimizer="Adam", loss_fn=F.l1_loss,  lr_scheduler=("onecyclelr", lr_scheduler_kwargs, lr_scheduler_pl_kwargs),
    backbone_kwargs=backbone_kwargs
)

In [None]:
import wandb
wandb.finish()

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor

wandb_logger = WandbLogger(
    name="schnet-test",
    project="graph-drug",
    entity="inno-materials-ai",
)
lr_monitor = LearningRateMonitor(logging_interval='step') 

# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=50, gpus=[1], logger=wandb_logger, callbacks=[lr_monitor])
trainer.fit(model, datamodule=datamodule)