In [1]:
import hydra
from pytorch_lightning import Trainer, callbacks, loggers, seed_everything
from scaling_model.models.utils import PredictionWriter
from scaling_model.models.painn_lightning import PaiNNforQM9
from scaling_model.data.data_module import (
    QM9DataModule,
    BaselineDataModule,
)

from lightning.pytorch.profilers import PyTorchProfiler
from torch.profiler import ProfilerActivity
import pytorch_lightning as pl
import torch
from lightning.pytorch.profilers import AdvancedProfiler
import os
import yaml
import pandas as pd
import numpy as np

In [2]:
yaml_string = """seed: 42069420
experiment: new_dens_logic

sampler: 
  sampler_type: random
  sampling_prob: 0.2

sampler_id: 05

data:
  target: 0
  data_dir: norm_alpha_${sampler.sampler_type}_${sampler_id}/
  max_protein_size: 1000
  batch_size_train: 8
  batch_size_inference: 8
  num_workers: 4
  splits: [0.8, 0.1, 0.1]
  seed: 42069420
  subset_size: null
  random_data: False


lightning_model:
  ema_decay: 0.9
  painn_kwargs:
    num_message_passing_layers: 3
    num_features: 128
    num_rbf_features: 20
    num_unique_atoms: 100
    cutoff_dist: 5.0
  prediction_kwargs:
    num_features: 128
    num_layers: 2
  optimizer_kwargs:
    weight_decay: 0.01
    lr: 1e-4
  lr_scheduler_kwargs:
    mode: min
    factor: 0.5
    patience: 5
    threshold: 1e-6
    threshold_mode: rel
    cooldown: 2
    min_lr: 1e-6

early_stopping:
    monitor: val_loss
    patience: 30
    min_delta: 1e-6

trainer:
  max_epochs: 52
  max_time: 00:10:00:00
  deterministic: true


  """

In [3]:
config = yaml.safe_load(yaml_string)


In [4]:
from scaling_model.models.utils import PredictionWriter
import re

In [5]:
def resolve_placeholders(config, base_dict):
    def resolve_value(value, base_dict):
        if isinstance(value, str):
            # Find all placeholders in the string
            matches = re.findall(r'\$\{([^}]+)\}', value)
            for match in matches:
                placeholder = match.split('.')
                ref_value = base_dict
                try:
                    for part in placeholder:
                        ref_value = ref_value[part]
                    # Only replace placeholders within strings
                    value = value.replace(f'${{{match}}}', str(ref_value))
                except KeyError:
                    raise ValueError(f"Unable to resolve placeholder {match}")
            return value
        elif isinstance(value, dict):
            return resolve_placeholders(value, base_dict)
        elif isinstance(value, list):
            return [resolve_value(item, base_dict) for item in value]
        else:
            return value

    resolved_config = {key: resolve_value(value, base_dict) for key, value in config.items()}
    
    # Handle numerical placeholders after resolving strings
    def convert_numerical_placeholders(resolved_config):
        if isinstance(resolved_config, dict):
            for key, value in resolved_config.items():
                resolved_config[key] = convert_numerical_placeholders(value)
        elif isinstance(resolved_config, list):
            resolved_config = [convert_numerical_placeholders(item) for item in resolved_config]
        elif isinstance(resolved_config, str):
            # Convert back to int or float if the value is a string representation of a number
            if resolved_config.isdigit():
                resolved_config = int(resolved_config)
            else:
                try:
                    resolved_config = float(resolved_config)
                except ValueError:
                    pass
        return resolved_config

    return convert_numerical_placeholders(resolved_config)

In [6]:
# Resolve placeholders in the config
resolved_config = resolve_placeholders(config, config)

In [7]:
resolved_config

