In [None]:

import pathlib
import hydra
import os, sys
import torch
sys.path.append('../')

import torch
import wandb
import hydra
import omegaconf
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from dgd import utils
from dgd.datasets.frag_dataset import FragDataset, FragDataModule, FragDatasetInfos
from dgd.analysis.frag_utils import PyGGraphToMolConverter, FragSamplingMetrics
from dgd.datasets.frag_dataset import FRAG_GRAPH_FILE, FRAG_INDEX_FILE, FRAG_EDGE_FILE



In [None]:
config_dir = pathlib.Path('../configs/')
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(config_dir)
cfg = hydra.compose('config.yaml')
dataset_config = cfg["dataset"]
cfg.general.wandb = 'disabled'
print(cfg)


In [None]:

dataset = FragDataset(FRAG_GRAPH_FILE)
converter = PyGGraphToMolConverter('../data/' + FRAG_INDEX_FILE, '../data/' +FRAG_EDGE_FILE)
example_graph = dataset[0]
converter.graph_to_mol(example_graph, count_non_edge=True)

In [None]:

from dgd.metrics.abstract_metrics import TrainAbstractMetricsDiscrete
from dgd.analysis.visualization import MolecularVisualization, NonMolecularVisualization
from dgd.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures
from diffusion_model_discrete import DiscreteDenoisingDiffusion

datamodule = FragDataModule(cfg)
sampling_metrics = FragSamplingMetrics(datamodule.dataloaders, [])

dataset_infos = FragDatasetInfos(datamodule, dataset_config)
train_metrics = TrainAbstractMetricsDiscrete()
visualization_tools = NonMolecularVisualization()

extra_features = DummyExtraFeatures()
domain_features = DummyExtraFeatures()

dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features,
                                        domain_features=domain_features)

model_kwargs = {'dataset_infos': dataset_infos, 'train_metrics': train_metrics,
                'sampling_metrics': sampling_metrics, 'visualization_tools': visualization_tools,
                'extra_features': extra_features, 'domain_features': domain_features}

In [None]:
model = DiscreteDenoisingDiffusion(cfg=cfg, **model_kwargs)

In [None]:
def setup_wandb(cfg):
    config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    kwargs = {'name': cfg.general.name, 'project': f'graph_ddm_{cfg.dataset.name}', 'config': config_dict,
              'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': cfg.general.wandb}
    wandb.init(**kwargs)
    wandb.save('*.txt')
    return cfg
cfg = setup_wandb(cfg)

In [None]:
trainer = Trainer(gradient_clip_val=cfg.train.clip_grad,
                    accelerator='gpu' if torch.cuda.is_available() and cfg.general.gpus > 0 else 'cpu',
                    devices=cfg.general.gpus if torch.cuda.is_available() and cfg.general.gpus > 0 else None,
                    limit_train_batches=20,
                    limit_val_batches=20,
                    limit_test_batches=20,
                    val_check_interval=cfg.general.val_check_interval,
                    max_epochs=cfg.train.n_epochs,
                    check_val_every_n_epoch=cfg.general.check_val_every_n_epochs,
                    fast_dev_run=cfg.general.name == 'debug',
                    strategy='ddp' if cfg.general.gpus > 1 else None,
                    enable_progress_bar=False,
                    callbacks=[],
                    logger=[])

In [None]:

trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
