In [1]:
!nvidia-smi

Tue May 14 14:11:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1D.0 Off |                    0 |
|  0%   28C    P8              16W / 300W |      0MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import torch
torch.__version__

'2.3.0+cu121'

In [3]:
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [4]:
!pip install torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [5]:
!pip install ase

Defaulting to user installation because normal site-packages is not writeable


In [6]:
#!pip install --upgrade pytorch-lightning
#!pip install transformers
#!pip install lightning-bolts
#!pip install lightning-utilities

In [7]:
import os
import torch
import torch.nn.functional as F

from dataclasses import asdict
from torch.nn.utils import clip_grad_norm_
# from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from torch_geometric.loader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

from config import Config
from data.ee import EE
from data.bde import BDE
from data.drugs import Drugs
from data.kraken import Kraken
from data.tmqmg import tmQMg
from happy_config import ConfigLoader
from loaders.samplers import EnsembleSampler, EnsembleMultiBatchSampler
from loaders.multibatch import MultiBatchLoader
from utils.early_stopping import EarlyStopping, generate_checkpoint_filename

from models.model_3d import Model3D
from models.models_3d.chiro import ChIRo
from models.models_3d.painn import PaiNN
from models.models_3d.schnet import SchNet
from models.models_3d.gemnet import GemNetT
from models.models_3d.dimenet import DimeNet, DimeNetPlusPlus
from models.models_3d.clofnet import ClofNet
from models.models_3d.leftnet import LEFTNet
from models.models_3d.chytorch_discrete import ChytorchDiscrete
from models.model_4d import Model4D, SumPooling, MeanPooling, TransformerPooling, DeepSets, SelfAttentionPooling
from loaders.utils import reorder_molecule_idx

import pickle
from time import time

In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
import torch.nn as nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR

In [9]:
from models.model_3d import GroupedScaledMAELoss

In [10]:
###################################################
import os
import numpy as np
#os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.empty_cache()
#print(os.environ['CUDA_VISIBLE_DEVICES'])

seed = 42
torch.manual_seed(seed)

# If you are using a GPU, you should also set the seed for CUDA operations
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

np.random.seed(seed)
################################################## 

In [11]:
def r2_score_torch(y_true, y_pred):
    """
    Compute the R-squared score.

    Parameters:
        y_true (torch.Tensor): The true target values.
        y_pred (torch.Tensor): The predicted target values.

    Returns:
        float: The R-squared score.
    """
    # Calculate the mean of the true target values
    y_mean = torch.mean(y_true)
    # Calculate the total sum of squares (TSS)
    tss = torch.sum((y_true - y_mean) ** 2)
    # Calculate the residual sum of squares (RSS)
    rss = torch.sum((y_true - y_pred) ** 2)
    # Calculate R-squared score
    r2 = 1 - rss / tss

    return r2.item()