{'seed': 42069420,
 'experiment': 'new_dens_logic',
 'sampler': {'sampler_type': 'random', 'sampling_prob': 0.2},
 'sampler_id': 5,
 'data': {'target': 0,
  'data_dir': 'norm_alpha_random_5/',
  'max_protein_size': 1000,
  'batch_size_train': 8,
  'batch_size_inference': 8,
  'num_workers': 4,
  'splits': [0.8, 0.1, 0.1],
  'seed': 42069420,
  'subset_size': None,
  'random_data': False},
 'lightning_model': {'ema_decay': 0.9,
  'painn_kwargs': {'num_message_passing_layers': 3,
   'num_features': 128,
   'num_rbf_features': 20,
   'num_unique_atoms': 100,
   'cutoff_dist': 5.0},
  'prediction_kwargs': {'num_features': 128, 'num_layers': 2},
  'optimizer_kwargs': {'weight_decay': 0.01, 'lr': 0.0001},
  'lr_scheduler_kwargs': {'mode': 'min',
   'factor': 0.5,
   'patience': 5,
   'threshold': 1e-06,
   'threshold_mode': 'rel',
   'cooldown': 2,
   'min_lr': 1e-06}},
 'early_stopping': {'monitor': 'val_loss', 'patience': 30, 'min_delta': 1e-06},
 'trainer': {'max_epochs': 52,
  'max_time'

In [8]:
from torch.utils.data import DataLoader
from torch_geometric.data import Data, Batch
import time


In [9]:
class PaiNNforQM9_sub(PaiNNforQM9):
    def __init__(
        self,
        ema_decay=0.9,
        painn_kwargs={},
        prediction_kwargs={},
        optimizer_kwargs={},
        lr_scheduler_kwargs={},
    ):
        super().__init__()
        self.ema_decay = ema_decay
        self.painn_kwargs = painn_kwargs
        self.prediction_kwargs = prediction_kwargs
        self.optimizer_kwargs = optimizer_kwargs
        self.lr_scheduler_kwargs = lr_scheduler_kwargs

        self.ema_val_loss = None
        self.example_input_array = (
            torch.tensor([8, 6, 6, 6, 6, 6, 6, 8, 6, 1, 1, 1, 1, 1, 1, 1, 1]),
            torch.tensor(
                [
                    [-1.9367, -1.9987, 0.1342],
                    [-1.7873, -0.8125, 0.3215],
                    [-0.5070, -0.1171, 0.2496],
                    [0.0975, 1.0183, 1.1337],
                    [1.4150, 0.2444, 1.2140],
                    [0.9229, -0.6731, 0.0458],
                    [1.3280, -0.0667, -1.2942],
                    [0.7616, 1.2588, -1.2733],
                    [-0.2487, 1.2681, -0.3002],
                    [-2.6507, -0.1535, 0.5738],
                    [-0.3977, 1.5453, 1.9420],
                    [2.3332, 0.8142, 1.0389],
                    [1.4985, -0.3062, 2.1528],
                    [1.0744, -1.7460, 0.1400],
                    [2.4107, 0.0298, -1.4204],
                    [0.9160, -0.6353, -2.1399],
                    [-1.0754, 1.9297, -0.5332],
                ]
            ),
            torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        )
        self.init_time = time.time()
        self.save_hyperparameters()

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        input_ = {
            "atoms": batch.z,
            "atom_positions": batch.pos,
            "graph_indexes": batch.batch,
        }
        y_hat = self(**input_)

        return {"preds": y_hat, "target":batch.y,"ids": batch.name}

In [10]:
cfg = resolved_config
# seed_everything(int(cfg['seed']))
cb = [
    callbacks.LearningRateMonitor(),
    PredictionWriter(dataloaders=["test"]),
]

dm = BaselineDataModule(
    **cfg['sampler'], **cfg['data'], cutoff=cfg['lightning_model']['painn_kwargs']['cutoff_dist']
)

# Load the model from the checkpoint
checkpoint_path = "logs/2024-05-22_02-03-22/HPC_3D_runs_no_D/5elfhdkj/checkpoints/manual_removals_manual_removals_density05_time_2024-05-22_02-03-22_epoch=52_val_loss=0.261209.ckpt"
# model = PaiNNforQM9.load_from_checkpoint(checkpoint_path, **cfg['lightning_model'])

trainer = Trainer(
    **cfg['trainer'],
)
trainer.fit(model=PaiNNforQM9_sub(**cfg['lightning_model']), datamodule=dm, ckpt_path=checkpoint_path)

model = PaiNNforQM9_sub(**cfg['lightning_model'])


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/mangus/miniconda3/envs/scaling_model/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Processing...
100%|██████████| 9786/9786 [02:24<00:00, 67.69it/s] 
Done!
Restoring states from the checkpoint path at logs/2024-05-22_02-03-22/HPC_3D_runs_no_D/5elfhdkj/checkpoints/manual_removals_manual_removals_density05_time_2024-05-22_02-03-22_epoch=52_val_loss=0.261209.ckpt
/home/mangus

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=52` reached.


In [11]:

# Predict with the loaded model
predictions = trainer.predict(
    model,
    dataloaders=[dm.test_dataloader()],
    # return_predictions=True,
    ckpt_path=checkpoint_path,
)
print(predictions)

Restoring states from the checkpoint path at logs/2024-05-22_02-03-22/HPC_3D_runs_no_D/5elfhdkj/checkpoints/manual_removals_manual_removals_density05_time_2024-05-22_02-03-22_epoch=52_val_loss=0.261209.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at logs/2024-05-22_02-03-22/HPC_3D_runs_no_D/5elfhdkj/checkpoints/manual_removals_manual_removals_density05_time_2024-05-22_02-03-22_epoch=52_val_loss=0.261209.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

[{'preds': tensor([[-1.2815],
        [-1.3212],
        [-1.3277],
        [-0.8175],
        [-2.0735],
        [-1.2581],
        [-1.0590],
        [-0.9408]]), 'target': tensor([[-1.4704],
        [-0.9368],
        [-0.5923],
        [ 0.4428],
        [ 0.2518],
        [-1.0295],
        [ 1.4884],
        [ 0.9760]]), 'ids': ['2VSD', '6KNA', '2KSM', '2N0Z', '3JZO', '2L5S', '6VER', '2L4A']}, {'preds': tensor([[-1.6741],
        [-1.0898],
        [-1.0923],
        [-1.3857],
        [-1.9195],
        [-0.6086],
        [-0.7512],
        [-1.4352]]), 'target': tensor([[ 0.7017],
        [ 0.8324],
        [-0.6174],
        [-1.4450],
        [-1.4579],
        [ 0.0189],
        [ 0.8566],
        [ 0.3307]]), 'ids': ['3FIA', '7WJP', '5YV7', '1ZP8', '2DIC', '1T8J', '2CPB', '4DKK']}, {'preds': tensor([[-1.8407],
        [-0.5417],
        [-0.6723],
        [-0.8589],
        [-1.5573],
        [-0.8441],
        [-1.1260],
        [-0.1717]]), 'target': tensor([[-0.4576],
  

In [12]:
# Initialize lists to store the data
preds_list = []
targets_list = []
ids_list = []

In [13]:
for batch in predictions:
    preds_list.append(batch['preds'].numpy())
    targets_list.append(batch['target'].numpy())
    ids_list.extend(batch['ids'])

# Concatenate lists into a single array
preds_array = torch.cat([torch.tensor(arr) for arr in preds_list]).numpy().flatten()
targets_array = torch.cat([torch.tensor(arr) for arr in targets_list]).numpy().flatten()

# Create a DataFrame
df = pd.DataFrame({
    'id': ids_list,
    'prediction': preds_array,
    'target': targets_array
})

In [14]:
df

Unnamed: 0,id,prediction,target
0,2VSD,-1.281526,-1.470442
1,6KNA,-1.321176,-0.936807
2,2KSM,-1.327682,-0.592305
3,2N0Z,-0.817452,0.442788
4,3JZO,-2.073529,0.251838
...,...,...,...
974,8SKX,-1.043148,0.594206
975,2JUU,-0.748641,-0.369630
976,1Z4H,-0.541139,0.525586
977,1SB6,-0.800840,-0.218667


In [15]:
np.mean((df.prediction - df.target)**2)

2.0424614

In [32]:
df = df.drop(df[df.prediction == df.prediction.max()].index)

In [16]:
np.mean((df.prediction - df.target)**2)

2.0424614

In [25]:
preds = pd.read_csv("logs/2024-05-22_02-03-22/predictions.csv")
preds

Unnamed: 0,pred_00,target_00,test,train,val
0,1.229491,1.097629,0.0,1.0,0.0
1,-0.634972,-1.310279,0.0,1.0,0.0
2,-0.537803,-0.342958,0.0,1.0,0.0
3,0.126474,-0.202467,0.0,1.0,0.0
4,0.176872,0.485918,0.0,1.0,0.0
...,...,...,...,...,...
9791,-0.197983,-0.386382,1.0,0.0,0.0
9792,1.329955,2.045383,1.0,0.0,0.0
9793,-0.027363,-0.240892,1.0,0.0,0.0
9794,-0.031765,0.253347,1.0,0.0,0.0


In [26]:
preds = preds[preds.test==1]

In [27]:
np.mean((preds.pred_00-preds.target_00)**2)

3.302982435694951e+31