In [1]:
!nvidia-smi

Tue May 14 08:58:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1D.0 Off |                    0 |
|  0%   28C    P8              30W / 300W |      0MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
!pip install torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
!pip install --upgrade transformers

Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html
Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html
Defaulting to user installation because normal site-packages is not writeable


In [3]:
import os
import torch
import torch.nn.functional as F

from dataclasses import asdict
from transformers import RobertaTokenizer
from torch.utils.data import Dataset, DataLoader, random_split
#from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

from config import Config
from data.ee import EE
from data.bde import BDE
from data.drugs import Drugs
from data.kraken import Kraken
from happy_config import ConfigLoader
from utils.early_stopping import EarlyStopping, generate_checkpoint_filename

from models.model_1d import LSTM, Transformer
from models.models_1d.utils import construct_fingerprint, construct_smiles, concatenate_smiles
#from train_1d import *
from time import time

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
import torch.nn as nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR


In [5]:
###################################################
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.empty_cache()
#print(os.environ['CUDA_VISIBLE_DEVICES'])

seed = 42
torch.manual_seed(seed)

# If you are using a GPU, you should also set the seed for CUDA operations
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#np.random.seed(seed)
################################################## 

In [6]:
class Molecules(Dataset):
    def __init__(self, smiles_ids, attention_masks, labels, fingerprint=None, input_type='smiles'):
        self.smiles_ids = smiles_ids
        self.attention_masks = attention_masks
        self.labels = labels
        self.fingerprint = fingerprint
        self.input_type = input_type

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        if self.input_type == 'SMILES':
            smiles = self.smiles_ids[index]
            attention_mask = self.attention_masks[index]
            y = self.labels[index]
            return smiles, attention_mask, y.clone()
        else:
            fingerprint = self.fingerprint[index]
            y = self.labels[index]
            return torch.tensor(fingerprint, dtype=torch.long), y.clone()

def r2_score_torch(y_true, y_pred):
    """
    Compute the R-squared score.

    Parameters:
        y_true (torch.Tensor): The true target values.
        y_pred (torch.Tensor): The predicted target values.

    Returns:
        float: The R-squared score.
    """
    # Calculate the mean of the true target values
    y_mean = torch.mean(y_true)
    # Calculate the total sum of squares (TSS)
    tss = torch.sum((y_true - y_mean) ** 2)
    # Calculate the residual sum of squares (RSS)
    rss = torch.sum((y_true - y_pred) ** 2)
    # Calculate R-squared score
    r2 = 1 - rss / tss

    return r2.item()

