# Import Train and Test datasets

In [1]:
from deepchem.data import DiskDataset

train = DiskDataset("aqsoldb_train")
test = DiskDataset("aqsoldb_test")


Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


<DiskDataset X.shape: (7870,), y.shape: (7870,), w.shape: (7870,), task_names: [0]>


# Create train and test datasets

In [2]:
from torch_geometric.data import Dataset
from deepchem.data.pytorch_datasets import _TorchDiskDataset


class AqSolDBDataset(Dataset):
  
  def __init__(self, deepchem_dataset: _TorchDiskDataset):
    self.graph_list = [mol.to_pyg_graph() for mol, _, _, _ in deepchem_dataset]
    self.labels = [y for _, y, _, _ in deepchem_dataset]
    self.length = len(self.labels)
    self._indices = None

  def __getitem__(self, i):
    graph = self.graph_list[i]
    label = self.labels[i]
    return graph, label
  
  def __len__(self):
    return self.length
  
  def len(self):
    return len(self)
  
  def get(self, idx):
    graph = self.graph_list[idx]
    label = self.labels[idx]
    return graph, label

train_dataset = AqSolDBDataset(train.make_pytorch_dataset())
test_dataset = AqSolDBDataset(test.make_pytorch_dataset())

train_dataset

AqSolDBDataset(7870)

# Create AqSol model

In [3]:
import torch.nn as nn
import torch_geometric.nn as pyg_nn
import torch
from torch_geometric.nn import global_add_pool
import torch.nn.functional as F


class AqSolModel(nn.Module):
  def __init__(
      self,
      n_features,
      hidden_channels,
      lr=10**-3,
      weight_decay=10**-2.5,
      dropout=0.2
    ):
    super(AqSolModel, self).__init__()

    self.conv1 = pyg_nn.GCNConv(n_features, hidden_channels)
    self.conv2 = pyg_nn.GCNConv(hidden_channels, hidden_channels)
    self.conv3 = pyg_nn.GCNConv(hidden_channels, int(hidden_channels / 2))
    
    self.lin = nn.Linear(int(hidden_channels / 2), int(hidden_channels / 2))
    self.lin2 = nn.Linear(int(hidden_channels / 2), 1)
    
    self.loss = nn.MSELoss()
    self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
    self.dropout = dropout

  def forward(self, mol):
    mol_x, mol_edge_index = mol.x, mol.edge_index
    
    mol_x = self.conv1(mol_x, mol_edge_index)
    mol_x = mol_x.relu()
    mol_x = self.conv2(mol_x, mol_edge_index)
    mol_x = mol_x.relu()
    mol_x = self.conv3(mol_x, mol_edge_index)
    mol_x = mol_x.relu()

    mol_x = global_add_pool(mol_x, mol.batch)

    mol_x = F.dropout(mol_x, p=self.dropout, training=self.training)

    mol_x = self.lin(mol_x)
    mol_x.relu()
    mol_x = self.lin2(mol_x)
    return mol_x

# Create model and train

In [None]:
from torch_geometric.loader import DataLoader
import math
import numpy as np
from tqdm import tqdm


model = AqSolModel(30, 128)

batch_size = min(len(train_dataset), 64)
num_epochs = 50
num_batches = math.ceil(len(train_dataset) / batch_size)
losses = np.zeros(num_epochs)
mean_loss = 0

model.train()

epoch_iter = range(num_epochs)

for epoch in epoch_iter:
  # print("Epoch " + str(epoch + 1) + " of " + str(num_epochs))
  model.optimizer.zero_grad()
  epoch_loss = 0
  for i, batch in enumerate(DataLoader(train_dataset, batch_size=batch_size)):
    graphs, labels = batch
    pred = model(graphs)
    actual = labels.reshape((len(labels), 1))
    # print(pred, actual)
    loss = model.loss(pred, actual)
    # print(loss.item())
    epoch_loss += loss.item()
    loss.backward()
    model.optimizer.step()
  losses[epoch] = epoch_loss / num_batches
  mean_loss += losses[epoch]
  print("loss: " + str(epoch_loss))


