In [1]:
!nvidia-smi

Tue May 14 10:12:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1D.0 Off |                    0 |
|  0%   24C    P8              15W / 300W |      0MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import torch
torch.__version__

'2.3.0+cu121'

In [65]:
from time import time

In [3]:
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [4]:
!pip install torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [5]:
import os
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T

from dataclasses import asdict
# from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from torch_geometric.loader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

from config import Config
from data.ee import EE_2D
from data.bde import BDE
from data.drugs import Drugs
from data.kraken import Kraken
from happy_config import ConfigLoader
from loaders.samplers import EnsembleSampler, EnsembleMultiBatchSampler
from loaders.multibatch import MultiBatchLoader
from utils.early_stopping import EarlyStopping, generate_checkpoint_filename

from models.model_2d import GIN, GPS, Model2D
from models.models_2d.chemprop import ChemProp, transform_reversely_indexed_data
#from models.models_2d.chytorch2d import Chytorch
import os
import shutil
import math

import pickle

In [6]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
import torch.nn as nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR

In [7]:
from models.model_3d import GroupedScaledMAELoss

In [8]:
###################################################
import os
import numpy as np
#os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.empty_cache()
#print(os.environ['CUDA_VISIBLE_DEVICES'])

seed = 42
torch.manual_seed(seed)

# If you are using a GPU, you should also set the seed for CUDA operations
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

np.random.seed(seed)
################################################## 

In [9]:
def r2_score_torch(y_true, y_pred):
    """
    Compute the R-squared score.

    Parameters:
        y_true (torch.Tensor): The true target values.
        y_pred (torch.Tensor): The predicted target values.

    Returns:
        float: The R-squared score.
    """
    # Calculate the mean of the true target values
    y_mean = torch.mean(y_true)
    # Calculate the total sum of squares (TSS)
    tss = torch.sum((y_true - y_mean) ** 2)
    # Calculate the residual sum of squares (RSS)
    rss = torch.sum((y_true - y_pred) ** 2)
    # Calculate R-squared score
    r2 = 1 - rss / tss

    return r2.item()

