In [1]:
!nvidia-smi

Sat May 11 07:29:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1D.0 Off |                    0 |
|  0%   22C    P8              19W / 300W |      0MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import torch
torch.__version__

'2.3.0+cu121'

In [3]:
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [4]:
!pip install torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html


In [5]:
!pip install ase

Defaulting to user installation because normal site-packages is not writeable


In [6]:
#!pip install --upgrade pytorch-lightning
#!pip install transformers
#!pip install lightning-bolts
#!pip install lightning-utilities

In [7]:
import os
import torch
import torch.nn.functional as F

from dataclasses import asdict
from torch.nn.utils import clip_grad_norm_
# from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from torch_geometric.loader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

from config import Config
from data.ee import EE
from data.bde import BDE
from data.drugs import Drugs
from data.kraken import Kraken
from happy_config import ConfigLoader
from loaders.samplers import EnsembleSampler, EnsembleMultiBatchSampler
from loaders.multibatch import MultiBatchLoader
from utils.early_stopping import EarlyStopping, generate_checkpoint_filename

from models.model_3d import Model3D
from models.models_3d.chiro import ChIRo
from models.models_3d.painn import PaiNN
from models.models_3d.schnet import SchNet
from models.models_3d.gemnet import GemNetT
from models.models_3d.dimenet import DimeNet, DimeNetPlusPlus
from models.models_3d.clofnet import ClofNet
from models.models_3d.leftnet import LEFTNet
#from models.models_3d.chytorch_discrete import ChytorchDiscrete
#from models.models_3d.chytorch_conformer import ChytorchConformer

import pickle
from time import time

In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
import torch.nn as nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR

In [9]:
###################################################
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.empty_cache()
#print(os.environ['CUDA_VISIBLE_DEVICES'])

seed = 42
torch.manual_seed(seed)

# If you are using a GPU, you should also set the seed for CUDA operations
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#np.random.seed(seed)
################################################## 

In [10]:
def r2_score_torch(y_true, y_pred):
    """
    Compute the R-squared score.

    Parameters:
        y_true (torch.Tensor): The true target values.
        y_pred (torch.Tensor): The predicted target values.

    Returns:
        float: The R-squared score.
    """
    # Calculate the mean of the true target values
    y_mean = torch.mean(y_true)
    # Calculate the total sum of squares (TSS)
    tss = torch.sum((y_true - y_mean) ** 2)
    # Calculate the residual sum of squares (RSS)
    rss = torch.sum((y_true - y_pred) ** 2)
    # Calculate R-squared score
    r2 = 1 - rss / tss

    return r2.item()