# Validate with test set

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error

model.eval()
preds = np.zeros((len(test)))
for i, (x_test, y_test) in enumerate(zip(test.X, test.y)):
    pred = model(x_test.to_pyg_graph()).detach().numpy()
    preds[i] = pred

print(preds, test.y)
print(mean_squared_error(test.y, preds))
print(mean_absolute_error(test.y, preds))

from matplotlib import pyplot as pyt

pyt.hist([preds, test.y], label=["Prediction", "Actual"]);
pyt.legend()

# Validate against train

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error

model.eval()
preds = np.zeros((len(train)))
for i, (x_test, y_test) in enumerate(zip(train.X, train.y)):
    pred = model(x_test.to_pyg_graph()).detach().numpy()
    preds[i] = pred

print(preds, train.y)
print(mean_squared_error(train.y, preds))
print(mean_absolute_error(train.y, preds))

In [None]:
from matplotlib import pyplot as pyt

pyt.hist([preds, train.y], label=["Prediction", "Actual"]);
pyt.legend()

# Hyperparameter Sweep

In [6]:
import wandb
from torch_geometric.loader import DataLoader


class Trainer:
  
  def __init__(self, model, dataset, batch_size):
    self.model = model
    self.dataset = dataset
    self.batch_size = batch_size
    
    self.mean_loss = 0
    self.run_epochs = 0
    
  def train_one_epoch(self):
    self.model.optimizer.zero_grad()
    epoch_loss = 0
    for i, batch in enumerate(DataLoader(self.dataset, batch_size=self.batch_size)):
      graphs, labels = batch
      pred = self.model(graphs)
      actual = labels.reshape((len(labels), 1))
      loss = self.model.loss(pred, actual)
      epoch_loss += loss.item()
      loss.backward()
      self.model.optimizer.step()
    self.mean_loss += epoch_loss
    print("loss: " + str(epoch_loss))
    return epoch_loss
    
  def run(self, num_epochs):
    epoch_loss = 0
    for _ in range(num_epochs):
      epoch_loss = self.train_one_epoch()
    wandb.log({"mse": epoch_loss})


In [None]:
from torch_geometric.loader import DataLoader
import math
import numpy as np
from tqdm import tqdm


model = AqSolModel(30, 128)

batch_size = min(len(train_dataset), 64)
num_epochs = 10
num_batches = math.ceil(len(train_dataset) / batch_size)
losses = np.zeros(num_epochs)
mean_loss = 0

wandb.init(
    # set the wandb project where this run will be logged
    project="SolubilityPredictor",
    
    # track hyperparameters and run metadata
    config={
      "learning_rate": 10**-2,
      "architecture": "ConvGNN",
      "epochs": num_epochs,
    }
)

model.train()

epoch_iter = range(num_epochs)

for epoch in epoch_iter:
  # print("Epoch " + str(epoch + 1) + " of " + str(num_epochs))
  model.optimizer.zero_grad()
  epoch_loss = 0
  for i, batch in enumerate(DataLoader(train_dataset, batch_size=batch_size)):
    graphs, labels = batch
    pred = model(graphs)
    actual = labels.reshape((len(labels), 1))
    # print(pred, actual)
    loss = model.loss(pred, actual)
    # print(loss.item())
    epoch_loss += loss.item()
    loss.backward()
    model.optimizer.step()
  losses[epoch] = epoch_loss / num_batches
  mean_loss += losses[epoch]
  print("loss: " + str(epoch_loss))
  wandb.log({"loss": epoch_loss})


wandb.finish()

In [8]:
from torch_geometric.loader import DataLoader


sweep_config = {
  "name": "sweep",
  "method": "bayes",
  "metric": {
    "goal": "minimize",
    "name": "mse"
  },
  "parameters": {
    "hidden_channels": {
      "min": 30,
      "max": 512
    },
    "num_epochs": {
      "min": 1,
      "max": 5
    },
    "batch_size": {
      "min": 1,
      "max": 128
    },
    "lr": {
      "min": 1e-6,
      "max": 1e-1
    },
    "weight_decay": {
      "min": float(0),
      "max": 1e-5
    },
    "dropout": {
      "min": 0,
      "max": 1
    }
  }
}