In [66]:
class DataModule(LightningDataModule):
    def __init__(self, hparams, dataset=None, multitask = False):
        super().__init__()
        self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
        self._saved_dataloaders = dict()
        self.dataset = dataset
        self.multitask = multitask
        
        if self.dataset is None:
            self.variable_name = None
            unique_variables = 1
            self.pre_transform = None
            self.hparams.max_num_conformers = 18
            
            if self.hparams.model2d_model == 'GPS':
                self.pre_transform = T.AddRandomWalkPE(walk_length=config.model2d.gps.walk_length, attr_name='pe')
            elif self.hparams.model2d_model == 'ChemProp':
                self.pre_transform = transform_reversely_indexed_data
                
            if self.hparams.dataset == 'Drugs':
                dataset = Drugs('/mnt/data/MARCEL/datasets/Drugs', 
                                max_num_conformers=self.hparams.max_num_conformers,
                                pre_transform=self.pre_transform).shuffle()
            elif self.hparams.dataset == 'Kraken':
                dataset = Kraken('/mnt/data/MARCEL/datasets/Kraken', 
                                 max_num_conformers=self.hparams.max_num_conformers,
                                 pre_transform=self.pre_transform).shuffle()
            elif self.hparams.dataset == 'BDE':
                dataset = BDE('/mnt/data/MARCEL/datasets/BDE', 
                              pre_transform=self.pre_transform,
                              max_num_conformers=self.hparams.max_num_conformers,).shuffle()
                self.variable_name = 'is_ligand'
                unique_variables = 2
            elif self.hparams.dataset == 'tmQMg':
                dataset = tmQMg('/mnt/data/MARCEL/datasets/tmQMg',
                                pre_transform=self.pre_transform).shuffle()
                unique_variables = 1
            elif self.hparams.dataset == 'EE':
                dataset = EE('/mnt/data/MARCEL/datasets/EE', max_num_conformers=self.hparams.max_num_conformers,
                            pre_transform=self.pre_transform).shuffle()
                self.variable_name = 'config_id'
                unique_variables = 2

            if self.multitask:
                self.hparams.target = 'all'
                pass
            else:
                #autoscaling
                target_id = dataset.descriptors.index(self.hparams.target)
                dataset.y = dataset.y[:, target_id]
                #mean = dataset.y.mean(dim=0, keepdim=True)
                #std = dataset.y.std(dim=0, keepdim=True)
                #dataset.y = ((dataset.y - mean) / std).to('cuda')
                #mean = mean.to('cuda')
                #std = std.to('cuda')
            
                #data.dataset.data.y = data.dataset.y
            
            self.dataset = dataset
            self.max_atomic_num = self.dataset.data.x[:, 0].max().item() + 1
            self.unique_variables = unique_variables
            print('--done---')

    def split_compute(self):

        split = self.dataset.get_idx_split(train_ratio=self.hparams.train_ratio, 
                                      valid_ratio=self.hparams.valid_ratio, 
                                      seed=self.hparams.seed)
        self.train_dataset = self.dataset[split['train']]
        self.valid_dataset = self.dataset[split['valid']]
        self.test_dataset = self.dataset[split['test']]

        print(f'{len(self.train_dataset)} training data, {len(self.test_dataset)} test data and {len(self.valid_dataset)} validation data')

    def train_dataloader(self):
        return self._get_dataloader(self.train_dataset, "train")

    def val_dataloader(self):
        return self._get_dataloader(self.valid_dataset, "val")

    def test_dataloader(self):
        return self._get_dataloader(self.test_dataset, "test")
    
    def _get_dataloader(self, dataset, stage, store_dataloader=True):
        store_dataloader = store_dataloader
        
        if stage in self._saved_dataloaders and store_dataloader:
            return self._saved_dataloaders[stage]
            
        if stage == "train":
            shuffle=True                              
        else:
            shuffle=False
        strategy = 'first'

        if self.variable_name is None:
            dl = DataLoader(dataset, batch_sampler=EnsembleSampler(dataset, 
                                                                   batch_size=self.hparams.batch_size, 
                                                                   strategy=strategy, 
                                                                   shuffle=shuffle),
                           num_workers=20)
        else:
            dl = MultiBatchLoader(dataset, batch_sampler=EnsembleMultiBatchSampler(dataset, 
                                                                                   batch_size=self.hparams.batch_size, 
                                                                                   strategy=strategy, 
                                                                                   shuffle=shuffle, 
                                                                                   variable_name=self.variable_name),
                                 num_workers=20)
        if store_dataloader:
            self._saved_dataloaders[stage] = dl
        return dl

In [67]:
import math
def compute_pnorm(model: nn.Module) -> float:
    """
    Computes the norm of the parameters of a model.
    :param model: A PyTorch model.
    :return: The norm of the parameters of the model.
    """
    return math.sqrt(sum([p.norm().item() ** 2 for p in model.parameters() if p.requires_grad]))


def compute_gnorm(model: nn.Module) -> float:
    """
    Computes the norm of the gradients of a model.
    :param model: A PyTorch model.
    :return: The norm of the gradients of the model.
    """
    return math.sqrt(sum([p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None]))