In [11]:
class DataModule(LightningDataModule):
    def __init__(self, hparams, dataset=None):
        super().__init__()
        self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
        self._saved_dataloaders = dict()
        self.dataset = dataset
        
        if self.dataset is None:
            self.variable_name = None
            unique_variables = 1

            if self.hparams.dataset == 'Drugs':
                dataset = Drugs('/mnt/data/MARCEL/datasets/Drugs', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'Kraken':
                dataset = Kraken('/mnt/data/MARCEL/datasets/Kraken', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'BDE':
                dataset = BDE('/mnt/data/MARCEL/datasets/BDE').shuffle()
                self.variable_name = 'is_ligand'
                unique_variables = 2
            elif self.hparams.dataset == 'EE':
                dataset = EE('/mnt/data/MARCEL/datasets/EE', max_num_conformers=self.hparams.max_num_conformers).shuffle()
                self.variable_name = 'config_id'
                unique_variables = 2
            
            #autoscaling
            target_id = dataset.descriptors.index(self.hparams.target)
            labels = dataset.y[:, target_id]
            mean = dataset.y.mean(dim=0, keepdim=True)
            std = dataset.y.std(dim=0, keepdim=True)
            dataset.y = ((dataset.y - mean) / std).to('cuda')
            mean = mean.to('cuda')
            std = std.to('cuda')
            #data.dataset.data.y = data.dataset.y
            
            self.dataset = dataset
            self.max_atomic_num = self.dataset.data.x[:, 0].max().item() + 1
            self.unique_variables = unique_variables
            print('--done---')

    def split_compute(self):

        split = self.dataset.get_idx_split(train_ratio=self.hparams.train_ratio, 
                                      valid_ratio=self.hparams.valid_ratio, 
                                      seed=self.hparams.seed)
        self.train_dataset = self.dataset[split['train']]
        self.valid_dataset = self.dataset[split['valid']]
        self.test_dataset = self.dataset[split['test']]

        print(f'{len(self.train_dataset)} training data, {len(self.test_dataset)} test data and {len(self.valid_dataset)} validation data')

    def train_dataloader(self):
        return self._get_dataloader(self.train_dataset, "train")

    def val_dataloader(self):
        return self._get_dataloader(self.valid_dataset, "val")

    def test_dataloader(self):
        return self._get_dataloader(self.test_dataset, "test")
    
    def _get_dataloader(self, dataset, stage, store_dataloader=True):
        store_dataloader = store_dataloader
        
        if stage in self._saved_dataloaders and store_dataloader:
            return self._saved_dataloaders[stage]

        if self.hparams.model3d_augmentation:
            strategy = 'random'
        else:
            strategy = 'first'
            
        if stage == "train":
            shuffle=True                              
        else:
            shuffle=False
            if stage == "train"=='test':
                strategy = 'first'

        if self.variable_name is None:
            dl = DataLoader(dataset, batch_sampler=EnsembleSampler(dataset, 
                                                                   batch_size=self.hparams.batch_size, 
                                                                   strategy=strategy, 
                                                                   shuffle=shuffle),
                           num_workers=20)
        else:
            dl = MultiBatchLoader(dataset, batch_sampler=EnsembleMultiBatchSampler(dataset, 
                                                                                   batch_size=self.hparams.batch_size, 
                                                                                   strategy=strategy, 
                                                                                   shuffle=shuffle, 
                                                                                   variable_name=self.variable_name),
                                 )
        if store_dataloader:
            self._saved_dataloaders[stage] = dl
        return dl

In [12]:
import math
def compute_pnorm(model: nn.Module) -> float:
    """
    Computes the norm of the parameters of a model.
    :param model: A PyTorch model.
    :return: The norm of the parameters of the model.
    """
    return math.sqrt(sum([p.norm().item() ** 2 for p in model.parameters() if p.requires_grad]))


def compute_gnorm(model: nn.Module) -> float:
    """
    Computes the norm of the gradients of a model.
    :param model: A PyTorch model.
    :return: The norm of the gradients of the model.
    """
    return math.sqrt(sum([p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None]))

In [13]:
class ModelLM(LightningModule):
    def __init__(self, max_atomic_num=None, whole_dataset = None, unique_variables=1, **kwargs):
        super().__init__()
        #self.kwargs.update(kwargs.__dict__) if hasattr(kwargs, "__dict__") else self.kwargs.update(kwargs)
        print(kwargs.get('model3d'))
        if kwargs.get('model3d').model == 'SchNet':
            model_factory = lambda: SchNet(max_atomic_num=max_atomic_num, 
                                           **asdict(kwargs.get('model3d').schnet))
        elif kwargs.get('model3d').model == 'DimeNet':
            model_factory = lambda: DimeNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model3d').dimenet))
        elif kwargs.get('model3d').model == 'DimeNet++':
            model_factory = lambda: DimeNetPlusPlus(max_atomic_num=max_atomic_num, 
                                                    **asdict(kwargs.get('model3d').dimenetplusplus))
        elif kwargs.get('model3d').model == 'GemNet':
            model_factory = lambda: GemNetT(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model3d').gemnet))
        elif kwargs.get('model3d').model == 'ChIRo':
            model_factory = lambda: ChIRo(**asdict(kwargs.get('model3d').chiro))
            
        elif kwargs.get('model3d').model == 'PaiNN':
            model_factory = lambda: PaiNN(max_atomic_num=max_atomic_num, 
                                          **asdict(kwargs.get('model3d').painn))
        elif kwargs.get('model3d').model == 'ClofNet':
            model_factory = lambda: ClofNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model3d').clofnet))
        elif kwargs.get('model3d').model == 'LEFTNet':
            model_factory = lambda: LEFTNet(max_atomic_num=max_atomic_num, 
                                            **asdict(kwargs.get('model3d').leftnet))
        elif kwargs.get('model3d').model == 'ChytorchDiscrete':
            model_factory = lambda: ChytorchDiscrete(max_neighbors=max_atomic_num, 
                                                     **asdict(kwargs.get('model3d').chytorch_discrete))
        elif kwargs.get('model3d').model == 'ChytorchConformer':
            model_factory = lambda: ChytorchConformer(**asdict(kwargs.get('model3d').chytorch_conformer))
            
        elif kwargs.get('model3d').model == 'ChytorchRotary':
            model_factory = lambda: ChytorchRotary(max_neighbors=max_atomic_num, 
                                                   **asdict(kwargs.get('model3d').chytorch_rotary))
        self.device_= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = Model3D(model_factory, 
                           hidden_dim=kwargs.get('hidden_dim'), 
                           out_dim=1,
                           unique_variables=unique_variables, 
                           device='cuda').to('cuda')
        
        self.loss_fn = nn.MSELoss() #LOGITS #GroupedScaledMAELoss(torch.ones(4, dtype=torch.long))
        
        self.lr = kwargs.get('learning_rate')
        self.wd = kwargs.get('learning_rate')
        self.whole_dataset = whole_dataset

        self._reset_losses_dict()
        self._reset_inference_results()
        self.save_hyperparameters(ignore=["cosine_annealing_lr","linear_warmup_cosine_annealing_lr",
                                          "model1d","model2d","model3d","model4d","modelfprf","one_cycle_lr",
                                          "reduce_lr_on_plateau","whole_dataset","device","scheduler"])

    def forward(self, batch):
        out = self.net(batch)
        return out
        
    def configure_optimizers(self):
        o = AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        s = CyclicLR(o, self.lr, 2e-4, 1000, mode='triangular', cycle_momentum=False)
        # instantiate the WeakMethod in the lr scheduler object into the custom scale function attribute
        #s._scale_fn_custom = s._scale_fn_ref()
        # remove the reference so there are no more WeakMethod references in the object
        #s._scale_fn_ref = None
        return [o], [{'scheduler': s, 'interval': 'step', 'name': 'lr_scheduler'}]

    def training_step(self, batch, batch_idx):
        pnorm = compute_pnorm(self.net)
        gnorm = compute_gnorm(self.net)
        self.log(f'(training) pnorm', pnorm)
        self.log(f'(training) gnorm', gnorm)
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")
        
    def step(self, batch, stage):
        start = time()
        if type(batch) is not list:
            batch = [batch]
        molecule_idx = batch[0].molecule_idx.to('cuda')
        dataset = self.whole_dataset.y.to('cuda')
        targets = dataset[molecule_idx].squeeze()

        with torch.set_grad_enabled(stage == "train"):
            pred = self(batch)
            loss = self.loss_fn(pred.squeeze(), targets)
            
            if stage == "test":
                self.inference_results['y_pred'].append(pred.squeeze())
                self.inference_results['y_true'].append(targets.squeeze())
                return None

            r2=r2_score_torch(targets.cpu(),pred.squeeze().cpu().detach())

            self.logging_info[f'{stage}_loss'].append(loss.item())
            self.logging_info[f'{stage}_r2'].append(r2)
            self.logging_info[f'{stage}_time'].append(time()-start)
            
            if stage == 'train':
                self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
                self.log(f'{stage}_step_loss', loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
            #model.save_checkpoint('checkpoint.pth')
            return loss
        
    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:                    
            result_dict = {
                "epoch": float(self.current_epoch),
                "train_epoch_loss": torch.tensor(self.logging_info["train_loss"]).mean().item(),
                "train_epoch_r2": torch.tensor(self.logging_info["train_r2"]).mean().item(),
                "val_epoch_loss": torch.tensor(self.logging_info["val_loss"]).mean().item(),
                "val_epoch_r2": torch.tensor(self.logging_info["val_r2"]).mean().item(),
                "train_epoch_time": sum(self.logging_info["train_time"]),
                "val_epoch_time": sum(self.logging_info["val_time"]),
                }
            self.log_dict(result_dict, logger=True, sync_dist=True)
            
        self._reset_losses_dict()
    
    def on_test_epoch_end(self) -> None:
        for key in self.inference_results.keys():
            self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)
    
    def _reset_losses_dict(self):
        self.logging_info = {
            "train_loss": [],
            "train_r2": [], 
            "train_mse": [], 
            "val_loss": [],
            "val_r2": [],
            "train_sample_size": [], 
            "val_sample_size": [],
            "train_time": [],
            "val_time": [],
        }
        
    def _reset_inference_results(self):
        self.inference_results = {'y_pred': [],
                                  'y_true': []}

In [14]:
dataname = 'BDE'
target = 'BindingEnergy'
modeltype = 'PaiNN'

writer = SummaryWriter(log_dir='/mnt/code/logs/')
#loader = ConfigLoader(model=Config, config='params/params_1d.json')
#config = loader()

config = Config
config.dataset = dataname
config.target = target
config.device = 'cuda:0'


#config.hidden_dim =128
#config.dropout = 0.5

#config.train_ratio = 0.7
#config.valid_ratio = 0.1
#config.batch_size = 256
#config.patience = 200

#config.learning_rate = 0.00001
#config.weight_decay = 1e-4
#config.scheduler = None
#config.reduce_lr_on_plateau = ReduceLROnPlateau()

#config.num_epochs = 2000
#config.cosine_annealing_lr = CosineAnnealingLR()
#config.linear_warmup_cosine_annealing_lr = LinearWarmupCosineAnnealingLR()


#not used in 1d

#config.one_cycle_lr = OneCycleLR()
#config.seed = 123
#activation = 'relu'

######3DMODEL
config.model3d.model=modeltype
config.model3d.augmentation = True

#config.model3d.schnet = SchNet()
#config.model3d.dimenet = DimeNet()
#config.model3d.dimenetplusplus = DimeNetPlusPlus()
#config.model3d.gemnet = GemNet()
#config.model3d.painn = PaiNN()
#config.model3d.clofnet = ClofNet()
#config.model3d.leftnet= LEFTNet()
#config.model3d.chytorch_discrete = ChytorchDiscrete()
#config.model3d.chytorch_conformer = ChytorchConformer()
#config.model3d.chytorch_rotary = ChytorchRotary()

In [15]:
def config_to_dict(config_class):
    config_dict = {}
    for attr_name in dir(config_class):
        if not attr_name.startswith("__") and not callable(getattr(config_class, attr_name)):
            #if attr_name=='model1d':
             #   attr_val = getattr(config_class, attr_name)
             #   for subattr, subvalue in zip(['model', 'input_type', 'embedding_dim', 'num_layers', 'num_heads'],
             #                                [attr_val.model, attr_val.input_type, attr_val.embedding_dim, 
             #                                 attr_val.num_layers, attr_val.num_heads]):
             #       config_dict[f'{attr_name}_{subattr}'] = subvalue
            #elif attr_name in ['model2d','model3d','model4d','modelfprf',
            #                   'linear_warmup_cosine_annealing_lr','cosine_annealing_lr',
            #                   'reduce_lr_on_plateau','one_cycle_lr']:
            #    pass
            #else:
            config_dict[attr_name] = getattr(config_class, attr_name)
    return config_dict
config_dict = config_to_dict(config)

In [16]:
subkeys = ["dataset",
           "max_num_conformers",
           "target",
           "train_ratio",
           "valid_ratio",
           "seed",
           "model3d", #.augmentation"
           "batch_size"]
config_dict_datamodule = {}
for k,v in config_dict.items():
    if k in subkeys:
        if k=="model3d":
            config_dict_datamodule[f'{k}_augmentation']=config_dict[k].augmentation
        else:
            config_dict_datamodule[k]=v

In [17]:
config_dict_datamodule

{'batch_size': 256,
 'dataset': 'BDE',
 'max_num_conformers': 20,
 'model3d_augmentation': True,
 'seed': 123,
 'target': 'BindingEnergy',
 'train_ratio': 0.7,
 'valid_ratio': 0.1}

In [18]:
data = DataModule(config_dict_datamodule)
data.prepare_data()
data.split_compute()



--done---
79637 training data, 22623 test data and 11838 validation data


In [20]:
model = ModelLM(max_atomic_num=data.max_atomic_num, 
                whole_dataset = data.dataset, 
                unique_variables=data.unique_variables, **config_dict)

Model3D(model='PaiNN', augmentation=True, schnet=SchNet(hidden_dim=128, num_filters=5, num_interactions=6, num_gaussians=50, cutoff=10, readout='mean', dipole=False), dimenet=DimeNet(hidden_channels=128, out_channels=128, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6), dimenetplusplus=DimeNetPlusPlus(hidden_channels=128, out_channels=128, num_blocks=4, int_emb_size=64, basis_emb_size=8, out_emb_channels=256, num_spherical=7, num_radial=6), gemnet=GemNet(num_spherical=7, num_radial=6, num_blocks=4, emb_size_atom=128, emb_size_edge=128, emb_size_trip=64, emb_size_rbf=16, emb_size_cbf=16, emb_size_bil_trip=64, num_before_skip=1, num_after_skip=1, num_concat=1, num_atoms=1, num_atom=2, bond_feat_dim=0), painn=PaiNN(hidden_dim=128, num_interactions=6, num_rbf=64, cutoff=12.0, readout='add', shared_interactions=False, shared_filters=False), clofnet=ClofNet(cutoff=6.5, num_layers=6, hidden_channels=128, num_radial=32), leftnet=LEFTNet(cutoff=6.5, num_layers=6, hidden_channels=12

In [21]:
#model._log_hyperparams = False

In [22]:
print(f'#PARAMS = {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

#PARAMS = 2494977


In [23]:
model

ModelLM(
  (net): Model3D(
    (models): ModuleList(
      (0-1): 2 x PaiNN(
        (cutoff_fn): CosineCutoff()
        (radial_basis): GaussianRBF()
        (embedding): Embedding(79, 128, padding_idx=0)
        (filter_net): Dense(
          in_features=64, out_features=2304, bias=True
          (activation): Identity()
        )
        (interactions): ModuleList(
          (0-5): 6 x PaiNNInteraction(
            (interatomic_context_net): Sequential(
              (0): Dense(in_features=128, out_features=128, bias=True)
              (1): Dense(
                in_features=128, out_features=384, bias=True
                (activation): Identity()
              )
            )
          )
        )
        (mixing): ModuleList(
          (0-5): 6 x PaiNNMixing(
            (intraatomic_context_net): Sequential(
              (0): Dense(in_features=256, out_features=128, bias=True)
              (1): Dense(
                in_features=128, out_features=384, bias=True
               

In [25]:
dir_name = f"tmp_3d_v4"

dir_load_model = None
log_dir_folder = '/mnt/code/logs/'
log_dir_folder = os.path.join(log_dir_folder, dir_name)
if os.path.exists(log_dir_folder):
    if os.path.exists(os.path.join(log_dir_folder, "last.ckpt")):
        dir_load_model = os.path.join(log_dir_folder, "last.ckpt")
    csv_path = os.path.join(log_dir_folder, "metrics.csv")
    while os.path.exists(csv_path):
        csv_path = csv_path + '.bak'
    if os.path.exists(os.path.join(log_dir_folder, "metrics.csv")):
        os.rename(os.path.join(log_dir_folder, "metrics.csv"), csv_path)

In [27]:
metric_to_monitor = "val_epoch_loss"

checkpoint_callback = ModelCheckpoint(
    dirpath=log_dir_folder,
    monitor=metric_to_monitor,
    mode = 'min',
    save_top_k=1,
    save_last=True,
    every_n_epochs=5,
    save_weights_only=True,
    verbose=True,
    filename="best-model-{epoch}-{val_epoch_loss:.4f}",
)


early_stopping = early_stop_callback = EarlyStopping(
        monitor=metric_to_monitor,  # The metric you want to monitor
        patience=config.patience,  # Number of epochs with no improvement after which training will be stopped
        verbose=True,
        mode='min'  # Minimizing the validation loss
    )

tb_logger = TensorBoardLogger(log_dir_folder, name="tensorbord")#, version="", default_hp_metric=False)
csv_logger = CSVLogger(log_dir_folder, name="", version="")

model_params = dict(
    devices=1, #args['ngpus'],
    accelerator='gpu', #args['accelerator'],
    default_root_dir=log_dir_folder, #args['log_dir'],
    logger=[tb_logger, csv_logger],
    enable_progress_bar=True)


model_params.update(dict(
    max_epochs=config.num_epochs,#1000,
    callbacks=[checkpoint_callback, early_stopping],
    #enable_checkpointing=False,
    gradient_clip_val=10,#args['clip_norm'],
    #precision="16-mixed",
))

In [28]:
torch.set_float32_matmul_precision('medium')

In [None]:
trainer = pl.Trainer(**model_params)
trainer.fit(model, datamodule=data, ckpt_path=dir_load_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /mnt/code/logs/tmp_3d_v4/tensorbord
/usr/local/lib/python3.10/dist-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory /mnt/code/logs/tmp_3d_v4/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /mnt/code/logs/tmp_3d_v4 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | net     | Model3D | 2.5 M 
1 | loss_fn | MSELoss | 0     
------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
9.980     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=191` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=191` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_epoch_loss improved. New best score: 37.467


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_epoch_loss improved by 21.977 >= min_delta = 0.0. New best score: 15.490


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_epoch_loss improved by 13.343 >= min_delta = 0.0. New best score: 2.147


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_epoch_loss improved by 0.839 >= min_delta = 0.0. New best score: 1.308


In [77]:
dir(checkpoint_callback)

['CHECKPOINT_EQUALS_CHAR',
 'CHECKPOINT_JOIN_CHAR',
 'CHECKPOINT_NAME_LAST',
 'FILE_EXTENSION',
 'STARTING_VERSION',
 '_ModelCheckpoint__init_ckpt_dir',
 '_ModelCheckpoint__init_monitor_mode',
 '_ModelCheckpoint__init_triggers',
 '_ModelCheckpoint__resolve_ckpt_dir',
 '_ModelCheckpoint__validate_init_configuration',
 '_ModelCheckpoint__warn_if_dir_not_empty',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_enable_version_counter',
 '_every_n_epochs',
 '_every_n_train_steps',
 '_find_last_checkpoints',
 '_format_checkpoint_name',
 '_fs',
 '_generate_state_key',
 '_get_metric_interpolated_filepath_name',
 '_last_checkpoint_saved',
 '_last_global_step_saved',
 '_last_ti

In [78]:
monitor_candidates = checkpoint_callback._monitor_candidates(trainer)

In [79]:
#checkpoint_callback.on_train_epoch_end(trainer, trainer.lightning_module)

In [80]:
#checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)

In [81]:
filepath = checkpoint_callback.format_checkpoint_name(monitor_candidates, checkpoint_callback.CHECKPOINT_NAME_LAST)

In [82]:
filepath

'/mnt/code/logs/tmp_3d_v3/last.ckpt'

In [83]:
checkpoint_callback.save_weights_only

True

In [85]:
#checkpoint_callback._save_checkpoint(trainer, filepath)

In [87]:
#trainer.save_checkpoint(filepath, checkpoint_callback.save_weights_only)

In [88]:
checkpoint = trainer._checkpoint_connector.dump_checkpoint(True)

In [89]:
checkpoint

{'epoch': 0,
 'global_step': 17,
 'pytorch-lightning_version': '2.2.4',
 'state_dict': OrderedDict([('net.models.0.cutoff_fn.cutoff', tensor([12.])),
              ('net.models.0.radial_basis.widths',
               tensor([0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905, 0.1905,
                       0.1905])),
              ('net.models.0.radial_basis.offsets',
               tensor([ 0.0000,  0.1905,  0.3810

In [90]:
import torch
import io

def identify_non_picklable_objects(checkpoint):
    non_picklable_objects = []
    for key, value in checkpoint.items():
        try:
            # Attempt to pickle the object
            torch.save(value, io.BytesIO())
        except Exception as e:
            # If an exception occurs, it means the object is not picklable
            non_picklable_objects.append((key, type(value)))
            print(f"Object at key '{key}' of type '{type(value)}' cannot be pickled.")
            print(f"Exception: {e}")
    return non_picklable_objects

# Identify non-picklable objects
non_picklable_objects = identify_non_picklable_objects(checkpoint)

# Print out non-picklable objects
print("\nNon-picklable objects:")
for key, obj_type in non_picklable_objects:
    print(f"- Key: '{key}', Type: '{obj_type}'")

Object at key 'datamodule_hyper_parameters' of type '<class 'dict'>' cannot be pickled.
Exception: cannot pickle 'getset_descriptor' object

Non-picklable objects:
- Key: 'datamodule_hyper_parameters', Type: '<class 'dict'>'


In [91]:
for key, obj_type in checkpoint.items():
    print(f"- Key: '{key}', Type: '{type(obj_type)}'")

- Key: 'epoch', Type: '<class 'int'>'
- Key: 'global_step', Type: '<class 'int'>'
- Key: 'pytorch-lightning_version', Type: '<class 'str'>'
- Key: 'state_dict', Type: '<class 'collections.OrderedDict'>'
- Key: 'loops', Type: '<class 'dict'>'
- Key: 'hparams_name', Type: '<class 'str'>'
- Key: 'hyper_parameters', Type: '<class 'dict'>'
- Key: 'datamodule_hyper_parameters', Type: '<class 'dict'>'


In [92]:
checkpoint['datamodule_hyper_parameters']

{'__module__': 'config',
 '__annotations__': {'dataset': str,
  'target': str,
  'max_num_molecules': int,
  'max_num_conformers': int,
  'train_ratio': float,
  'valid_ratio': float,
  'batch_size': int,
  'hidden_dim': int,
  'num_epochs': int,
  'patience': int,
  'activation': str,
  'seed': int,
  'device': str,
  'dropout': float,
  'scheduler': str,
  'reduce_lr_on_plateau': config.ReduceLROnPlateau,
  'cosine_annealing_lr': config.CosineAnnealingLR,
  'linear_warmup_cosine_annealing_lr': config.LinearWarmupCosineAnnealingLR,
  'one_cycle_lr': config.OneCycleLR,
  'learning_rate': float,
  'weight_decay': float,
  'modelfprf': configs.model_fp_rf.ModelFPRF,
  'model1d': configs.model_1d.Model1D,
  'model2d': configs.model_2d.Model2D,
  'model3d': configs.model_3d.Model3D,
  'model4d': configs.model_4d.Model4D},
 'dataset': 'BDE',
 'target': 'BindingEnergy',
 'max_num_molecules': None,
 'max_num_conformers': 20,
 'train_ratio': 0.7,
 'valid_ratio': 0.1,
 'batch_size': 256,
 'hidd

In [None]:
for epoch in range(config.num_epochs):
    loss = train(train_loader)
    if scheduler is not None:
        scheduler.step(loss)
    valid_error = eval(valid_loader)

    early_stopping(valid_error, model)
    if early_stopping.counter == 0:
        test_error = eval(test_loader)
    if early_stopping.early_stop:
        print('Early stopping...')
        break

    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/train', loss, epoch)
    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/valid', valid_error, epoch)
    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/test', test_error, epoch)
    print(f'Progress: {epoch}/{config.num_epochs}/{loss:.5f}/{valid_error:.5f}/{test_error:.5f}')
    res_dict[dataname][algo][d]['loss'].append(loss)
    res_dict[dataname][algo][d]['valid_error'].append(valid_error)
    res_dict[dataname][algo][d]['test_error'].append(test_error)

model.load_state_dict(torch.load(checkpoint_path))
test_error = eval(test_loader)
test_r2 = eval_r2(test_loader)
print(f'Best validation error: {-early_stopping.best_score:.7f}')
print(f'Test error: {test_error:.7f}')
print(f'Test r2: {test_r2:.7f}')

res_dict[dataname][algo][d]['Test error'] = test_error
res_dict[dataname][algo][d]['Test r2'] = test_r2
res_dict[dataname][algo][d]['checkpoint']=checkpoint_path


#os.remove(checkpoint_path)
writer.close()

import pickle
with open("res_dict_1d.pkl", "wb") as f:
pickle.dump(res_dict, f)