sweep_id = wandb.sweep(
  sweep_config, project="SolubilityPredictor"
)

def tune_hyperparameters(config=None):
  with wandb.init(config=config):
    config = wandb.config
    model = AqSolModel(
      30,
      config.hidden_channels,
      lr=config.lr,
      weight_decay=config.weight_decay,
      dropout=config.dropout
    )
    trainer = Trainer(model, train_dataset, config.batch_size)
    trainer.run(config.num_epochs)

wandb_config = {
  "architecture": "ConvGNN",
}


wandb.agent(
  sweep_id,
  function=tune_hyperparameters,
  project="SolubilityPredictor",
  count=10
)

Create sweep with ID: 5d4kipr4
Sweep URL: https://wandb.ai/alexb02/SolubilityPredictor/sweeps/5d4kipr4
<IPython.core.display.HTML object>
VBox(children=(Label(value='0.022 MB of 0.022 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max=1.0)))
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>


[34m[1mwandb[0m: Agent Starting Run: oxo9th0s with config:
[34m[1mwandb[0m: 	batch_size: 113
[34m[1mwandb[0m: 	hidden_channels: 116
[34m[1mwandb[0m: 	lr: 0.0031676241864073535
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	weight_decay: 1.6646060646142023e-06


loss: 27.930533476173878
loss: 11.686980627477169
loss: 6.548539321869612
loss: 3.1140586379915476
loss: 10.998607411980629
loss: 9.385794211179018


0,1
mse,▁

0,1
mse,9.38579


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 7fhkqvjp with config:
[34m[1mwandb[0m: 	batch_size: 95
[34m[1mwandb[0m: 	hidden_channels: 211
[34m[1mwandb[0m: 	lr: 0.040101871318269235
[34m[1mwandb[0m: 	num_epochs: 2
[34m[1mwandb[0m: 	weight_decay: 3.392825869283121e-06


loss: 2055631.0120208263
loss: 1510492.9855852127


0,1
mse,▁

0,1
mse,1510492.98559


[34m[1mwandb[0m: Agent Starting Run: m1uhdck4 with config:
[34m[1mwandb[0m: 	batch_size: 109
[34m[1mwandb[0m: 	hidden_channels: 162
[34m[1mwandb[0m: 	lr: 0.01053808620683595
[34m[1mwandb[0m: 	num_epochs: 6
[34m[1mwandb[0m: 	weight_decay: 1.8582413567087952e-06


loss: 450.2283548489213
loss: 1236.7111067809165
loss: 950.5597934592515
loss: 374.46377059072256
loss: 38.10427113994956
loss: 99.83652552217245


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
mse,▁

0,1
mse,99.83653


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: yngz29d8 with config:
[34m[1mwandb[0m: 	batch_size: 100
[34m[1mwandb[0m: 	hidden_channels: 178
[34m[1mwandb[0m: 	lr: 0.0010457027203548805
[34m[1mwandb[0m: 	num_epochs: 10
[34m[1mwandb[0m: 	weight_decay: 2.0832709551664566e-06


loss: 33.79737573862076
loss: 9.003723001107574
loss: 3.569805594161153
loss: 2.231181161478162
loss: 2.3150185886770487
loss: 2.6742548514157534
loss: 3.3498885985463858
loss: 4.4622168354690075
loss: 5.339343395084143
loss: 7.036055373027921


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
mse,▁

0,1
mse,7.03606


[34m[1mwandb[0m: Agent Starting Run: lcdqm3xj with config:
[34m[1mwandb[0m: 	batch_size: 125
[34m[1mwandb[0m: 	hidden_channels: 63
[34m[1mwandb[0m: 	lr: 1.2115364751017272e-05
[34m[1mwandb[0m: 	num_epochs: 10
[34m[1mwandb[0m: 	weight_decay: 2.8428852222208036e-06


loss: 20.281343162059784
loss: 16.07087917625904
loss: 14.63545024394989
loss: 14.07811926305294
loss: 13.094731986522675
loss: 12.773168861865997
loss: 11.086216732859612
loss: 10.825394049286842
loss: 11.985680758953094
loss: 11.055353865027428


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
mse,▁

0,1
mse,11.05535


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: bbj3i1mc with config:
[34m[1mwandb[0m: 	batch_size: 63
[34m[1mwandb[0m: 	hidden_channels: 80
[34m[1mwandb[0m: 	lr: 0.01157399225633972
[34m[1mwandb[0m: 	num_epochs: 10
[34m[1mwandb[0m: 	weight_decay: 7.521045339086619e-07


loss: 898.6232204698026
loss: 2675.329891268164
loss: 7615.614904217422
loss: 9908.509663179517
loss: 16220.835028681904
loss: 25024.04684876278
loss: 16401.842674925923
loss: 10212.786762665957
loss: 7128.391537282616
loss: 7244.327060490847


0,1
mse,▁

0,1
mse,7244.32706


[34m[1mwandb[0m: Agent Starting Run: pu6x36b1 with config:
[34m[1mwandb[0m: 	batch_size: 123
[34m[1mwandb[0m: 	hidden_channels: 238
[34m[1mwandb[0m: 	lr: 0.004974840365876775
[34m[1mwandb[0m: 	num_epochs: 10
[34m[1mwandb[0m: 	weight_decay: 5.931625080055226e-07


loss: 52.83966972120106
loss: 40.7907604817301
loss: 119.51049136556685
loss: 71.76808876916766
loss: 109.18145343102515
loss: 9.470042414963245
loss: 1.7924700882285833
loss: 1.7761276569217443
loss: 2.5602665543556213
loss: 5.010594744235277


0,1
mse,▁

0,1
mse,5.01059


[34m[1mwandb[0m: Agent Starting Run: fiuc45iy with config:
[34m[1mwandb[0m: 	batch_size: 119
[34m[1mwandb[0m: 	hidden_channels: 53
[34m[1mwandb[0m: 	lr: 0.030434141472876
[34m[1mwandb[0m: 	num_epochs: 9
[34m[1mwandb[0m: 	weight_decay: 9.580083820874832e-09


loss: 1199.912275614217
loss: 5069.117190737277
loss: 3925.954480431974
loss: 2369.7580522298813
loss: 2214.271826542914
loss: 134.56758574396372
loss: 252.44085431843996
loss: 375.89999897405505
loss: 107.74666135013103


VBox(children=(Label(value='0.001 MB of 0.023 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.034508…

0,1
mse,▁

0,1
mse,107.74666


[34m[1mwandb[0m: Agent Starting Run: ojhqyvz3 with config:
[34m[1mwandb[0m: 	batch_size: 109
[34m[1mwandb[0m: 	hidden_channels: 102
[34m[1mwandb[0m: 	lr: 0.008125221253178022
[34m[1mwandb[0m: 	num_epochs: 9
[34m[1mwandb[0m: 	weight_decay: 7.012894083377397e-07


loss: 27.765543213114142
loss: 486.75018855929375
loss: 344.2917249780148
loss: 255.42646943219006
loss: 162.3536927383393
loss: 20.022428223863244
loss: 76.62445028312504
loss: 36.92621922492981
loss: 34.292299365624785


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
mse,▁

0,1
mse,34.2923


[34m[1mwandb[0m: Agent Starting Run: 3cjdfpxd with config:
[34m[1mwandb[0m: 	batch_size: 19
[34m[1mwandb[0m: 	hidden_channels: 443
[34m[1mwandb[0m: 	lr: 0.00108399110375668
[34m[1mwandb[0m: 	num_epochs: 10
[34m[1mwandb[0m: 	weight_decay: 1.5145354410457838e-07


loss: 591.1104813748971
loss: 60.31645524781197
loss: 194.32123770192266
loss: 43.25601611658931
loss: 46.76056456193328
loss: 43.130606949329376
loss: 50.614115483127534
loss: 106.65383637323976
loss: 457.7980485474691
loss: 1234.213883771561


0,1
mse,▁

0,1
mse,1234.21388
