In [1]:
%load_ext autoreload
%autoreload 2
%cd "python/LatentEvolution"
%ls

/home/hew/python/LatentEvolution
[0m[01;34mcache[0m/  [01;34mdata[0m/  [01;34mfigure[0m/  [01;34mframework[0m/  main.py  [01;34mscript[0m/  [01;34mtemp[0m/


In [2]:
import gc

import torch

from framework.config import parse_config, paths
from framework.utils.lightning.device_utils import seed_everything
from framework.utils.lightning.trainer_utils import get_pl_trainer
from script.task_02_ProteinVAE.ProteinVAE.sequence_data_module import SequenceDataModule
from script.task_02_ProteinVAE.ProteinVAE.sequence_lightning_module import SequenceLightningModule

root_path: /home/hew/python/LatentEvolution
framework_path: /home/hew/python/LatentEvolution/framework


In [3]:
update_dict = {
    'project': 'ProteinVAE',
    'seed': 42,
    'logger': {
        'save_dir': './script/task_02_ProteinVAE/'
    },
    'data': {
        'dataset': 'ACE2_variants_2k',
        # 'dataset': 'ACE2_variants_1000',
        'data_class': 'Protein',
        'mini_set_ratio': None,
        'max_len': 83,
    },
    # VAE v2
    # 'hparams': {
    #     'encoder_params': {
    #         'num_layers': 4,
    #         'embed_dim': 1280,
    #         'attention_heads': 20,
    #         'alphabet': 'ESM-1b',
    #         'token_dropout': False,
    #         'embedding_layer': True,
    #         'lm_head': False,
    #         'return_layer': -1,
    #     },
    #     'encoder_mlp': {
    #         'hiddens': [1280, 512, 256],
    #         'activation': 'ReLU',
    #         'batch_norm': False,
    #         'layer_norm': True,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'encoder_mapping': {
    #         'hiddens': [85 * 256, 4],
    #         'activation': 'ReLU',
    #         'batch_norm': True,
    #         'layer_norm': False,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_mapping': {
    #         'hiddens': [4 // 2, 85 * 256],
    #         'activation': 'ReLU',
    #         'batch_norm': True,
    #         'layer_norm': False,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_mlp': {
    #         'hiddens': [256, 512, 1280],
    #         'activation': 'ReLU',
    #         'batch_norm': False,
    #         'layer_norm': True,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_params': {
    #         'num_layers': 4,
    #         'embed_dim': 1280,
    #         'attention_heads': 20,
    #         'alphabet': 'ESM-1b',
    #         'token_dropout': False,
    #         'embedding_layer': False,
    #         'lm_head': False,
    #         'return_layer': -1,
    #     },

    # VAE v3 standard
    'hparams': {
        'encoder_transformer': {
            'num_layers': 4,
            'embed_dim': 128,
            'attention_heads': 16,
            'alphabet': 'ESM-1b',
            'token_dropout': False,
            'embedding_layer': True,
            'lm_head': False,
            'return_layer': -1,
        },
        'encoder_mlp': {
            'hiddens': [128, 64, 32],
            'activation': 'ReLU',
            'batch_norm': False,
            'layer_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        'decoder_mlp': {
            'hiddens': [32 // 2, 64, 128],
            'activation': 'ReLU',
            'batch_norm': False,
            'layer_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        'decoder_transformer': {
            'num_layers': 4,
            'embed_dim': 128,
            'attention_heads': 16,
            'alphabet': 'ESM-1b',
            'token_dropout': False,
            'embedding_layer': False,
            'lm_head': False,
            'return_layer': -1,
        },
        'regressor_head': {
            'hiddens': [85 * 32 // 2, 256, 128, 2],  # concat_h as input: L*D/2
            # 'hiddens': [32, 16, 2],  # position_h <cls> after pooling as input: D
            # 'hiddens': [128, 16, 2],  # position_h <cls> before pooling as input: H
            'activation': 'ReLU',
            'batch_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        # 'regressor_head': None,
        'reparameterization': False,
    },
    'loss': {
        'ce_loss': {'name': 'CrossEntropy', 'args': {}},
        'mse_loss': {'name': 'MSELoss', 'args': {}},
        'mmd_loss': {'name': 'MMDLoss', 'args': {'sigma': 20}},
        'ce_weight': 1.0,
        'mse_weight': 1000.0,
        'reg_weight': 0.1,
    },
    'optimizer': {
        # 'name': 'Adam',
        'name': 'RAdam',
        # 'name': 'AdamW',
        'args': {
            # 'lr': 0.001,
            'lr': 0.0005,
            # 'lr': 0.0002,
            # 'lr': 0.0001,
        }
    },
    'scheduler': {
        # 'name': None,
        # 'args': {},
        'name': 'LinearLR',
        'interval': 'step',
        'frequency': 1,
        'args': {
            'start_factor': 1,
            'end_factor': 0.01,
            'total_iters': 25 * 100 * 5
        },
    },
    'trainer': {
        # 'max_epochs': 50,
        'max_epochs': 100,
        'gradient_clip_val': 1.0,
        'gradient_clip_algorithm': 'norm',
        'accumulate_grad_batches': 1,
        'num_sanity_val_steps': 2,
        'val_check_interval': 0.5,
        'enable_checkpointing': True,
        # 'enable_checkpointing': False,
    },
    'early_stop_callback': {
        'monitor': 'valid/loss_epoch',
        'mode': 'min',
    },
    'ckpt_callback': {
        'monitor': 'valid/loss_epoch',
        'filename': 'epoch={epoch:02d}, loss={valid/loss_epoch:.3f}, ce={valid/ce_epoch:.3f}, reg={valid/reg_epoch:.3f}, mse={valid/mse_epoch:.3f}, ddG={valid/ddG_pearsonr_epoch:.3f}, dS={valid/dS_pearsonr_epoch:.3f}',
        'auto_insert_metric_name': False,
        'save_weights_only': True,
        'mode': 'min',
        'save_top_k': 50,
        'save_last': True,
    },
    # 'ckpt_callback': None,
    'train_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
        'drop_last': True,
    },
    'valid_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
    'test_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
    'predict_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
}
dataset_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/dataset.yaml'
model_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/model.yaml'
framework_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/framework.yaml'
args = parse_config([dataset_hparams, model_hparams, framework_hparams, update_dict])

In [4]:
args

{'project': 'ProteinVAE', 'dataset': 'template', 'model': '/home/hew/python/LatentEvolution/framework/config/model/template.yaml', 'seed': 42, 'tokenization': {'alphabet': 'ESM-1b', 'truncation_seq_length': None}, 'train_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': True, 'drop_last': True, 'pin_memory': True, 'persistent_workers': True}, 'valid_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'test_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'predict_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'trainer': {'max_epochs': 100, 'accelerator': 'auto', 'strategy': 'auto', 'devices': 'auto', 'deterministic': False, 'benchmark': True, 'sync_batchnorm': True, 'log_every_n_steps': 1, 'check_val_every_n_epoch': 1, 'fast_dev_run': False, 'num_sanity_val_steps': 2, 'enable

In [5]:
%%time
seed_everything(args.seed)
pl_data_module = SequenceDataModule(args)

Global seed set to 42


CPU times: user 4.42 ms, sys: 60 µs, total: 4.48 ms
Wall time: 3.47 ms


In [6]:
%%time
pl_data_module.prepare_data('train')
pl_data_module.setup('fit')
pl_data_module.setup('test')

<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< load data according to selected index >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


  0%|          | 0/2404 [00:00<?, ?it/s]

the dataset has not been partitioned, split dataset with specific ratio
dataframe partition values:
train    1661
test      521
valid     222
Name: partition, dtype: int64
select the subset for debug, max_len: 83, ratio: 1, number: 2404
len(self.train_dataset) 1661
len(self.valid_dataset) 222
len(self.test_dataset) 521
[len self.train_dataset] 1661
[len self.val_dataset] 222
[len self.test_dataset] 521
CPU times: user 734 ms, sys: 300 ms, total: 1.03 s
Wall time: 1.09 s


In [7]:
# pl_data_module.args.train_dataloader.batch_size = 128
# pl_data_module.args.valid_dataloader.batch_size = 128
# pl_data_module.args.test_dataloader.batch_size = 128
# pl_data_module.args.predict_dataloader.batch_size = 128
# pl_data_module.args.train_dataloader.num_workers = 4
# pl_data_module.args.valid_dataloader.num_workers = 4
# pl_data_module.args.test_dataloader.num_workers = 4
# pl_data_module.args.predict_dataloader.num_workers = 4

In [8]:
pl_data_module.dataframe

Unnamed: 0,index,name,partition,length,sequence,structure,graph,dS,ddG,bins
0,0,0,train,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,0.017,-1.0838,"(-1.091, -1.08]"
1,1,1,train,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,0.017,-0.0154,"(-0.0185, -0.0134]"
2,10,10,train,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,0.017,-0.8987,"(-0.899, -0.889]"
3,100,100,train,83,STIEEQAKTFLDKFNHDAEDLFYQSFLASWNYNTNITEENVQNMNN...,,,0.017,-1.1936,"(-1.197, -1.189]"
4,1000,1000,train,83,SDIEEQAKTFLDKFNHEAEDLFYQSSLAYWNYNTNITEENVQNMGN...,,,0.017,-2.5357,"(-2.576, -2.523]"
...,...,...,...,...,...,...,...,...,...,...
2399,995,995,test,83,STIEEQAKTFLDKFNHEAEDLFYQSDLARWNYNTNITEENVQNMNN...,,,0.017,-0.4275,"(-0.438, -0.425]"
2400,996,996,test,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWWYNTNITEENVQNMNN...,,,0.017,-1.6738,"(-1.686, -1.673]"
2401,997,997,test,83,STIEEQAKTFLDKFNHEAEDLFYQMSLASWNYNTNITEENVQNMNN...,,,0.017,-1.9267,"(-1.943, -1.923]"
2402,998,998,test,83,SDIEEQAKMFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,0.017,-1.3416,"(-1.35, -1.332]"


In [9]:
tokens, ddG, dS = next(iter(pl_data_module.train_dataloader()))
tokens.shape, ddG.shape, dS.shape

(torch.Size([64, 85]), torch.Size([64]), torch.Size([64]))

In [10]:
%%time
seed_everything(args.seed)
pl_model = SequenceLightningModule(args)

trainer = get_pl_trainer(args)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


CPU times: user 285 ms, sys: 60 ms, total: 345 ms
Wall time: 345 ms


In [11]:
pl_model.model

ProteinVAE(
  (encoder_transformer): ESMTransformer(
    (embed_tokens): Embedding(33, 128, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=128, out_features=128, bias=True)
          (v_proj): Linear(in_features=128, out_features=128, bias=True)
          (q_proj): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=128, out_features=128, bias=True)

In [12]:
gc.collect()
torch.cuda.empty_cache()

In [13]:
%%time
pl_data_module.prepare_data('train')
trainer.fit(model=pl_model, datamodule=pl_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | model    | ProteinVAE       | 2.0 M 
1 | ce_loss  | CrossEntropyLoss | 0     
2 | mse_loss | MSELoss          | 0     
3 | mmd_loss | MMDLoss          | 0     
----------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.056     Total estimated model params size (MB)


len(self.train_dataset) 1661
len(self.valid_dataset) 222
len(self.test_dataset) 521
[len self.train_dataset] 1661
[len self.val_dataset] 222


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


CPU times: user 10min 10s, sys: 4min 33s, total: 14min 43s
Wall time: 6min 46s


In [14]:
trainer.test(model=pl_model, datamodule=pl_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[len self.test_dataset] 521


Testing: 0it [00:00, ?it/s]

[{'test/loss_epoch': 42.53229904174805,
  'test/ce_epoch': 19.594350814819336,
  'test/reg_epoch': 11.896059036254883,
  'test/mse_epoch': 11.041886329650879,
  'test/ddG_pearsonr_epoch': 0.5305525255945294,
  'test/ddG_spearmanr_epoch': 0.47339832121591047,
  'test/dS_pearsonr_epoch': 0.8595295487667789,
  'test/dS_spearmanr_epoch': 0.4247684171266864,
  'test/avg_pearsonr_epoch': 0.6950410371806541,
  'test/avg_spearmanr_epoch': 0.44908336917129843}]