# Import packages

In [1]:
import pandas as pd
from lightning import pytorch as pl
from sklearn.model_selection import train_test_split

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [2]:
input_path = '../tests/data/spectra/exclusions_no_na.csv' # path to your data .csv file
num_workers = 0  # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['400', '402', '404', '406', '408', '410'] # list of names of the columns containing spectra targets

## Load data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,400,402,404,406,408,410
0,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
1,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
2,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
3,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
4,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
...,...,...,...,...,...,...,...
195,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
196,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
197,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644
198,O=C(O)c1ccco1,0.001718,0.001717,0.001717,0.001701,0.001677,0.001644


In [4]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
ys

array([[0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366],
       [0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366],
       [0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366],
       ...,
       [0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366],
       [0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366],
       [0.00171802, 0.0017168 , 0.0017168 , 0.00170103, 0.00167736,
        0.00164366]])

## Get molecule datapoints

In [5]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [6]:
train_data, val_test_data = train_test_split(all_data, test_size=0.1)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

# Get molecule datasets

In [7]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data, featurizer)
test_dset.normalize_targets(scaler)

In [8]:
train_loader = data.MolGraphDataLoader(train_dset, num_workers=num_workers)
val_loader = data.MolGraphDataLoader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.MolGraphDataLoader(test_dset, num_workers=num_workers, shuffle=False)

# Construct the model

## Batch Norm

In [9]:
batch_norm = True # normalizes the outputs of the aggregation by re-centering and re-scaling.

## Message passing and aggregation

In [10]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()

## Feed-forward Network (FNN)

Spectral datasets use the ```SpectralFNN``` module.

This includes a spectral activation function with the options ```softplus``` or ```exp```.

In [11]:
ffn = nn.SpectralFFN(
    spectral_activation="softplus"  # "exp", "softplus". None defaults to softplus
)

## Metrics

Spectral FFNs work specifically with the ```SIDMetric``` metric.

In [12]:
metric_list = [nn.metrics.SIDMetric()]

## Construct MPNN

In [13]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=147, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=433, out_features=300, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (tau): ReLU()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): SpectralFFN(
    (ffn): MLP(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0, inplace=False)
      (3): Linear(in_features=300, out_features=1, bias=True)
      (spectral_activation): Softplus(beta=1, threshold=20)
    )
  )
)

# Training

In [14]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=3,  # number of epochs to train for
)

trainer.fit(mpnn, train_loader, val_loader)
results = trainer.test(mpnn, test_loader)

results

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /Users/joelmanu/chemprop/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 264 K 
1 | agg             | MeanAggregation    | 0     
2 | bn            

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 35.12it/s, train/loss=nan.0, val_loss=nan.0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 31.48it/s, train/loss=nan.0, val_loss=nan.0]


/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 73.36it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/sid                    nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/sid': nan}]