In [7]:
class DataModule(LightningDataModule):
    def __init__(self, hparams, dataset=None):
        super().__init__()
        self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
        self._saved_dataloaders = dict()
        self.dataset = dataset
        
        if self.dataset is None:
            variable_name = None
            unique_variables = 1
            
            if self.hparams.dataset == 'Drugs':
                dataset = Drugs('/mnt/data/MARCEL/datasets/Drugs', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'Kraken':
                dataset = Kraken('/mnt/data/MARCEL/datasets/Kraken', max_num_conformers=self.hparams.max_num_conformers).shuffle()
            elif self.hparams.dataset == 'BDE':
                dataset = BDE('/mnt/data/MARCEL/datasets/BDE').shuffle()
                variable_name = 'is_ligand'
                unique_variables = 2
            elif self.hparams.dataset == 'EE':
                dataset = EE('/mnt/data/MARCEL/datasets/EE', max_num_conformers=self.hparams.max_num_conformers).shuffle()
                variable_name = 'config_id'
                unique_variables = 2
            
            #autoscaling
            target_id = dataset.descriptors.index(self.hparams.target)
            labels = dataset.y[:, target_id]
            mean = labels.mean(dim=0).item()
            std = labels.std(dim=0).item()
            labels = (labels - mean) / std
            
            if variable_name is not None:
                smiles = concatenate_smiles(dataset, variable_name)
            else:
                smiles = construct_smiles(dataset)
            fingerprint = construct_fingerprint(smiles) if self.hparams.model1d_input_type == 'Fingerprint' else None

            tokenizer = RobertaTokenizer.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
            dicts = tokenizer(smiles, return_tensors='pt', padding='longest')
            smiles_ids, attention_masks = dicts['input_ids'], dicts['attention_mask']
            vocab_size = tokenizer.vocab_size if self.hparams.model1d_input_type == 'SMILES' else fingerprint.shape[1]

            dataset = Molecules(smiles_ids, attention_masks, labels, fingerprint, input_type=self.hparams.model1d_input_type)

            self.dataset = dataset
            self.vocab_size=vocab_size
            self.tokenizer=tokenizer
            self.smiles_ids=smiles_ids
            #modelnet = model.to(device)

            print('--done---')

    def split_compute(self):
        train_ratio = self.hparams.train_ratio
        valid_ratio = self.hparams.valid_ratio
        test_ratio = 1 - train_ratio - valid_ratio

        train_len = int(train_ratio * len(self.dataset))
        valid_len = int(valid_ratio * len(self.dataset))
        test_len = len(self.dataset) - train_len - valid_len

        self.train_dataset, self.valid_dataset, self.test_dataset = random_split(self.dataset, lengths=[train_len, valid_len, test_len])
        print(f'{len(self.train_dataset)} training data, {len(self.test_dataset)} test data and {len(self.valid_dataset)} validation data')

    def train_dataloader(self):
        return self._get_dataloader(self.train_dataset, "train")

    def val_dataloader(self):
        return self._get_dataloader(self.valid_dataset, "val")

    def test_dataloader(self):
        return self._get_dataloader(self.test_dataset, "test")
    
    def _get_dataloader(self, dataset, stage, store_dataloader=True):
        store_dataloader = store_dataloader
        
        if stage in self._saved_dataloaders and store_dataloader:
            return self._saved_dataloaders[stage]

        if stage == "train":
            dl = DataLoader(dataset=dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers = 20)                                  
        else:
            dl = DataLoader(dataset=dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers = 20) 

        if store_dataloader:
            self._saved_dataloaders[stage] = dl
        return dl

In [8]:
import math
def compute_pnorm(model: nn.Module) -> float:
    """
    Computes the norm of the parameters of a model.
    :param model: A PyTorch model.
    :return: The norm of the parameters of the model.
    """
    return math.sqrt(sum([p.norm().item() ** 2 for p in model.parameters() if p.requires_grad]))


def compute_gnorm(model: nn.Module) -> float:
    """
    Computes the norm of the gradients of a model.
    :param model: A PyTorch model.
    :return: The norm of the gradients of the model.
    """
    return math.sqrt(sum([p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None]))

In [9]:
class Model1D(LightningModule):
    def __init__(self, vocab_size=None, tokenizer=None, smiles_ids=None, **hparams):
        super().__init__()
        #self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams)
        
        if hparams.get('model1d').model == 'LSTM':
            self.net = LSTM(
                vocab_size, hparams.get('hidden_dim'), hparams.get('hidden_dim'), 1,
                hparams.get('model1d').num_layers, hparams.get('dropout'), padding_idx=tokenizer.pad_token_id)
        elif hparams.get('model1d').model == 'Transformer':
            self.net = Transformer(
                vocab_size, hparams.get('model1d').embedding_dim, smiles_ids.shape[1],
                hparams.get('model1d').num_heads, hparams.get('hidden_dim'), 1,
                hparams.get('model1d').num_layers, hparams.get('dropout'), padding_idx=tokenizer.pad_token_id)
                
        self.loss_fn = nn.MSELoss() #LOGITS #GroupedScaledMAELoss(torch.ones(4, dtype=torch.long))
        self.device_= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_type = hparams.get('model1d').input_type
        self.lr = hparams.get('learning_rate')
        self.wd = hparams.get('weight_decay')

        self._reset_losses_dict()
        self._reset_inference_results()
        self.save_hyperparameters(ignore=["cosine_annealing_lr","linear_warmup_cosine_annealing_lr",
                                          "model1d","model2d","model3d","model4d","modelfprf","one_cycle_lr",
                                          "reduce_lr_on_plateau","whole_dataset","device","scheduler"])

    def forward(self, batch):
        if self.input_type == 'SMILES':
            input_ids, attention_mask, y = batch
            if isinstance(self.net, Transformer):
                out = self.net(input_ids, attention_mask)
            else:
                out = self.net(input_ids)
        else:
            fingerprints, y = batch
            out = self.net(fingerprints)
        return out, y
        
    def configure_optimizers(self):
        o = AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        s = CyclicLR(o, self.lr, 2e-4, 1000, mode='triangular', cycle_momentum=False)
        # instantiate the WeakMethod in the lr scheduler object into the custom scale function attribute
        #s._scale_fn_custom = s._scale_fn_ref()
        # remove the reference so there are no more WeakMethod references in the object
        #s._scale_fn_ref = None
        return [o], [{'scheduler': s, 'interval': 'step', 'name': 'lr_scheduler'}]

    def training_step(self, batch, batch_idx):
        pnorm = compute_pnorm(self.net)
        gnorm = compute_gnorm(self.net)
        self.log(f'(training) pnorm', pnorm)
        self.log(f'(training) gnorm', gnorm)
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")
        
    def step(self, batch, stage):
        start = time()

        with torch.set_grad_enabled(stage == "train"):
            pred, targets = self(batch)
            loss = self.loss_fn(pred.squeeze(), targets)
            
            if stage == "test":
                self.inference_results['y_pred'].append(pred.squeeze())
                self.inference_results['y_true'].append(targets.squeeze())
                return None

            r2=r2_score_torch(targets.cpu(),pred.squeeze().cpu().detach())

            self.logging_info[f'{stage}_loss'].append(loss.item())
            self.logging_info[f'{stage}_r2'].append(r2)
            self.logging_info[f'{stage}_time'].append(time()-start)
            
            if stage == 'train':
                self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
                self.log(f'{stage}_step_loss', loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)

            return loss
        
    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:                    
            result_dict = {
                "epoch": float(self.current_epoch),
                "train_epoch_loss": torch.tensor(self.logging_info["train_loss"]).mean().item(),
                "train_epoch_r2": torch.tensor(self.logging_info["train_r2"]).mean().item(),
                "val_epoch_loss": torch.tensor(self.logging_info["val_loss"]).mean().item(),
                "val_epoch_r2": torch.tensor(self.logging_info["val_r2"]).mean().item(),
                "train_epoch_time": sum(self.logging_info["train_time"]),
                "val_epoch_time": sum(self.logging_info["val_time"]),
                }
            self.log_dict(result_dict, logger=True, sync_dist=True)
            
        self._reset_losses_dict()
    
    def on_test_epoch_end(self) -> None:
        for key in self.inference_results.keys():
            self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)
    
    def _reset_losses_dict(self):
        self.logging_info = {
            "train_loss": [],
            "train_r2": [], 
            "train_mse": [], 
            "val_loss": [],
            "val_r2": [],
            "train_sample_size": [], 
            "val_sample_size": [],
            "train_time": [],
            "val_time": [],
        }
        
    def _reset_inference_results(self):
        self.inference_results = {'y_pred': [],
                                  'y_true': []}