In [12]:
class DataModule(LightningDataModule):
    def __init__(self, hparams, dataset=None, multitask = False):
        super().__init__()
        self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
        self._saved_dataloaders = dict()
        self.dataset = dataset
        self.multitask = multitask
        
        if self.dataset is None:
            self.variable_name = None
            unique_variables = 1

            if self.hparams.dataset == 'Drugs':
                dataset = Drugs('/mnt/data/MARCEL/datasets/Drugs', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'Kraken':
                dataset = Kraken('/mnt/data/MARCEL/datasets/Kraken', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'BDE':
                dataset = BDE('/mnt/data/MARCEL/datasets/BDE').shuffle()
                self.variable_name = 'is_ligand'
                unique_variables = 2
            elif self.hparams.dataset == 'tmQMg':
                dataset = tmQMg('/mnt/data/MARCEL/datasets/tmQMg').shuffle()
                unique_variables = 1
            elif self.hparams.dataset == 'EE':
                dataset = EE('/mnt/data/MARCEL/datasets/EE', max_num_conformers=self.hparams.max_num_conformers).shuffle()
                self.variable_name = 'config_id'
                unique_variables = 2

            if self.multitask:
                self.hparams.target = 'all'
                pass
            else:
                #autoscaling
                target_id = dataset.descriptors.index(self.hparams.target)
                dataset.y = dataset.y[:, target_id]
                #mean = dataset.y.mean(dim=0, keepdim=True)
                #std = dataset.y.std(dim=0, keepdim=True)
                #dataset.y = ((dataset.y - mean) / std).to('cuda')
                #mean = mean.to('cuda')
                #std = std.to('cuda')
            
                #data.dataset.data.y = data.dataset.y
            
            self.dataset = dataset
            self.max_atomic_num = self.dataset.data.x[:, 0].max().item() + 1
            self.unique_variables = unique_variables
            print('--done---')

    def split_compute(self):

        split = self.dataset.get_idx_split(train_ratio=self.hparams.train_ratio, 
                                      valid_ratio=self.hparams.valid_ratio, 
                                      seed=self.hparams.seed)
        self.train_dataset = self.dataset[split['train']]
        self.valid_dataset = self.dataset[split['valid']]
        self.test_dataset = self.dataset[split['test']]

        print(f'{len(self.train_dataset)} training data, {len(self.test_dataset)} test data and {len(self.valid_dataset)} validation data')

    def train_dataloader(self):
        return self._get_dataloader(self.train_dataset, "train")

    def val_dataloader(self):
        return self._get_dataloader(self.valid_dataset, "val")

    def test_dataloader(self):
        return self._get_dataloader(self.test_dataset, "test")
    
    def _get_dataloader(self, dataset, stage, store_dataloader=True):
        store_dataloader = store_dataloader
        
        if stage in self._saved_dataloaders and store_dataloader:
            return self._saved_dataloaders[stage]

        strategy = 'all'
            
        if stage == "train":
            shuffle=True                              
        else:
            shuffle=False

        if self.variable_name is None:
            dl = DataLoader(dataset, batch_sampler=EnsembleSampler(dataset, 
                                                                   batch_size=self.hparams.batch_size, 
                                                                   strategy=strategy, 
                                                                   shuffle=shuffle),
                           num_workers=20)
        else:
            dl = MultiBatchLoader(dataset, batch_sampler=EnsembleMultiBatchSampler(dataset, 
                                                                                   batch_size=self.hparams.batch_size, 
                                                                                   strategy=strategy, 
                                                                                   shuffle=shuffle, 
                                                                                   variable_name=self.variable_name),
                                 num_workers=20)
        if store_dataloader:
            self._saved_dataloaders[stage] = dl
        return dl

In [13]:
import math
def compute_pnorm(model: nn.Module) -> float:
    """
    Computes the norm of the parameters of a model.
    :param model: A PyTorch model.
    :return: The norm of the parameters of the model.
    """
    return math.sqrt(sum([p.norm().item() ** 2 for p in model.parameters() if p.requires_grad]))


def compute_gnorm(model: nn.Module) -> float:
    """
    Computes the norm of the gradients of a model.
    :param model: A PyTorch model.
    :return: The norm of the gradients of the model.
    """
    return math.sqrt(sum([p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None]))

In [32]:
class ModelLM(LightningModule):
    def __init__(self, whole_dataset = None, unique_variables=1, multitask = False, **kwargs):
        super().__init__()
        #self.kwargs.update(kwargs.__dict__) if hasattr(kwargs, "__dict__") else self.kwargs.update(kwargs)
        #print(kwargs.get('model4d'))

        max_atomic_num = whole_dataset.data.x[:, 0].max().item() + 1
        
        if kwargs.get('model4d').model == 'SchNet':
            graph_model_factory = lambda: SchNet(max_atomic_num=max_atomic_num, 
                                           **asdict(kwargs.get('model4d').schnet))
        elif kwargs.get('model4d').model == 'DimeNet':
            graph_model_factory = lambda: DimeNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model4d').dimenet))
        elif kwargs.get('model4d').model == 'DimeNet++':
            graph_model_factory = lambda: DimeNetPlusPlus(max_atomic_num=max_atomic_num, 
                                                    **asdict(kwargs.get('model4d').dimenetplusplus))
        elif kwargs.get('model4d').model == 'GemNet':
            graph_model_factory = lambda: GemNetT(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model4d').gemnet))
        elif kwargs.get('model4d').model == 'ChIRo':
            graph_model_factory = lambda: ChIRo(**asdict(kwargs.get('model4d').chiro))
            
        elif kwargs.get('model4d').model == 'PaiNN':
            graph_model_factory = lambda: PaiNN(max_atomic_num=max_atomic_num, 
                                          **asdict(kwargs.get('model4d').painn))
        elif kwargs.get('model4d').model == 'ClofNet':
            graph_model_factory = lambda: ClofNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model4d').clofnet))
        elif kwargs.get('model4d').model == 'LEFTNet':
            graph_model_factory = lambda: LEFTNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model4d').leftnet))
        elif kwargs.get('model4d').model == 'ChytorchDiscrete':
            graph_model_factory = lambda: ChytorchDiscrete(max_neighbors=max_atomic_num, 
                                                     **asdict(kwargs.get('model4d').chytorch_discrete))
        elif kwargs.get('model4d').model == 'ChytorchConformer':
            graph_model_factory = lambda: ChytorchConformer(**asdict(kwargs.get('model4d').chytorch_conformer))
            
        elif kwargs.get('model4d').model == 'ChytorchRotary':
            graph_model_factory = lambda: ChytorchRotary(max_neighbors=max_atomic_num, 
                                                   **asdict(kwargs.get('model4d').chytorch_rotary))

        if kwargs.get('model4d').set_encoder == 'Sum':
            set_model_factory = lambda: SumPooling()
        elif kwargs.get('model4d').set_encoder == 'Mean':
            set_model_factory = lambda: MeanPooling()
        elif kwargs.get('model4d').set_encoder == 'DeepSets':
            set_model_factory = lambda: DeepSets(hidden_dim=kwargs.get('hidden_dim'))
        elif kwargs.get('model4d').set_encoder == 'Attention':
            set_model_factory = lambda: SelfAttentionPooling(hidden_dim=kwargs.get('hidden_dim'))
        elif kwargs.get('model4d').set_encoder == 'Transformer':
            set_model_factory = lambda: TransformerPooling(
                hidden_dim=kwargs.get('hidden_dim'), **asdict(kwargs.get('model4d').transformer))

        
        self.device_= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.multitask = multitask
        
        self.net = Model4D(hidden_dim=kwargs.get('hidden_dim'), out_dim=1,
                           graph_model_factory=graph_model_factory,
                           set_model_factory=set_model_factory,
                           unique_variables=unique_variables, 
                           device='cuda',
                           multitask=self.multitask).to('cuda')
        
        self.whole_dataset = whole_dataset

        if self.multitask:
            self.loss_fn = GroupedScaledMAELoss(torch.ones(self.whole_dataset.y.shape[1], 
                                                           dtype=torch.long))
        else:
            self.loss_fn = nn.MSELoss() #LOGITS #GroupedScaledMAELoss(torch.ones(4, dtype=torch.long))
        
        self.lr = kwargs.get('learning_rate')
        self.wd = kwargs.get('learning_rate')
        
        self._reset_losses_dict()
        self._reset_inference_results()
        self.save_hyperparameters(ignore=["cosine_annealing_lr","linear_warmup_cosine_annealing_lr",
                                          "model1d","model2d","model3d","model4d","modelfprf","one_cycle_lr",
                                          "reduce_lr_on_plateau","whole_dataset","device","scheduler"])

    def forward(self, batch, molecule_indices):
        out = self.net(batch, molecule_indices)
        return out
        
    def configure_optimizers(self):
        o = AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        s = CyclicLR(o, self.lr, 2e-4, 1000, mode='triangular', cycle_momentum=False)
        # instantiate the WeakMethod in the lr scheduler object into the custom scale function attribute
        #s._scale_fn_custom = s._scale_fn_ref()
        # remove the reference so there are no more WeakMethod references in the object
        #s._scale_fn_ref = None
        return [o], [{'scheduler': s, 'interval': 'step', 'name': 'lr_scheduler'}]

    def training_step(self, batch, batch_idx):
        pnorm = compute_pnorm(self.net)
        gnorm = compute_gnorm(self.net)
        self.log(f'(training) pnorm', pnorm)
        self.log(f'(training) gnorm', gnorm)
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")
        
    def step(self, batch, stage):
        start = time()
        if type(batch) is not list:
            batch = [batch]
        molecule_indices = [reorder_molecule_idx(batch[i].molecule_idx).to('cuda') for i in range(len(batch))]
        #molecule_idx = batch[0].molecule_idx.to('cuda')

        unique_raw_molecule_idx = torch.unique_consecutive(batch[0].molecule_idx)
        dataset = self.whole_dataset.y.to('cuda')
        targets = dataset[unique_raw_molecule_idx].squeeze()

        with torch.set_grad_enabled(stage == "train"):
            if self.multitask:
                batch_multi = batch.copy()
                if self.whole_dataset.y.shape[1]==1:
                    print('only one property-switch to singletask')
                else:
                    for cnt, bat_i in enumerate(batch_multi):
                        for i in range(self.whole_dataset.y.shape[1]-1):
                            bat_tmp = batch_multi[cnt].batch#.detach().clone()
                            batch_multi[cnt] = batch_multi[cnt].concat(batch[cnt])
                            batch_multi[cnt].batch = torch.hstack([bat_tmp, batch[cnt].batch + (bat_tmp.max()+1)])
                
                targets_flat = targets.flatten()
                prompts = torch.tensor([i for i in range(data.dataset.y.shape[1])]*targets.shape[0],
                                      dtype=torch.int32,
                                      device=targets.device)
                for cnt, bat_i in enumerate(batch_multi):
                    batch_multi[cnt].tokens = prompts
                
                pred = self(batch_multi, molecule_indices)
                loss = self.loss_fn(pred.squeeze(), targets_flat, prompts)
                if stage == "test":
                    self.inference_results['token'].append(prompts.squeeze())
                    self.inference_results['y_pred'].append(pred.squeeze())
                    self.inference_results['y_true'].append(targets_flat.squeeze())
                    return None
                r2=r2_score_torch(targets_flat.cpu(),pred.squeeze().cpu().detach())
            else:
                pred = self(batch, molecule_indices)
                loss = self.loss_fn(pred.squeeze(), targets)
            
                if stage == "test":
                    self.inference_results['y_pred'].append(pred.squeeze())
                    self.inference_results['y_true'].append(targets.squeeze())
                    return None

                r2=r2_score_torch(targets.cpu(),pred.squeeze().cpu().detach())

            self.logging_info[f'{stage}_loss'].append(loss.item())
            self.logging_info[f'{stage}_r2'].append(r2)
            self.logging_info[f'{stage}_time'].append(time()-start)
            
            if stage == 'train':
                self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
                self.log(f'{stage}_step_loss', loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
            #model.save_checkpoint('checkpoint.pth')
            return loss
        
    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:                    
            result_dict = {
                "epoch": float(self.current_epoch),
                "train_epoch_loss": torch.tensor(self.logging_info["train_loss"]).mean().item(),
                "train_epoch_r2": torch.tensor(self.logging_info["train_r2"]).mean().item(),
                "val_epoch_loss": torch.tensor(self.logging_info["val_loss"]).mean().item(),
                "val_epoch_r2": torch.tensor(self.logging_info["val_r2"]).mean().item(),
                "train_epoch_time": sum(self.logging_info["train_time"]),
                "val_epoch_time": sum(self.logging_info["val_time"]),
                }
            self.log_dict(result_dict, logger=True, sync_dist=True)
            
        self._reset_losses_dict()
    
    def on_test_epoch_end(self) -> None:
        for key in self.inference_results.keys():
            self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)
    
    def _reset_losses_dict(self):
        self.logging_info = {
            "train_loss": [],
            "train_r2": [], 
            "train_mse": [], 
            "val_loss": [],
            "val_r2": [],
            "train_sample_size": [], 
            "val_sample_size": [],
            "train_time": [],
            "val_time": [],
        }
        
    def _reset_inference_results(self):
        self.inference_results = {'token': [],
                                  'y_pred': [],
                                  'y_true': []}

In [33]:
def config_to_dict(config_class):
    config_dict = {}
    for attr_name in dir(config_class):
        if not attr_name.startswith("__") and not callable(getattr(config_class, attr_name)):
            config_dict[attr_name] = getattr(config_class, attr_name)
    return config_dict

In [16]:
for dataname,targets in zip(['BDE','Drugs','Kraken','tmQMg'],
                           [['BindingEnergy'],['ip', 'ea', 'chi'],['sterimol_B5', 'sterimol_L', 'sterimol_burB5', 'sterimol_burL'],['tzvp_dipole_moment']]):
    for modeltype in ['ChytorchDiscrete', 'ChIRo','SchNet','GemNet','PaiNN','ClofNet','DimeNet++']:#,'GIN','GIN-VN','GPS','ChemProp']: #,'Chytorch'
        print(dataname)
        for target in targets:
            print(target)
            print(modeltype)
            config = Config
            config.dataset = dataname
            config.target = target
            config.device = 'cuda:0'
            
            
            config.model4d.model=modeltype
            config.model4d.augmentation = True
            
            config_dict = config_to_dict(config)
    
            subkeys = ["dataset",
                       "max_num_conformers",
                       "target",
                       "train_ratio",
                       "valid_ratio",
                       "seed",
                       "model3d", #.augmentation"
                       "batch_size"]
            config_dict_datamodule = {}
            for k,v in config_dict.items():
                if k in subkeys:
                    if k=="model4d":
                        config_dict_datamodule[f'{k}_augmentation']=config_dict[k].augmentation
                    else:
                        config_dict_datamodule[k]=v
            data = DataModule(config_dict_datamodule, multitask = False)
            data.prepare_data()
            data.split_compute()
            model = ModelLM(max_atomic_num=data.max_atomic_num, 
                            whole_dataset = data.dataset, 
                            unique_variables=data.unique_variables, 
                            multitask = False, **config_dict)
            print(f'#PARAMS = {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

BDE
BindingEnergy
ChytorchDiscrete




--done---
79637 training data, 22623 test data and 11838 validation data
#PARAMS = 3395776
BDE
BindingEnergy
ChIRo
--done---
79322 training data, 23377 test data and 11399 validation data
#PARAMS = 1834246
BDE
BindingEnergy
SchNet
--done---
80246 training data, 22803 test data and 11049 validation data
#PARAMS = 470364
BDE
BindingEnergy
GemNet
--done---
79657 training data, 23180 test data and 11261 validation data
#PARAMS = 3945344
BDE
BindingEnergy
PaiNN
--done---
78886 training data, 23442 test data and 11770 validation data
#PARAMS = 2660352
BDE
BindingEnergy
ClofNet
--done---
80455 training data, 22592 test data and 11051 validation data
#PARAMS = 1250178
BDE
BindingEnergy
DimeNet++
--done---
79794 training data, 23252 test data and 11052 validation data
#PARAMS = 3994124
Drugs
ip
ChytorchDiscrete
--done---
378811 training data, 108258 test data and 54294 validation data
#PARAMS = 1689952
ea
ChytorchDiscrete
--done---
378590 training data, 108507 test data and 54266 validation dat

In [50]:
dataname = 'Drugs'  #['BDE','Drugs','Kraken','tmQMg],
target = 'ea' #[['BindingEnergy'],['ip', 'ea', 'chi'],['sterimol_B5', 'sterimol_L', 'sterimol_burB5', 'sterimol_burL']]'tzvp_dipole_moment'
modeltype = 'SchNet'  #['SchNet','GemNet','PaiNN','ClofNet','LEFTNet','DimeNet++','ChytorchDiscrete', 'ChIRo']  #oom: 'GemNet'

writer = SummaryWriter(log_dir='/mnt/code/logs/')
#loader = ConfigLoader(model=Config, config='params/params_1d.json')
#config = loader()

config = Config
config.dataset = dataname
config.target = target
config.device = 'cuda:0'


#config.hidden_dim =128
#config.dropout = 0.5

#config.train_ratio = 0.7
#config.valid_ratio = 0.1
config.batch_size = 64 #256
#config.patience = 200

#config.learning_rate = 0.00001
#config.weight_decay = 1e-4
#config.scheduler = None
#config.reduce_lr_on_plateau = ReduceLROnPlateau()

#config.num_epochs = 2000
#config.cosine_annealing_lr = CosineAnnealingLR()
#config.linear_warmup_cosine_annealing_lr = LinearWarmupCosineAnnealingLR()


#not used in 1d

#config.one_cycle_lr = OneCycleLR()
#config.seed = 123
#activation = 'relu'

######3DMODEL
config.model3d.model=modeltype
config.model3d.augmentation = True

#config.model3d.schnet = SchNet()
#config.model3d.dimenet = DimeNet()
#config.model3d.dimenetplusplus = DimeNetPlusPlus()
#config.model3d.gemnet = GemNet()
#config.model3d.painn = PaiNN()
#config.model3d.clofnet = ClofNet()
#config.model3d.leftnet= LEFTNet()
#config.model3d.chytorch_discrete = ChytorchDiscrete()
#config.model3d.chytorch_conformer = ChytorchConformer()
#config.model3d.chytorch_rotary = ChytorchRotary()

In [51]:
config_dict = config_to_dict(config)

In [52]:
subkeys = ["dataset",
           "max_num_conformers",
           "target",
           "train_ratio",
           "valid_ratio",
           "seed",
           "model4d", #.augmentation"
           "batch_size"]
config_dict_datamodule = {}
for k,v in config_dict.items():
    if k in subkeys:
        if k=="model4d":
            config_dict_datamodule[f'{k}_augmentation']=config_dict[k].augmentation
        else:
            config_dict_datamodule[k]=v

In [53]:
config_dict_datamodule

{'batch_size': 64,
 'dataset': 'Drugs',
 'max_num_conformers': 20,
 'model4d_augmentation': True,
 'seed': 123,
 'target': 'ea',
 'train_ratio': 0.7,
 'valid_ratio': 0.1}

In [54]:
data = DataModule(config_dict_datamodule, multitask = False)
data.prepare_data()
data.split_compute()

--done---
378766 training data, 108580 test data and 54017 validation data


In [55]:
data.dataset.data

Data(x=[16545178, 9], edge_index=[2, 35655306], edge_attr=[35655306, 3], pos=[16545178, 3], name=[541363], id=[541363], smiles=[541363], y=[541363, 4], molecule_idx=[541363])

In [56]:
#z.shape, hgs.shape, pos.shape, bat.shape

In [57]:
model = ModelLM(max_atomic_num=data.max_atomic_num, 
                whole_dataset = data.dataset, 
                unique_variables=data.unique_variables, 
                multitask = False, **config_dict)

In [58]:
#model._log_hyperparams = False

In [59]:
print(f'#PARAMS = {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

#PARAMS = 1989126


In [60]:
model

ModelLM(
  (net): Model4D(
    (graph_encoders): ModuleList(
      (0): DimeNetPlusPlus(
        (rbf): BesselBasisLayer(
          (envelope): Envelope()
        )
        (sbf): SphericalBasisLayer(
          (envelope): Envelope()
        )
        (emb): EmbeddingBlock(
          (emb): Embedding(17, 128)
          (lin_rbf): Linear(in_features=6, out_features=128, bias=True)
          (lin): Linear(in_features=384, out_features=128, bias=True)
        )
        (output_blocks): ModuleList(
          (0-4): 5 x OutputPPBlock(
            (lin_rbf): Linear(in_features=6, out_features=128, bias=False)
            (lin_up): Linear(in_features=128, out_features=256, bias=False)
            (lins): ModuleList(
              (0-2): 3 x Linear(in_features=256, out_features=256, bias=True)
            )
            (lin): Linear(in_features=256, out_features=128, bias=False)
          )
        )
        (interaction_blocks): ModuleList(
          (0-3): 4 x InteractionPPBlock(
           

In [61]:
dir_name = f"tmp_4D_singletask_{dataname}_{target}_{modeltype}_multitask_v0"

dir_load_model = None
log_dir_folder = '/mnt/code/logs/'
log_dir_folder = os.path.join(log_dir_folder, dir_name)
if os.path.exists(log_dir_folder):
    if os.path.exists(os.path.join(log_dir_folder, "last.ckpt")):
        dir_load_model = os.path.join(log_dir_folder, "last.ckpt")
    csv_path = os.path.join(log_dir_folder, "metrics.csv")
    while os.path.exists(csv_path):
        csv_path = csv_path + '.bak'
    if os.path.exists(os.path.join(log_dir_folder, "metrics.csv")):
        os.rename(os.path.join(log_dir_folder, "metrics.csv"), csv_path)

In [62]:
metric_to_monitor = "val_epoch_loss"

checkpoint_callback = ModelCheckpoint(
    dirpath=log_dir_folder,
    monitor=metric_to_monitor,
    mode = 'min',
    save_top_k=1,
    save_last=True,
    every_n_epochs=5,
    save_weights_only=True,
    verbose=True,
    filename="best-model-{epoch}-{val_epoch_loss:.4f}",
)


early_stopping = early_stop_callback = EarlyStopping(
        monitor=metric_to_monitor,  # The metric you want to monitor
        patience=config.patience,  # Number of epochs with no improvement after which training will be stopped
        verbose=True,
        mode='min'  # Minimizing the validation loss
    )

tb_logger = TensorBoardLogger(log_dir_folder, name="tensorbord")#, version="", default_hp_metric=False)
csv_logger = CSVLogger(log_dir_folder, name="", version="")

model_params = dict(
    devices=1, #args['ngpus'],
    accelerator='gpu', #args['accelerator'],
    default_root_dir=log_dir_folder, #args['log_dir'],
    logger=[tb_logger, csv_logger],
    enable_progress_bar=True)


model_params.update(dict(
    max_epochs=config.num_epochs,#1000,
    callbacks=[checkpoint_callback, early_stopping],
    #enable_checkpointing=False,
    gradient_clip_val=10,#args['clip_norm'],
    #precision="16-mixed",
))

In [63]:
torch.set_float32_matmul_precision('medium')

In [64]:
trainer = pl.Trainer(**model_params)
trainer.fit(model, datamodule=data, ckpt_path=dir_load_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /mnt/code/logs/tmp_4D_singletask_Drugs_ea_SchNet_multitask_v0/tensorbord
/usr/local/lib/python3.10/dist-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory /mnt/code/logs/tmp_4D_singletask_Drugs_ea_SchNet_multitask_v0/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /mnt/code/logs/tmp_4D_singletask_Drugs_ea_SchNet_multitask_v0 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | net     | Model4D | 2.0 M 
1 | loss_fn | MSELoss | 0     
------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [65]:
data.dataset.descriptors

['energy', 'ip', 'ea', 'chi']