In [68]:
class ModelLM(LightningModule):
    def __init__(self, max_atomic_num=None, whole_dataset = None, unique_variables=1, multitask = False, **kwargs):
        super().__init__()
        #self.kwargs.update(kwargs.__dict__) if hasattr(kwargs, "__dict__") else self.kwargs.update(kwargs)

        if kwargs.get('model2d').model == 'GIN':
            kwargs.get('model2d').gin.virtual_node = False
            model_factory = lambda: GIN(hidden_dim=kwargs.get('hidden_dim'), 
                                        act=kwargs.get('activation'),**asdict(kwargs.get('model2d').gin))

        elif kwargs.get('model2d').model == 'GIN-VN':
            kwargs.get('model2d').gin.virtual_node = True
            model_factory = lambda: GIN(hidden_dim=kwargs.get('hidden_dim'), 
                                        act=kwargs.get('activation'),**asdict(kwargs.get('model2d').gin))
        elif kwargs.get('model2d').model == 'GPS':
            model_factory = lambda: GPS(hidden_dim=kwargs.get('hidden_dim'), 
                                        dropout=kwargs.get('dropout'), 
                                        act=kwargs.get('activation'),**asdict(kwargs.get('model2d').gps))
            pre_transform = T.AddRandomWalkPE(walk_length=kwargs.get('model2d').gps.walk_length, 
                                              attr_name='pe')

        elif kwargs.get('model2d').model == 'ChemProp':
            model_factory = lambda: ChemProp(hidden_dim=kwargs.get('hidden_dim'), 
                                             act=kwargs.get('activation'),
                                             **asdict(kwargs.get('model2d').chemprop))
            pre_transform = transform_reversely_indexed_data

        elif kwargs.get('model2d').model == 'Chytorch':
            model_factory = lambda: Chytorch(**asdict(kwargs.get('model2d').chytorch2d))
            
        self.device_= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.multitask = multitask
        self.net = Model2D(model_factory, 
                           kwargs.get('hidden_dim'), 
                           out_dim=1,
                           dropout=kwargs.get('dropout'), 
                           device='cuda', 
                           unique_variables=unique_variables).to('cuda')
        
        self.whole_dataset = whole_dataset

        if self.multitask:
            self.loss_fn = GroupedScaledMAELoss(torch.ones(self.whole_dataset.y.shape[1], 
                                                           dtype=torch.long))
        else:
            self.loss_fn = nn.MSELoss() #LOGITS #GroupedScaledMAELoss(torch.ones(4, dtype=torch.long))
        
        self.lr = kwargs.get('learning_rate')
        self.wd = kwargs.get('weight_decay')
        
        self._reset_losses_dict()
        self._reset_inference_results()
        self.save_hyperparameters(ignore=["cosine_annealing_lr","linear_warmup_cosine_annealing_lr",
                                          "model1d","model2d","model3d","model4d","modelfprf","one_cycle_lr",
                                          "reduce_lr_on_plateau","whole_dataset","device","scheduler"])

    def forward(self, batch):
        out = self.net(batch)
        return out
        
    def configure_optimizers(self):
        o = AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        s = CyclicLR(o, self.lr, 2e-4, 1000, mode='triangular', cycle_momentum=False)
        # instantiate the WeakMethod in the lr scheduler object into the custom scale function attribute
        #s._scale_fn_custom = s._scale_fn_ref()
        # remove the reference so there are no more WeakMethod references in the object
        #s._scale_fn_ref = None
        return [o], [{'scheduler': s, 'interval': 'step', 'name': 'lr_scheduler'}]

    def training_step(self, batch, batch_idx):
        pnorm = compute_pnorm(self.net)
        gnorm = compute_gnorm(self.net)
        self.log(f'(training) pnorm', pnorm)
        self.log(f'(training) gnorm', gnorm)
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")
        
    def step(self, batch, stage):
        start = time()
        if type(batch) is not list:
            batch = [batch]
        molecule_idx = batch[0].molecule_idx.to('cuda')
        dataset = self.whole_dataset.y.to('cuda')
        targets = dataset[molecule_idx].squeeze()

        with torch.set_grad_enabled(stage == "train"):
            if self.multitask:
                batch_multi = batch.copy()
                if self.whole_dataset.y.shape[1]==1:
                    print('only one property-switch to singletask')
                else:
                    for cnt, bat_i in enumerate(batch_multi):
                        for i in range(self.whole_dataset.y.shape[1]-1):
                            bat_tmp = batch_multi[cnt].batch#.detach().clone()
                            batch_multi[cnt] = batch_multi[cnt].concat(batch[cnt])
                            batch_multi[cnt].batch = torch.hstack([bat_tmp, batch[cnt].batch + (bat_tmp.max()+1)])
                
                targets_flat = targets.flatten()
                prompts = torch.tensor([i for i in range(data.dataset.y.shape[1])]*targets.shape[0],
                                      dtype=torch.int32,
                                      device=targets.device)
                for cnt, bat_i in enumerate(batch_multi):
                    batch_multi[cnt].tokens = prompts
                pred = self(batch_multi)
                loss = self.loss_fn(pred.squeeze(), targets_flat, prompts)
                if stage == "test":
                    self.inference_results['token'].append(prompts.squeeze())
                    self.inference_results['y_pred'].append(pred.squeeze())
                    self.inference_results['y_true'].append(targets_flat.squeeze())
                    return None
                r2=r2_score_torch(targets_flat.cpu(),pred.squeeze().cpu().detach())
            else:
                pred = self(batch)
                loss = self.loss_fn(pred.squeeze(), targets)
            
                if stage == "test":
                    self.inference_results['y_pred'].append(pred.squeeze())
                    self.inference_results['y_true'].append(targets.squeeze())
                    return None

                r2=r2_score_torch(targets.cpu(),pred.squeeze().cpu().detach())

            self.logging_info[f'{stage}_loss'].append(loss.item())
            self.logging_info[f'{stage}_r2'].append(r2)
            self.logging_info[f'{stage}_time'].append(time()-start)
            
            if stage == 'train':
                self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
                self.log(f'{stage}_step_loss', loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
            #model.save_checkpoint('checkpoint.pth')
            return loss
        
    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:                    
            result_dict = {
                "epoch": float(self.current_epoch),
                "train_epoch_loss": torch.tensor(self.logging_info["train_loss"]).mean().item(),
                "train_epoch_r2": torch.tensor(self.logging_info["train_r2"]).mean().item(),
                "val_epoch_loss": torch.tensor(self.logging_info["val_loss"]).mean().item(),
                "val_epoch_r2": torch.tensor(self.logging_info["val_r2"]).mean().item(),
                "train_epoch_time": sum(self.logging_info["train_time"]),
                "val_epoch_time": sum(self.logging_info["val_time"]),
                }
            self.log_dict(result_dict, logger=True, sync_dist=True)
            
        self._reset_losses_dict()
    
    def on_test_epoch_end(self) -> None:
        for key in self.inference_results.keys():
            self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)
    
    def _reset_losses_dict(self):
        self.logging_info = {
            "train_loss": [],
            "train_r2": [], 
            "train_mse": [], 
            "val_loss": [],
            "val_r2": [],
            "train_sample_size": [], 
            "val_sample_size": [],
            "train_time": [],
            "val_time": [],
        }
        
    def _reset_inference_results(self):
        self.inference_results = {'token': [],
                                  'y_pred': [],
                                  'y_true': []}