In [10]:
dataname = 'BDE'
target = 'BindingEnergy'
input_type = 'SMILES'
modeltype = 'LSTM'

writer = SummaryWriter(log_dir='/mnt/code/logs/')
#loader = ConfigLoader(model=Config, config='params/params_1d.json')
#config = loader()

config = Config
config.dataset = dataname
config.target = target
config.device = 'cuda:0'

#config.max_num_conformers = 20
config.model1d.input_type = input_type
config.model1d.model = modeltype
#config.model1d.num_layers = 4

#Specific for Transformer
#config.model1d.embedding_dim = 128
#config.model1d.num_heads = 4

#config.hidden_dim =128
#config.dropout = 0.5

#config.train_ratio = 0.7
#config.valid_ratio = 0.1
#config.batch_size = 256
#config.patience = 200

#config.learning_rate = 0.00001
#config.weight_decay = 1e-4
#config.scheduler = None
#config.reduce_lr_on_plateau = ReduceLROnPlateau()

#config.num_epochs = 2000
#config.cosine_annealing_lr = CosineAnnealingLR()
#config.linear_warmup_cosine_annealing_lr = LinearWarmupCosineAnnealingLR()


#not used in 1d

#config.one_cycle_lr = OneCycleLR()
#config.seed = 123
#activation = 'relu'

