### Experiment Tracking with W&B

- config: store hp and metadata for each run
- wandb.init
- wandb.watch: log model gradients and params over time (helps detect bugs e.g. weird grad behaviour)
- wandb.log: log stuff we care about
- wandb.save: save online

use with block in context manager syntax

In [None]:
import wandb
wandb.login()

In [None]:
config = dict(
    epochs = 50,
    val_ratio = 0,
    test_ratio = 0.2
)

In [None]:
def make(base_path, val_ratio, test_ratio, encode_data_name, decode_data_name, latent_dim):
    # TODO: make edges to device here on when called on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # dataset to encode
    encode_dataset = ReactionDataset(base_path, geo_file = encode_data_name, dataset_type= 'individual')
    encode_data = encode_dataset.data
    encode_data.train_mask = encode_data.val_mask = encode_data.test_mask = encode_data.y = None
    encode_data = train_test_split_edges(data = encode_data, val_ratio = val_ratio, test_ratio = test_ratio)
    encode_x = encode_data.x.to(device)
    encode_train_pos_edge_index = encode_data.train_pos_edge_index.to(device)

    # dataset to decode
    decode_dataset = ReactionDataset(base_path, geo_file = decode_data_name, dataset_type= 'individual')
    decode_data = decode_dataset.data
    decode_data.train_mask = decode_data.val_mask = decode_data.test_mask = decode_data.y = None
    decode_data = train_test_split_edges(data = decode_data, val_ratio = val_ratio, test_ratio = test_ratio)
    decode_x = decode_data.x.to(device)
    decode_train_pos_edge_index = decode_data.train_pos_edge_index.to(device)

    # model creation
    gae = GAE(MolEncoder(encode_data.num_node_features, latent_dim))
    opt = torch.optim.Adam(gae.parameters(), lr = 0.01)

    return gae, opt, encode_data, decode_data

In [None]:
def model_pipeline(hps):

    # start wandb
    with wandb.init(project="test", config=hps):
        
        # access hps through wandb.config so logging matches execution
        config = wandb.config

        # model data
        
        val_ratio = 0
        test_ratio = 0.2
        
        # make model, data, opt problem
        ts_r_gae, ts_r_opt, r_data, ts_data = make(r'data/', 0, 0.2, 'train_r', 'train_ts', 2)

### Testing GAEs

In [1]:
from ts_vae.gae import EGNN, EGNN_NEC, EGNN_AE
from ts_vae.layers import GCL_PYG
from ts_vae.data_processors.new_pyg_processor import ReactionDataset

import torch
import torch.nn as nn
from torch_geometric.data import DataLoader

import numpy as np

In [2]:
# remove processed files

import os
import glob

files = glob.glob(r'data/processed/*')
for f in files:
    os.remove(f)

In [3]:
rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

train_loader = DataLoader(rxns[: num_train], batch_size = 2, follow_batch = ['r', 'p'])
test_loader = DataLoader(rxns[num_train:], batch_size = 2, follow_batch = ['r', 'p'])

  0%|          | 0/6739 [00:00<?, ?it/s]

Processing...


  0%|          | 30/6739 [00:00<01:22, 81.14it/s]
  4%|▎         | 30/842 [00:00<00:02, 331.70it/s]
  0%|          | 30/6739 [00:00<00:09, 692.98it/s]
  4%|▎         | 30/842 [00:00<00:01, 416.85it/s]
  0%|          | 30/6739 [00:00<00:16, 417.30it/s]
  4%|▎         | 30/842 [00:00<00:01, 536.06it/s]


Done!


In [13]:
batch = next(iter(test_loader))
reactants = batch.r

Data(edge_attr=[30, 4], edge_index=[2, 30], idx=48, pos=[15, 3], x=[15, 11], z=[15])

In [92]:
# prop_indices = [mol['edge_index'] for mol in batch.r]


from torch_geometric.data.dataloader import DataListLoader

# batch.r

# dl = DataListLoader(rxns[0:10], batch_size = 2, follow_batch )

[Data(edge_attr=[30, 4], edge_index=[2, 30], idx=48, pos=[15, 3], x=[15, 11], z=[15]),
 Data(edge_attr=[26, 4], edge_index=[2, 26], idx=49, pos=[14, 3], x=[14, 11], z=[14])]

In [67]:
def identity_collate(data_list):
    return data_list

class CustomDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size = 1, shuffle = False, collate_fn = identity_collate, **kwargs):
        super(CustomDataLoader, self).__init__(dataset, batch_size, shuffle, collate_fn = identity_collate, **kwargs)
        # change to collate_fn = CustomCollater(follow_batch, exclude_keys)
    

class CustomCollater(object):
    def __init__(self, follow_batch, exclude_keys):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
    
    def collate(self, batch):
        elem 



In [90]:
rxns.data.r

