In [1]:
%load_ext autoreload
%autoreload 2
%cd "python/LatentEvolution"
%ls

/home/hew/python/LatentEvolution
[0m[01;34mcache[0m/  [01;34mdata[0m/  env.txt  [01;34mfigure[0m/  [01;34mframework[0m/  main.py  [01;34mscript[0m/  [01;34mtemp[0m/


In [2]:
from framework.config import parse_config, paths
from framework.utils.lightning.device_utils import seed_everything
from framework.utils.lightning.trainer_utils import get_pl_trainer
from script.task_02_ProteinVAE.ProteinVAE.sequence_data_module import SequenceDataModule
from script.task_02_ProteinVAE.ProteinVAE.sequence_lightning_module import SequenceLightningModule

root_path: /home/hew/python/LatentEvolution
framework_path: /home/hew/python/LatentEvolution/framework


In [3]:
update_dict = {
    'project': 'ProteinVAE',
    'seed': 42,
    'logger': {
        'save_dir': './script/task_02_ProteinVAE/'
    },
    'data': {
        'dataset': 'ACE2_variants_2k',
        # 'dataset': 'ACE2_variants_1000',
        'data_class': 'Protein',
        'mini_set_ratio': None,
        'max_len': 83,
    },
    # VAE v2
    # 'hparams': {
    #     'encoder_params': {
    #         'num_layers': 4,
    #         'embed_dim': 1280,
    #         'attention_heads': 20,
    #         'alphabet': 'ESM-1b',
    #         'token_dropout': False,
    #         'embedding_layer': True,
    #         'lm_head': False,
    #         'return_layer': -1,
    #     },
    #     'encoder_mlp': {
    #         'hiddens': [1280, 512, 256],
    #         'activation': 'ReLU',
    #         'batch_norm': False,
    #         'layer_norm': True,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'encoder_mapping': {
    #         'hiddens': [85 * 256, 4],
    #         'activation': 'ReLU',
    #         'batch_norm': True,
    #         'layer_norm': False,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_mapping': {
    #         'hiddens': [4 // 2, 85 * 256],
    #         'activation': 'ReLU',
    #         'batch_norm': True,
    #         'layer_norm': False,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_mlp': {
    #         'hiddens': [256, 512, 1280],
    #         'activation': 'ReLU',
    #         'batch_norm': False,
    #         'layer_norm': True,
    #         'bias': True,
    #         'dropout': 0.1,
    #     },
    #     'decoder_params': {
    #         'num_layers': 4,
    #         'embed_dim': 1280,
    #         'attention_heads': 20,
    #         'alphabet': 'ESM-1b',
    #         'token_dropout': False,
    #         'embedding_layer': False,
    #         'lm_head': False,
    #         'return_layer': -1,
    #     },

    # VAE v3 standard
    'hparams': {
        'encoder_transformer': {
            'num_layers': 4,
            'embed_dim': 128,
            'attention_heads': 16,
            'alphabet': 'ESM-1b',
            'token_dropout': False,
            'embedding_layer': True,
            'lm_head': False,
            'return_layer': -1,
        },
        'encoder_mlp': {
            'hiddens': [128, 64, 32],
            'activation': 'ReLU',
            'batch_norm': False,
            'layer_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        'decoder_mlp': {
            'hiddens': [32 // 2, 64, 128],
            'activation': 'ReLU',
            'batch_norm': False,
            'layer_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        'decoder_transformer': {
            'num_layers': 4,
            'embed_dim': 128,
            'attention_heads': 16,
            'alphabet': 'ESM-1b',
            'token_dropout': False,
            'embedding_layer': False,
            'lm_head': False,
            'return_layer': -1,
        },
        'regressor_head': {
            'hiddens': [85 * 32 // 2, 256, 128, 2],  # concat_h as input: L*D/2
            # 'hiddens': [32, 16, 2],  # position_h <cls> after pooling as input: D
            # 'hiddens': [128, 16, 2],  # position_h <cls> before pooling as input: H
            'activation': 'ReLU',
            'batch_norm': True,
            'bias': True,
            'dropout': 0.05,
        },
        # 'regressor_head': None,
        'reparameterization': False,
    },
    'loss': {
        'ce_loss': {'name': 'CrossEntropy', 'args': {}},
        'mse_loss': {'name': 'MSELoss', 'args': {}},
        'mmd_loss': {'name': 'MMDLoss', 'args': {'sigma': 20}},
        'ce_weight': 1.0,
        'mse_weight': 1000.0,
        'reg_weight': 0.1,
    },
    'optimizer': {
        # 'name': 'Adam',
        'name': 'RAdam',
        # 'name': 'AdamW',
        'args': {
            # 'lr': 0.001,
            'lr': 0.0005,
            # 'lr': 0.0002,
            # 'lr': 0.0001,
        }
    },
    'scheduler': {
        # 'name': None,
        # 'args': {},
        'name': 'LinearLR',
        'interval': 'step',
        'frequency': 1,
        'args': {
            'start_factor': 1,
            'end_factor': 0.01,
            'total_iters': 25 * 100 * 5
        },
    },
    'trainer': {
        # 'max_epochs': 50,
        # 'max_epochs': 100,
        'max_epochs': 200,
        'gradient_clip_val': 1.0,
        'gradient_clip_algorithm': 'norm',
        'accumulate_grad_batches': 1,
        'num_sanity_val_steps': 2,
        'val_check_interval': 0.5,
        'enable_checkpointing': True,
        # 'enable_checkpointing': False,
    },
    'early_stop_callback': {
        'monitor': 'valid/loss_epoch',
        'mode': 'min',
    },
    'ckpt_callback': {
        'monitor': 'valid/mse_epoch',
        'filename': 'epoch={epoch:02d}, loss={valid/loss_epoch:.3f}, ce={valid/ce_epoch:.3f}, reg={valid/reg_epoch:.3f}, mse={valid/mse_epoch:.3f}, ddG={valid/ddG_pearsonr_epoch:.3f}, dS={valid/dS_pearsonr_epoch:.3f}',
        'auto_insert_metric_name': False,
        'save_weights_only': True,
        'mode': 'min',
        'save_top_k': 100,
        'save_last': True,
    },
    # 'ckpt_callback': None,
    'train_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
        'drop_last': True,
    },
    'valid_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
    'test_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
    'predict_dataloader': {
        # 'batch_size': 32,
        'batch_size': 64,
        # 'batch_size': 128,
        # 'batch_size': 256,
        'num_workers': 4,
    },
}
dataset_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/dataset.yaml'
model_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/model.yaml'
framework_hparams = paths.script + '/task_02_ProteinVAE/ProteinVAE/framework.yaml'
args = parse_config([dataset_hparams, model_hparams, framework_hparams, update_dict])

In [4]:
args

{'project': 'ProteinVAE', 'dataset': 'template', 'model': '/home/hew/python/LatentEvolution/framework/config/model/template.yaml', 'seed': 42, 'tokenization': {'alphabet': 'ESM-1b', 'truncation_seq_length': None}, 'train_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': True, 'drop_last': True, 'pin_memory': True, 'persistent_workers': True}, 'valid_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'test_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'predict_dataloader': {'batch_size': 64, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'persistent_workers': True}, 'trainer': {'max_epochs': 200, 'accelerator': 'auto', 'strategy': 'auto', 'devices': 'auto', 'deterministic': False, 'benchmark': True, 'sync_batchnorm': True, 'log_every_n_steps': 1, 'check_val_every_n_epoch': 1, 'fast_dev_run': False, 'num_sanity_val_steps': 2, 'enable

In [5]:
%%time
seed_everything(args.seed)
pl_data_module = SequenceDataModule(args)

Global seed set to 42


CPU times: user 4.45 ms, sys: 62 µs, total: 4.51 ms
Wall time: 3.63 ms


In [6]:
# %%time
# pl_data_module.prepare_data('train')
# pl_data_module.setup('fit')
# pl_data_module.setup('test')

In [7]:
# pl_data_module.args.train_dataloader.batch_size = 128
# pl_data_module.args.valid_dataloader.batch_size = 128
# pl_data_module.args.test_dataloader.batch_size = 128
# pl_data_module.args.predict_dataloader.batch_size = 128
# pl_data_module.args.train_dataloader.num_workers = 4
# pl_data_module.args.valid_dataloader.num_workers = 4
# pl_data_module.args.test_dataloader.num_workers = 4
# pl_data_module.args.predict_dataloader.num_workers = 4

In [8]:
pl_data_module.dataframe

In [9]:
# tokens, ddG, dS = next(iter(pl_data_module.train_dataloader()))
# tokens.shape, ddG.shape, dS.shape

In [10]:
%%time
seed_everything(args.seed)
pl_model = SequenceLightningModule(args)

trainer = get_pl_trainer(args)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


CPU times: user 330 ms, sys: 44.9 ms, total: 375 ms
Wall time: 394 ms


In [11]:
pl_model.model

ProteinVAE(
  (encoder_transformer): ESMTransformer(
    (embed_tokens): Embedding(33, 128, padding_idx=1)
    (layers): ModuleList(
      (0-3): 4 x TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=128, out_features=128, bias=True)
          (v_proj): Linear(in_features=128, out_features=128, bias=True)
          (q_proj): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
    )
    (emb_layer_norm_after): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (encoder_mlp): MLP(
    (mlp): Sequential(

In [12]:
pl_model.model.encoder_mlp.mlp[0].weight, pl_model.model.encoder_mlp.mlp[0].weight.requires_grad

(Parameter containing:
 tensor([[ 0.0009,  0.0231,  0.0561,  ...,  0.0883, -0.0723, -0.0706],
         [ 0.0438, -0.0589, -0.0544,  ..., -0.0211,  0.0062,  0.0495],
         [ 0.0537,  0.0459, -0.0724,  ..., -0.0693, -0.0875,  0.0370],
         ...,
         [-0.0514,  0.0383,  0.0226,  ..., -0.0554,  0.0690, -0.0841],
         [-0.0037,  0.0279, -0.0087,  ..., -0.0350,  0.0324,  0.0755],
         [ 0.0561, -0.0160,  0.0218,  ..., -0.0486, -0.0286,  0.0117]],
        requires_grad=True),
 True)

In [13]:
pl_model.model.training, pl_model.training

(True, True)

In [14]:
# gc.collect()
# torch.cuda.empty_cache()

In [15]:
%%time
pl_data_module.prepare_data('train')
trainer.fit(model=pl_model, datamodule=pl_data_module)

<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< load data according to selected index >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


  0%|          | 0/2404 [00:00<?, ?it/s]

the dataset has not been partitioned, split dataset with specific ratio
dataframe partition values:
partition
train    1661
test      521
valid     222
Name: count, dtype: int64
select the subset for debug, max_len: 83, ratio: 1, number: 2404
len(self.train_dataset) 1661
len(self.valid_dataset) 222
len(self.test_dataset) 521
[len self.train_dataset] 1661
[len self.val_dataset] 222


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | model    | ProteinVAE       | 2.0 M 
1 | ce_loss  | CrossEntropyLoss | 0     
2 | mse_loss | MSELoss          | 0     
3 | mmd_loss | MMDLoss          | 0     
----------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.056     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

CPU times: user 15min 45s, sys: 8min 7s, total: 23min 53s
Wall time: 9min 58s


In [16]:
pl_model.model.encoder_mlp.mlp[0].weight

Parameter containing:
tensor([[ 0.0009, -0.0014,  0.0628,  ...,  0.0951, -0.0603, -0.0844],
        [ 0.0665, -0.0657, -0.0755,  ..., -0.0259, -0.0124,  0.0302],
        [ 0.0484,  0.0811, -0.0911,  ..., -0.0473, -0.0403,  0.0443],
        ...,
        [-0.0534,  0.0293,  0.0264,  ..., -0.0497,  0.0875, -0.1292],
        [-0.0468,  0.0287,  0.0232,  ...,  0.0093,  0.0049,  0.0844],
        [ 0.0709, -0.0197,  0.0174,  ..., -0.0223, -0.0189,  0.0508]],
       requires_grad=True)

In [17]:
trainer.test(model=pl_model, datamodule=pl_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[len self.test_dataset] 521


Testing: 0it [00:00, ?it/s]

[{'test/loss_epoch': 40.74958801269531,
  'test/ce_epoch': 22.543073654174805,
  'test/reg_epoch': 11.927752494812012,
  'test/mse_epoch': 6.278755187988281,
  'test/ddG_pearsonr_epoch': 0.5360563903271662,
  'test/ddG_spearmanr_epoch': 0.4634887474711109,
  'test/dS_pearsonr_epoch': 0.8944658767328253,
  'test/dS_spearmanr_epoch': 0.5946765908169404,
  'test/avg_pearsonr_epoch': 0.7152611335299958,
  'test/avg_spearmanr_epoch': 0.5290826691440256}]

In [18]:
trainer.test(model=pl_model, datamodule=pl_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[len self.test_dataset] 521


Testing: 0it [00:00, ?it/s]

[{'test/loss_epoch': 40.69425964355469,
  'test/ce_epoch': 22.543073654174805,
  'test/reg_epoch': 11.87242603302002,
  'test/mse_epoch': 6.278755187988281,
  'test/ddG_pearsonr_epoch': 0.5360563903271662,
  'test/ddG_spearmanr_epoch': 0.4634887474711109,
  'test/dS_pearsonr_epoch': 0.8944658767328253,
  'test/dS_spearmanr_epoch': 0.5946765908169404,
  'test/avg_pearsonr_epoch': 0.7152611335299958,
  'test/avg_spearmanr_epoch': 0.5290826691440256}]

In [19]:
pl_model.model.encoder_mlp.mlp[0].weight

Parameter containing:
tensor([[ 0.0009, -0.0014,  0.0628,  ...,  0.0951, -0.0603, -0.0844],
        [ 0.0665, -0.0657, -0.0755,  ..., -0.0259, -0.0124,  0.0302],
        [ 0.0484,  0.0811, -0.0911,  ..., -0.0473, -0.0403,  0.0443],
        ...,
        [-0.0534,  0.0293,  0.0264,  ..., -0.0497,  0.0875, -0.1292],
        [-0.0468,  0.0287,  0.0232,  ...,  0.0093,  0.0049,  0.0844],
        [ 0.0709, -0.0197,  0.0174,  ..., -0.0223, -0.0189,  0.0508]],
       requires_grad=True)

In [20]:
pl_data_module.prepare_predict_data(predict_data='test')
predictions = trainer.predict(model=pl_model, datamodule=pl_data_module)
predictions

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


len(self.test_dataset) 521
len(self.predict_index) 521
[len self.predict_dataset] 521


Predicting: 0it [00:00, ?it/s]

[{'recon_tokens': tensor([[ 8, 11, 12,  ...,  5,  4, 16],
          [16, 11, 12,  ...,  5,  4, 16],
          [ 8,  8, 12,  ...,  5,  4, 16],
          ...,
          [ 8, 11, 12,  ...,  5,  4, 16],
          [ 8, 11, 12,  ...,  5,  4, 16],
          [ 8, 11, 12,  ...,  5,  4, 16]]),
  'tokens': tensor([[ 8, 11, 12,  ...,  5,  4, 16],
          [ 5, 11, 12,  ...,  5,  4, 16],
          [ 8,  8, 12,  ...,  5,  4, 16],
          ...,
          [ 8, 11, 12,  ...,  5,  4, 16],
          [ 8, 11, 12,  ...,  5,  4, 16],
          [ 8, 11, 12,  ...,  5,  4, 16]]),
  'ddG': tensor([-0.7306, -0.7681, -1.3282, -0.0167, -1.2203, -0.7663, -1.8396, -0.1471,
          -0.1454, -0.8561, -1.4395, -1.9383, -0.5685, -0.4398, -1.1382, -4.4631,
          -1.3262, -0.7398, -1.0461, -1.7806, -0.8554, -1.4865, -1.5717, -0.1668,
          -1.2113, -0.0279, -1.8086, -1.5696, -1.8183, -1.7443, -0.0988, -0.2069,
          -0.2052, -2.4691, -4.8645, -1.2536, -1.0251, -0.9982, -1.8109, -0.2505,
          -0.1961, 