In [11]:
def config_to_dict(config_class):
    config_dict = {}
    for attr_name in dir(config_class):
        if not attr_name.startswith("__") and not callable(getattr(config_class, attr_name)):
            config_dict[attr_name] = getattr(config_class, attr_name)
    return config_dict
config_dict = config_to_dict(config)

In [12]:
config_dict

{'activation': 'relu',
 'batch_size': 256,
 'cosine_annealing_lr': CosineAnnealingLR(eta_min=1e-06),
 'dataset': 'BDE',
 'device': 'cuda:0',
 'dropout': 0.5,
 'hidden_dim': 128,
 'learning_rate': 0.001,
 'linear_warmup_cosine_annealing_lr': LinearWarmupCosineAnnealingLR(warmup_steps=200, max_epochs=2000),
 'max_num_conformers': 20,
 'max_num_molecules': None,
 'model1d': Model1D(model='LSTM', input_type='SMILES', embedding_dim=128, num_layers=4, num_heads=4),
 'model2d': Model2D(model='GIN', gin=GIN(num_layers=6, virtual_node=False), gps=GPS(num_layers=6, walk_length=20, num_heads=4), chemprop=ChemProp(num_layers=6)),
 'model3d': Model3D(model='PaiNN', augmentation=True, schnet=SchNet(hidden_dim=128, num_filters=5, num_interactions=6, num_gaussians=50, cutoff=10, readout='mean', dipole=False), dimenet=DimeNet(hidden_channels=128, out_channels=128, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6), dimenetplusplus=DimeNetPlusPlus(hidden_channels=128, out_channels=128, num_blo

In [13]:
subkeys = ["dataset",
           "max_num_conformers",
           "target",
           "train_ratio",
           "valid_ratio",
           "seed",
           "model1d", #.augmentation"
           "batch_size"]
config_dict_datamodule = {}
for k,v in config_dict.items():
    if k in subkeys:
        if k=="model1d":
            config_dict_datamodule[f'{k}_input_type']=config_dict[k].input_type
        else:
            config_dict_datamodule[k]=v

In [14]:
data = DataModule(config_dict_datamodule)
data.prepare_data()
data.split_compute()



--done---
4140 training data, 1184 test data and 591 validation data


In [15]:
def config_to_dict(config_class):
    config_dict = {}
    for attr_name in dir(config_class):
        if not attr_name.startswith("__") and not callable(getattr(config_class, attr_name)):
            if attr_name=='model1d':
                attr_val = getattr(config_class, attr_name)
                for subattr, subvalue in zip(['model', 'input_type', 'embedding_dim', 'num_layers', 'num_heads'],
                                             [attr_val.model, attr_val.input_type, attr_val.embedding_dim, 
                                              attr_val.num_layers, attr_val.num_heads]):
                    config_dict[f'{attr_name}_{subattr}'] = subvalue
            elif attr_name in ['model2d','model3d','model4d','modelfprf',
                               'linear_warmup_cosine_annealing_lr','cosine_annealing_lr',
                               'reduce_lr_on_plateau','one_cycle_lr']:
                pass
            else:
                config_dict[attr_name] = getattr(config_class, attr_name)
    return config_dict
#config_dict = config_to_dict(config)

In [16]:
config_dict

{'activation': 'relu',
 'batch_size': 256,
 'cosine_annealing_lr': CosineAnnealingLR(eta_min=1e-06),
 'dataset': 'BDE',
 'device': 'cuda:0',
 'dropout': 0.5,
 'hidden_dim': 128,
 'learning_rate': 0.001,
 'linear_warmup_cosine_annealing_lr': LinearWarmupCosineAnnealingLR(warmup_steps=200, max_epochs=2000),
 'max_num_conformers': 20,
 'max_num_molecules': None,
 'model1d': Model1D(model='LSTM', input_type='SMILES', embedding_dim=128, num_layers=4, num_heads=4),
 'model2d': Model2D(model='GIN', gin=GIN(num_layers=6, virtual_node=False), gps=GPS(num_layers=6, walk_length=20, num_heads=4), chemprop=ChemProp(num_layers=6)),
 'model3d': Model3D(model='PaiNN', augmentation=True, schnet=SchNet(hidden_dim=128, num_filters=5, num_interactions=6, num_gaussians=50, cutoff=10, readout='mean', dipole=False), dimenet=DimeNet(hidden_channels=128, out_channels=128, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6), dimenetplusplus=DimeNetPlusPlus(hidden_channels=128, out_channels=128, num_blo

In [17]:
model = Model1D(vocab_size=data.vocab_size, tokenizer=data.tokenizer, smiles_ids=data.smiles_ids, **config_dict)

In [18]:
print(f'#PARAMS = {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

#PARAMS = 1542785


In [19]:
model

Model1D(
  (net): LSTM(
    (embedding): Embedding(7924, 128, padding_idx=1)
    (lstm): LSTM(128, 128, num_layers=4, batch_first=True, dropout=0.5)
    (fc): Linear(in_features=128, out_features=1, bias=True)
  )
  (loss_fn): MSELoss()
)

In [20]:
dir_name = f"tmp_2"

dir_load_model = None
log_dir_folder = '/mnt/code/logs/'
log_dir_folder = os.path.join(log_dir_folder, dir_name)
if os.path.exists(log_dir_folder):
    if os.path.exists(os.path.join(log_dir_folder, "last.ckpt")):
        dir_load_model = os.path.join(log_dir_folder, "last.ckpt")
    csv_path = os.path.join(log_dir_folder, "metrics.csv")
    while os.path.exists(csv_path):
        csv_path = csv_path + '.bak'
    if os.path.exists(os.path.join(log_dir_folder, "metrics.csv")):
        os.rename(os.path.join(log_dir_folder, "metrics.csv"), csv_path)

In [21]:
metric_to_monitor = "val_epoch_loss"

checkpoint_callback = ModelCheckpoint(
    dirpath=log_dir_folder,
    monitor=metric_to_monitor,
    mode = 'min',
    save_top_k=1,
    save_last=True,
    every_n_epochs=5,
    filename="best-model-{epoch}-{val_epoch_loss:.4f}",
)

early_stopping = early_stop_callback = EarlyStopping(
        monitor=metric_to_monitor,  # The metric you want to monitor
        patience=config.patience,  # Number of epochs with no improvement after which training will be stopped
        verbose=True,
        mode='min'  # Minimizing the validation loss
    )

tb_logger = TensorBoardLogger(log_dir_folder, name="tensorbord")#, version="", default_hp_metric=False)
csv_logger = CSVLogger(log_dir_folder, name="", version="")

model_params = dict(
    devices=1, #args['ngpus'],
    accelerator='gpu', #args['accelerator'],
    default_root_dir=log_dir_folder, #args['log_dir'],
    logger=[tb_logger, csv_logger],
    enable_progress_bar=True)


model_params.update(dict(
    max_epochs=config.num_epochs,#1000,
    callbacks=[checkpoint_callback, early_stopping],
    gradient_clip_val=10,#args['clip_norm'],
    #precision="16-mixed",
))

In [22]:
torch.set_float32_matmul_precision('medium')

In [23]:
trainer = pl.Trainer(**model_params)#, profiler=profiler)
trainer.fit(model, datamodule=data, ckpt_path=dir_load_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /mnt/code/logs/tmp_2/tensorbord
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | net     | LSTM    | 1.5 M 
1 | loss_fn | MSELoss | 0     
------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
6.171     Total estimated model params size (MB)
  rank_zero_warn(


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved. New best score: 7.141


Validation: 0it [00:00, ?it/s]

Metric val_epoch_loss improved by 6.140 >= min_delta = 0.0. New best score: 1.001
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


KeyboardInterrupt: 

In [53]:
from pytorch_lightning.utilities.parsing import is_picklable

In [40]:
torch.tensor(model.logging_info["train_loss"]).mean().item()

20.360118865966797

In [51]:
torch.tensor(model.logging_info["train_loss"]).mean().item()

4.66825008392334

In [None]:
checkpoint_path = generate_checkpoint_filename()
early_stopping = EarlyStopping(patience=config.patience, path=checkpoint_path)
print(f'Checkpoint path: {checkpoint_path}')

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

if config.scheduler == 'ReduceLROnPlateau':
    scheduler = ReduceLROnPlateau(
        optimizer, verbose=True, **asdict(config.reduce_lr_on_plateau))
elif config.scheduler == 'CosineAnnealingLR':
    scheduler = CosineAnnealingLR(
        optimizer, T_max=config.num_epochs, verbose=True, **asdict(config.cosine_annealing_lr))
elif config.scheduler == 'LinearWarmupCosineAnnealingLR':
    scheduler = LinearWarmupCosineAnnealingLR(
        optimizer, **asdict(config.linear_warmup_cosine_annealing_lr))
else:
    scheduler = None

best_val_error = None

In [None]:
for epoch in range(config.num_epochs):
    loss = train(train_loader)
    if scheduler is not None:
        scheduler.step(loss)
    valid_error = eval(valid_loader)

    early_stopping(valid_error, model)
    if early_stopping.counter == 0:
        test_error = eval(test_loader)
    if early_stopping.early_stop:
        print('Early stopping...')
        break

    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/train', loss, epoch)
    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/valid', valid_error, epoch)
    writer.add_scalar(f'Loss_{config.model1d.model}/{config.model1d.input_type}'
                      f'/{config.dataset}/{config.target}/test', test_error, epoch)
    print(f'Progress: {epoch}/{config.num_epochs}/{loss:.5f}/{valid_error:.5f}/{test_error:.5f}')
    res_dict[dataname][algo][d]['loss'].append(loss)
    res_dict[dataname][algo][d]['valid_error'].append(valid_error)
    res_dict[dataname][algo][d]['test_error'].append(test_error)

model.load_state_dict(torch.load(checkpoint_path))
test_error = eval(test_loader)
test_r2 = eval_r2(test_loader)
print(f'Best validation error: {-early_stopping.best_score:.7f}')
print(f'Test error: {test_error:.7f}')
print(f'Test r2: {test_r2:.7f}')

res_dict[dataname][algo][d]['Test error'] = test_error
res_dict[dataname][algo][d]['Test r2'] = test_r2
res_dict[dataname][algo][d]['checkpoint']=checkpoint_path


#os.remove(checkpoint_path)
writer.close()

import pickle
with open("res_dict_1d.pkl", "wb") as f:
pickle.dump(res_dict, f)