[Data(edge_attr=[26, 4], edge_index=[2, 26], idx=0, pos=[15, 3], x=[15, 11], z=[15]),
 Data(edge_attr=[24, 4], edge_index=[2, 24], idx=1, pos=[13, 3], x=[13, 11], z=[13]),
 Data(edge_attr=[20, 4], edge_index=[2, 20], idx=2, pos=[10, 3], x=[10, 11], z=[10]),
 Data(edge_attr=[16, 4], edge_index=[2, 16], idx=3, pos=[9, 3], x=[9, 11], z=[9]),
 Data(edge_attr=[22, 4], edge_index=[2, 22], idx=4, pos=[11, 3], x=[11, 11], z=[11]),
 Data(edge_attr=[26, 4], edge_index=[2, 26], idx=5, pos=[14, 3], x=[14, 11], z=[14]),
 Data(edge_attr=[30, 4], edge_index=[2, 30], idx=6, pos=[15, 3], x=[15, 11], z=[15]),
 Data(edge_attr=[30, 4], edge_index=[2, 30], idx=7, pos=[15, 3], x=[15, 11], z=[15]),
 Data(edge_attr=[28, 4], edge_index=[2, 28], idx=8, pos=[15, 3], x=[15, 11], z=[15]),
 Data(edge_attr=[26, 4], edge_index=[2, 26], idx=9, pos=[14, 3], x=[14, 11], z=[14]),
 Data(edge_attr=[20, 4], edge_index=[2, 20], idx=10, pos=[11, 3], x=[11, 11], z=[11]),
 Data(edge_attr=[30, 4], edge_index=[2, 30], idx=11, pos

In [71]:
from torch_geometric.utils import to_dense_adj, to_dense_batch

# do for node_feats, edge_index, edge_attr, 

# to_dense_batch()

max_num_atoms = int(max(test_loader.dataset.data.num_atoms))



node_feats = torch.cat([r.x for r in batch.r])
batch_idx = batch.r_batch

print(type(node_feats), type(batch_idx))

to_dense_batch(node_feats, batch_idx)#, fill_value = 0, max_num_nodes = max_num_atoms)

<class 'torch.Tensor'> <class 'torch.Tensor'>


RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "C:\Users\Avish\miniconda3\envs\3d-rdkit\lib\site-packages\torch_scatter\scatter.py", line 31, in scatter_add
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return scatter_sum(src, index, dim, out, dim_size)
           ~~~~~~~~~~~ <--- HERE
  File "C:\Users\Avish\miniconda3\envs\3d-rdkit\lib\site-packages\torch_scatter\scatter.py", line 12, in scatter_sum
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    index = broadcast(index, src, dim)
            ~~~~~~~~~ <--- HERE
    if out is None:
        size = list(src.size())
  File "C:\Users\Avish\miniconda3\envs\3d-rdkit\lib\site-packages\torch_scatter\utils.py", line 13, in broadcast
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand_as(other)
          ~~~~~~~~~~~~~ <--- HERE
    return src
RuntimeError: The expanded size of the tensor (29) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [29].  Tensor sizes: [2]


In [40]:
# batch_mol_property(prop_indices)

## get sequence lengths
lengths = torch.tensor([ t.shape[0] for t in prop_indices ])
## padd
batch = torch.nn.utils.rnn.pad_sequence(prop_indices)
## compute mask
mask = (batch != 0).to(device)
batch, lengths, mask

RuntimeError: The expanded size of the tensor (30) must match the existing size (26) at non-singleton dimension 1.  Target sizes: [2, 30].  Tensor sizes: [2, 26]

In [17]:
for key in batch.r[0].keys:
    for mol in batch.r:
        print(mol[key])

tensor([[0., 0., 0., 1., 0., 8., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 2.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 2.],
        [0., 0., 1., 0., 0., 7., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 2.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 2.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 3.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 2.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0.

In [19]:
from torch.nn.utils.rnn import pad_sequence

def batch_mol_property(prop):
    return pad_sequence(prop, batch_first = True, padding_value = 0)

new_batch = {data_prop: batch_mol_property([mol[data_prop] for mol in batch.r]) for data_prop in batch.r[0].keys}

RuntimeError: The expanded size of the tensor (30) must match the existing size (26) at non-singleton dimension 1.  Target sizes: [2, 30].  Tensor sizes: [2, 26]

In [None]:
def collate_fn(rxn_batch):
    rxn_batch = 

In [8]:
from torch.utils.data import DataLoader

std_test_loader = DataLoader(rxns[0:10], batch_size = 2, collate_fn = collate_fn)


OtherReactionTriple(edge_attr_p=[28, 4], edge_attr_r=[26, 4], edge_attr_ts=[28, 4], edge_index_p=[2, 28], edge_index_r=[2, 26], edge_index_ts=[2, 28], idx=[1], num_atoms=[1], p=Data(edge_attr=[28, 4], edge_index=[2, 28], idx=0, pos=[15, 3], x=[15, 11], z=[15]), pos_p=[15, 3], pos_r=[15, 3], pos_ts=[15, 3], r=Data(edge_attr=[26, 4], edge_index=[2, 26], idx=0, pos=[15, 3], x=[15, 11], z=[15]), ts=Data(edge_attr=[28, 4], edge_index=[2, 28], idx=0, pos=[15, 3], x=[15, 11], z=[15]), x_p=[15, 11], x_r=[15, 11], x_ts=[15, 11])

In [4]:
# TODO: make batching proper

batch = next(iter(test_loader))
print(batch.r_batch)

TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

In [None]:
def collate_fn(rxn_batch):
    # takes in batch as list of datapoints RT(r, ts, p) and returns collated batch as dict of pytorch tensors

    
    
    
    reactants = rxn_batch.r
    #products = rxn_batch.p

    # max num edges and nodes
    max_num_edges = max([r.edge_attr.size(0) for r in reactants])
    max_num_nodes = max([r.z.size(0) for r in reactants])

    rxn_batch['atom_mask']

    pass

In [11]:
from torch_geometric.data import Data

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)


In [17]:
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2, follow_batch = ['x_s'])
batch = next(iter(loader))

print(batch)
print(batch.edge_index_s)

Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_s_batch=[10], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])


tensor([0, 1])


In [3]:
for key in test_loader.dataset.r_data.keys:
    print(key, ": ", test_loader.dataset.r_data[key].shape)

print(test_loader.dataset.r_data['x'].size(0), test_loader.dataset.r_data['x'].size(1))

x :  torch.Size([789, 11])
edge_index :  torch.Size([2, 1550])
edge_attr :  torch.Size([1550, 4])
pos :  torch.Size([789, 3])
z :  torch.Size([789])
idx :  torch.Size([60])
789 11


In [5]:
in_nf = test_loader.dataset.r_data['x'].size(1) # = out_nf?
h_nf = 5
emb_nf = 2 

egnn_ae = EGNN_AE(h_nf = h_nf, emb_nf = emb_nf, num_node_fs = in_nf)
opt = torch.optim.Adam(egnn_ae.parameters(), lr = 1e-3)

In [None]:
def train_egnn_ae(gae, opt):
    
    # lr_scheduler.step()

    # simple results dict for now
    res = {'loss': 0, 'counter': 0, 'loss_arr': []}

    for i, rxn_batch in enumerate(test_loader):
        
        gae.train()
        opt.zero_grad()
    
        # from batch get data info: node_feats, edge_index, edge_attr
        # pass into model and get out ... adj?

        # calc loss
        loss = 0 # = bce(adj_pred, adj_gt)
        loss.backward() 
        opt.step()

        # add results to experiment dataclass: loss, epoch,  batch_size
        # temp: use dict
        res['loss'] += loss.item() * batch
    
    # return final loss i.e. return experimentlog final loss
    pass


In [None]:
def train_gae(gae, opt, x, train_pos_edge_index):
    gae.train()
    opt.zero_grad()
    print("train x shape: ", x.shape)
    z = gae.encode(x, train_pos_edge_index)
    print("train z shape: ", z.shape)
    loss = gae.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test_gae(gae, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
    gae.eval()
    with torch.no_grad():
        z = gae.encode(x, train_pos_edge_index)
    return gae.test(z, test_pos_edge_index, test_neg_edge_index)

def new_test_gae(gae, x, edge_index):
    # this just does recon loss again
    gae.eval()
    with torch.no_grad():
        print("test x shape: ", x.shape)
        z = gae.encode(x, edge_index)
        print("test z shape: ", z.shape)
    return gae.recon_loss(z, edge_index)

r_ae.reset_parameters()

epochs = 10
for epoch in range(1, epochs + 1):

    # value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
    loss_train = train_gae(r_ae, r_opt, r_x, r_data.edge_index)
    print("===== Training complete with loss: {:.4f}, now testing ====".format(loss_train))
    loss_test = new_test_gae(r_ae, test_x, test_data.edge_index)
    if epoch % 1 == 0:
        print('===== Epoch: {:03d}, Loss: {:.4f} ===== \n'.format(epoch, loss_test))

Batch(num_atoms=[3], p=[3], p_batch=[3], r=[3], r_batch=[3], ts=[3])
Batch(num_atoms=[3], p=[3], p_batch=[3], r=[3], r_batch=[3], ts=[3])


In [None]:
# test_loader.dataset.data.r

baseline_losses = []
epoch_losses = []


# for ep in range(2):
for batch_id, rxn_batch in enumerate(test_loader):
    reactants = rxn_batch.r
    


    