In [69]:
dataname = 'Drugs'  #['BDE','Drugs','Kraken','tmQMg],
target = 'ea' #[['BindingEnergy'],['ip', 'ea', 'chi'],['sterimol_B5', 'sterimol_L', 'sterimol_burB5', 'sterimol_burL']]'tzvp_dipole_moment'
modeltype = 'GIN'  #['GIN','GIN-VN','GPS','ChemProp',''Chytorch']

#writer = SummaryWriter(log_dir='/mnt/code/logs/')
#loader = ConfigLoader(model=Config, config='params/params_1d.json')
#config = loader()

config = Config
config.dataset = dataname
config.target = target
config.device = 'cuda:0'


#config.hidden_dim =128
#config.dropout = 0.5

#config.train_ratio = 0.7
#config.valid_ratio = 0.1
config.batch_size = 64 #256
#config.patience = 200

#config.learning_rate = 0.00001
#config.weight_decay = 1e-4
#config.scheduler = None
#config.reduce_lr_on_plateau = ReduceLROnPlateau()

#config.num_epochs = 2000
#config.cosine_annealing_lr = CosineAnnealingLR()
#config.linear_warmup_cosine_annealing_lr = LinearWarmupCosineAnnealingLR()


#not used in 1d

#config.one_cycle_lr = OneCycleLR()
#config.seed = 123
#activation = 'relu'

######3DMODEL
config.model3d.model=modeltype
config.model3d.augmentation = True

#config.model3d.schnet = SchNet()
#config.model3d.dimenet = DimeNet()
#config.model3d.dimenetplusplus = DimeNetPlusPlus()
#config.model3d.gemnet = GemNet()
#config.model3d.painn = PaiNN()
#config.model3d.clofnet = ClofNet()
#config.model3d.leftnet= LEFTNet()
#config.model3d.chytorch_discrete = ChytorchDiscrete()
#config.model3d.chytorch_conformer = ChytorchConformer()
#config.model3d.chytorch_rotary = ChytorchRotary()

In [70]:
def config_to_dict(config_class):
    config_dict = {}
    for attr_name in dir(config_class):
        if not attr_name.startswith("__") and not callable(getattr(config_class, attr_name)):
            config_dict[attr_name] = getattr(config_class, attr_name)
    return config_dict
config_dict = config_to_dict(config)

In [71]:
config_dict#['model2d'].model

{'activation': 'relu',
 'batch_size': 64,
 'cosine_annealing_lr': CosineAnnealingLR(eta_min=1e-06),
 'dataset': 'Drugs',
 'device': 'cuda:0',
 'dropout': 0.5,
 'hidden_dim': 128,
 'learning_rate': 0.001,
 'linear_warmup_cosine_annealing_lr': LinearWarmupCosineAnnealingLR(warmup_steps=200, max_epochs=2000),
 'max_num_conformers': 20,
 'max_num_molecules': None,
 'model1d': Model1D(model='LSTM', input_type='SMILES', embedding_dim=128, num_layers=4, num_heads=4),
 'model2d': Model2D(model='GIN', gin=GIN(num_layers=6, virtual_node=False), gps=GPS(num_layers=6, walk_length=20, num_heads=4), chemprop=ChemProp(num_layers=6)),
 'model3d': Model3D(model='GIN', augmentation=True, schnet=SchNet(hidden_dim=128, num_filters=5, num_interactions=6, num_gaussians=50, cutoff=10, readout='mean', dipole=False), dimenet=DimeNet(hidden_channels=128, out_channels=128, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6), dimenetplusplus=DimeNetPlusPlus(hidden_channels=128, out_channels=128, num_bloc

In [72]:
subkeys = ["dataset",
           "max_num_conformers",
           "target",
           "train_ratio",
           "valid_ratio",
           "seed",
           "model2d", #.augmentation"
           "batch_size"]
config_dict_datamodule = {}
for k,v in config_dict.items():
    if k in subkeys:
        if k=="model2d":
            config_dict_datamodule[f'{k}_model']=config_dict[k].model
        else:
            config_dict_datamodule[k]=v

In [73]:
config_dict_datamodule

{'batch_size': 64,
 'dataset': 'Drugs',
 'max_num_conformers': 20,
 'model2d_model': 'GIN',
 'seed': 123,
 'target': 'ea',
 'train_ratio': 0.7,
 'valid_ratio': 0.1}

In [74]:
data = DataModule(config_dict_datamodule, multitask = False)
data.prepare_data()
data.split_compute()

--done---
376106 training data, 107537 test data and 53939 validation data


In [75]:
data.dataset.data

Data(x=[16427415, 9], edge_index=[2, 35401392], edge_attr=[35401392, 3], pos=[16427415, 3], name=[537582], id=[537582], smiles=[537582], y=[537582, 4], molecule_idx=[537582])

In [76]:
data.dataset.data

Data(x=[16427415, 9], edge_index=[2, 35401392], edge_attr=[35401392, 3], pos=[16427415, 3], name=[537582], id=[537582], smiles=[537582], y=[537582, 4], molecule_idx=[537582])

In [77]:
#z.shape, hgs.shape, pos.shape, bat.shape

In [78]:
config_dict['model2d'].model

'GIN'

In [79]:
model = ModelLM(max_atomic_num=data.max_atomic_num, 
                whole_dataset = data.dataset, 
                unique_variables=data.unique_variables, 
                multitask = False, **config_dict)

In [80]:
#model._log_hyperparams = False

In [81]:
print(f'#PARAMS = {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

#PARAMS = 240513


In [82]:
model

ModelLM(
  (net): Model2D(
    (models): ModuleList(
      (0): GIN(
        (act): ReLU()
        (conv): ModuleList(
          (0-5): 6 x GINEConv(nn=Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
            (3): Linear(in_features=128, out_features=128, bias=True)
            (4): ReLU()
          ))
        )
        (atom_encoder): AtomEncoder(
          (atom_embedding_list): ModuleList(
            (0): Embedding(119, 128)
            (1): Embedding(7, 128)
            (2-3): 2 x Embedding(12, 128)
            (4): Embedding(10, 128)
            (5-6): 2 x Embedding(6, 128)
            (7-8): 2 x Embedding(2, 128)
          )
        )
        (bond_encoder): BondEncoder(
          (bond_embedding_list): ModuleList(
            (0): Embedding(5, 128)
            (1): Embedding(6, 128)
            (2): Embedding(2, 128)
    

In [83]:
dir_name = f"tmp_2D_{dataname}_{target}_{modeltype}_multitask_v0"

dir_load_model = None
log_dir_folder = '/mnt/code/logs/'
log_dir_folder = os.path.join(log_dir_folder, dir_name)
if os.path.exists(log_dir_folder):
    if os.path.exists(os.path.join(log_dir_folder, "last.ckpt")):
        dir_load_model = os.path.join(log_dir_folder, "last.ckpt")
    csv_path = os.path.join(log_dir_folder, "metrics.csv")
    while os.path.exists(csv_path):
        csv_path = csv_path + '.bak'
    if os.path.exists(os.path.join(log_dir_folder, "metrics.csv")):
        os.rename(os.path.join(log_dir_folder, "metrics.csv"), csv_path)

In [84]:
metric_to_monitor = "val_epoch_loss"

checkpoint_callback = ModelCheckpoint(
    dirpath=log_dir_folder,
    monitor=metric_to_monitor,
    mode = 'min',
    save_top_k=1,
    save_last=True,
    every_n_epochs=5,
    save_weights_only=True,
    verbose=True,
    filename="best-model-{epoch}-{val_epoch_loss:.4f}",
)


early_stopping = early_stop_callback = EarlyStopping(
        monitor=metric_to_monitor,  # The metric you want to monitor
        patience=config.patience,  # Number of epochs with no improvement after which training will be stopped
        verbose=True,
        mode='min'  # Minimizing the validation loss
    )

tb_logger = TensorBoardLogger(log_dir_folder, name="tensorbord")#, version="", default_hp_metric=False)
csv_logger = CSVLogger(log_dir_folder, name="", version="")

model_params = dict(
    devices=1, #args['ngpus'],
    accelerator='gpu', #args['accelerator'],
    default_root_dir=log_dir_folder, #args['log_dir'],
    logger=[tb_logger, csv_logger],
    enable_progress_bar=True)


model_params.update(dict(
    max_epochs=config.num_epochs,#1000,
    callbacks=[checkpoint_callback, early_stopping],
    #enable_checkpointing=False,
    gradient_clip_val=10,#args['clip_norm'],
    #precision="16-mixed",
))

In [85]:
torch.set_float32_matmul_precision('medium')

In [None]:
trainer = pl.Trainer(**model_params)
trainer.fit(model, datamodule=data, ckpt_path=dir_load_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | net     | Model2D | 240 K 
1 | loss_fn | MSELoss | 0     
------------------------------------
240 K     Trainable params
0         Non-trainable params
240 K     Total params
0.962     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved. New best score: 0.868


Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.859


Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.848


Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.833


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 4110: 'val_epoch_loss' reached 0.85116 (best 0.85116), saving model to '/mnt/code/logs/tmp_2D_Drugs_ea_GIN_multitask_v0/best-model-epoch=4-val_epoch_loss=0.8512.ckpt' as top 1
