# Manual data processing if issues

In [11]:
from numpy.core.fromnumeric import product
from scipy.sparse import data
import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.data import InMemoryDataset, DataLoader # , Data
from torch_geometric.data.data import Data
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm

def process_geometry_file(geometry_file, list = None):
    """ Code mostly lifted from QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html 
        Transforms molecules to their atom features and adjacency lists.
    """
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
    limit = 100

    data_list = list if list else []
    full_path = r'data' + geometry_file
    geometries = Chem.SDMolSupplier(full_path, removeHs=False, sanitize=False)

    # get atom and edge features for each geometry
    for i, mol in enumerate(tqdm(geometries)):

        # temp soln cos of split edge memory issues
        if i == limit:
            break
        
        N = mol.GetNumAtoms()
        # get atom positions as matrix w shape [num_nodes, num_dimensions] = [num_atoms, 3]
        atom_data = geometries.GetItemText(i).split('\n')[4:4 + N] 
        atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data]
        atom_positions = torch.tensor(atom_positions, dtype=torch.float)
        # all the features
        type_idx = []
        atomic_number = []
        aromatic = []
        sp = []
        sp2 = []
        sp3 = []
        num_hs = []

        # atom/node features
        for atom in mol.GetAtoms():
            type_idx.append(types[atom.GetSymbol()])
            atomic_number.append(atom.GetAtomicNum())
            aromatic.append(1 if atom.GetIsAromatic() else 0)
            hybridisation = atom.GetHybridization()
            sp.append(1 if hybridisation == HybridizationType.SP else 0)
            sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
            sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)
            # !!! should do the features that lucky does: whether bonded, 3d_rbf

        # bond/edge features
        row, col, edge_type = [], [], []
        for bond in mol.GetBonds(): 
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            # edge type for each bond type; *2 because both ways
            edge_type += 2 * [bonds[bond.GetBondType()]]
        # edge_index is graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_type = torch.tensor(edge_type, dtype=torch.long)
        # edge_attr is edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) 

        # order edges based on combined ascending order
        perm = (edge_index[0] * N + edge_index[1]).argsort() # TODO
        edge_index = edge_index[:, perm]
        edge_type = edge_type[perm]
        edge_attr = edge_attr[perm]

        row, col = edge_index
        z = torch.tensor(atomic_number, dtype=torch.long)
        hs = (z == 1).to(torch.float) # hydrogens
        num_hs = scatter(hs[row], col, dim_size=N).tolist() # scatter helps with one-hot
        
        x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
        x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
        x = torch.cat([x1.to(torch.float), x2], dim=-1)

        data = Data(x=x, z=z, pos=atom_positions, edge_index=edge_index, edge_attr=edge_attr, idx=i)
        
        data_list.append(data)

    return data_list

In [12]:
# concat train r and test r
reactants = []
reactants = process_geometry_file('/raw/train_reactants.sdf', reactants)
reactants = process_geometry_file('/raw/test_reactants.sdf', reactants)

# concat train ts and test ts
ts = []
ts = process_geometry_file('/raw/train_ts.sdf', ts)
ts = process_geometry_file('/raw/test_ts.sdf', ts) 

# concat train p and test p
products = []
products = process_geometry_file('/raw/train_products.sdf', products)
products = process_geometry_file('/raw/test_products.sdf', products) 

assert len(reactants) == len(ts) == len(products)

print(type(reactants[0]), type(ts[0]), type(products[0]))

  1%|▏         | 100/6739 [00:00<00:23, 286.92it/s]
 12%|█▏        | 100/842 [00:00<00:02, 268.14it/s]
  1%|▏         | 100/6739 [00:00<00:16, 414.71it/s]
 12%|█▏        | 100/842 [00:00<00:01, 404.30it/s]
  1%|▏         | 100/6739 [00:00<00:12, 548.55it/s]
 12%|█▏        | 100/842 [00:00<00:02, 314.45it/s]

<class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'>





In [57]:
class ReactionTriple(Data):
    def __init__(self, r = None, ts = None, p = None):
        super(ReactionTriple, self).__init__()
        self.r = r
        self.ts = ts
        self.p = p

    def __inc__(self, key, value):
        if key == 'r':
            return self.r.edge_index.size(0)
        elif key == 'ts':
            return self.ts.edge_index.size(0)
        elif key == 'p':
            return self.p.edge_index.size(0)
        else:
            return super().__inc__(key, value)

In [58]:
rxns = []
for rxn_id in range(len(reactants)):
    rxn = ReactionTriple(reactants[rxn_id], ts[rxn_id], products[rxn_id])
    rxns.append(rxn)

# Normal data processing from files

In [111]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.data import DataLoader
from ts_vae.data_processors.new_pyg_processor import ReactionDataset
from ts_vae.gae import GAE, MolEncoder, InnerProductDecoder
import numpy as np

In [112]:
rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

train_loader = DataLoader(rxns[: num_train], batch_size = 3, follow_batch = ['r', 'p'])
test_loader = DataLoader(rxns[num_train:], batch_size = 3, follow_batch = ['r', 'p'])

# batch = next(iter(train_loader))
# batch.p
# batch.r

In [51]:
def edge2adj(z, edge_index, sigmoid = True):
    value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
    return torch.sigmoid(value) if sigmoid else value

In [79]:
from torch_geometric.utils import to_dense_adj

node_fs = mol_graph.x
edge_index = mol_graph.edge_index
edge_attr = mol_graph.edge_attr
n_nodes = len(mol_graph.z)
latent_dim = 3

max_num_nodes = 21

adj = to_dense_adj(edge_index, edge_attr = edge_attr)

## Reconstruction loss on adj

In [None]:
def train_gae(gae, opt, x, train_pos_edge_index):
    gae.train()
    opt.zero_grad()
    print("train x shape: ", x.shape)
    z = gae.encode(x, train_pos_edge_index)
    print("train z shape: ", z.shape)
    loss = gae.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test_gae(gae, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
    gae.eval()
    with torch.no_grad():
        z = gae.encode(x, train_pos_edge_index)
    return gae.test(z, test_pos_edge_index, test_neg_edge_index)

def new_test_gae(gae, x, edge_index):
    # this just does recon loss again
    gae.eval()
    with torch.no_grad():
        print("test x shape: ", x.shape)
        z = gae.encode(x, edge_index)
        print("test z shape: ", z.shape)
    return gae.recon_loss(z, edge_index)

In [106]:
from ts_vae.data_processors.grambow_processor import ReactionDataset

# model data
base_path = r'data/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# reactant train data
r_dataset = ReactionDataset(base_path, geo_file = 'train_r', dataset_type= 'individual') 
r_data = r_dataset.data
r_x = r_data.x.to(device)

# reactant test data
test_dataset = ReactionDataset(base_path, geo_file = 'test_r', dataset_type= 'individual') 
test_data = test_dataset.data
test_x = test_data.x.to(device)

# reactant encoder
r_num_node_fs = r_data.num_node_features
r_latent_dim = 2
r_ae = GAE(MolEncoder(r_num_node_fs, r_latent_dim))
r_opt = torch.optim.Adam(r_ae.parameters(), lr = 0.01)

In [107]:
r_ae.reset_parameters()

epochs = 10
for epoch in range(1, epochs + 1):

    # value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
    loss_train = train_gae(r_ae, r_opt, r_x, r_data.edge_index)
    print("===== Training complete with loss: {:.4f}, now testing ====".format(loss_train))
    loss_test = new_test_gae(r_ae, test_x, test_data.edge_index)
    if epoch % 1 == 0:
        print('===== Epoch: {:03d}, Loss: {:.4f} ===== \n'.format(epoch, loss_test))

train x shape:  torch.Size([395, 11])
train z shape:  torch.Size([395, 2])
===== Training complete with loss: 3.6079, now testing ====
test x shape:  torch.Size([394, 11])
test z shape:  torch.Size([394, 2])
===== Epoch: 001, Loss: 3.4965 ===== 

train x shape:  torch.Size([395, 11])
train z shape:  torch.Size([395, 2])
===== Training complete with loss: 3.4361, now testing ====
test x shape:  torch.Size([394, 11])
test z shape:  torch.Size([394, 2])
===== Epoch: 002, Loss: 3.1556 ===== 

train x shape:  torch.Size([395, 11])
train z shape:  torch.Size([395, 2])
===== Training complete with loss: 3.1816, now testing ====
test x shape:  torch.Size([394, 11])
test z shape:  torch.Size([394, 2])
===== Epoch: 003, Loss: 2.9096 ===== 

train x shape:  torch.Size([395, 11])
train z shape:  torch.Size([395, 2])
===== Training complete with loss: 2.9659, now testing ====
test x shape:  torch.Size([394, 11])
test z shape:  torch.Size([394, 2])
===== Epoch: 004, Loss: 2.7989 ===== 

train x shap

In [None]:
class InnerProductDecoder(nn.Module):
    def forward(self, z, edge_index, sigmoid = True):
        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
        return torch.sigmoid(value) if sigmoid else value
    
    def forward_all(self, z, sigmoid = True):
        """ Decode latent variables into probabilistic adj matrix. """
        adj = torch.matmul(z, z.t())
        return torch.sigmoid(adj) if sigmoid else adj

In [92]:
# their model 
# so they take their nodes, edges, edge_attr and actual adj
# adj_pred, z = model(nodes, edges, edge_attr)
# bce, kl = loss(adj_pred, adj_gt)


26

In [140]:
i = torch.cat([edge_index[0].unsqueeze(0), edge_index[1].unsqueeze(0)])
v = torch.ones(i.size(1))

print("edge index: ", edge_index[0].shape, "|| edge index unsqueeze: ", edge_index[0].unsqueeze(0).shape)
print("i shape: ", i.shape, "|| v shape: ", v.shape, "|| num nodes: ", n_nodes)

#i = torch.cat([edges[0].unsqueeze(0), edges[1].unsqueeze(0)])
#v = torch.ones(i.size(1))
#adj_dense = torch.sparse.FloatTensor(i, v, torch.Size([n_nodes, n_nodes])).to_dense()
# i.size(1)

torch.sparse.FloatTensor(i, v, torch.Size([n_nodes, n_nodes]))

edge index:  torch.Size([26]) || edge index unsqueeze:  torch.Size([1, 26])
i shape:  torch.Size([2, 26]) || v shape:  torch.Size([26]) || num nodes:  13


tensor(indices=tensor([[ 0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,
                         4,  4,  4,  4,  5,  6,  7,  8,  9, 10, 11, 12],
                       [ 1,  5,  0,  2,  4,  6,  1,  3,  7,  8,  2,  4,  9, 10,
                         1,  3, 11, 12,  0,  1,  2,  2,  3,  3,  4,  4]]),
       values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
       size=(13, 13), nnz=26, layout=torch.sparse_coo)

In [153]:
from torch_geometric.utils import to_dense_adj

node_fs = mol_graph.x
edge_index = mol_graph.edge_index
edge_attr = mol_graph.edge_attr
num_nodes = len(mol_graph.z)
latent_dim = 3
max_num_nodes = 21

def sparse_to_dense_adj(num_nodes, edge_index):
    # edge_index is sparse_adj matrix (given in coo format for graph connectivity)
    sparse_adj = torch.cat([edge_index[0].unsqueeze(0), edge_index[1].unsqueeze(0)])
    # the values we put in at each tuple; that's why length of sparse_adj
    ones = torch.ones(sparse_adj.size(1)) 
    # FloatTensor() creates sparse coo tensor in torch format, then to_dense()
    dense_adj = torch.sparse.FloatTensor(sparse_adj, ones, torch.Size([num_nodes, num_nodes])).to_dense() # to_dense adds the zeroes needed
    return dense_adj



adj_egnn = sparse_to_dense_adj(num_nodes, edge_index)
# with edge_attr, we get a [1, num_nodes, num_nodes] for each edge_type
adj_pyg = to_dense_adj(edge_index, edge_attr = edge_attr, max_num_nodes = num_nodes)

# get_dense_graph(): returns self.nodes, self.edges_dense, self.edge_attr_dense, self.adj
# adj = sparse2dense(n_nodes, self.edges); adjust for loops

# compare sparse2dense (egnn) vs to_dense_adj (pyg)

# gcn = GCNConv(num_nodes, latent_dim)
# z = gcn(node_fs, edge_index)


In [151]:
adj_egnn.shape

torch.Size([13, 13])

In [154]:
(adj_pyg == adj_egnn).all()

RuntimeError: The size of tensor a (4) must match the size of tensor b (13) at non-singleton dimension 3

In [160]:
adj_egnn

tensor([[0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [170]:
t = torch.rand(2, 5)
# b = t.view(2, 8)
t

tensor([[0.0877, 0.9180, 0.5504, 0.4741, 0.6862],
        [0.6980, 0.5273, 0.1610, 0.9068, 0.8301]])

In [172]:
t.view(1, 10)

tensor([[0.0877, 0.9180, 0.5504, 0.4741, 0.6862, 0.6980, 0.5273, 0.1610, 0.9068,
         0.8301]])

In [161]:
# adj_pred = adj_pred * (1 - torch.eye(num_nodes).to(self.device)) # removes self_loops

# (1 - torch.eye(num_nodes)) gives [num_nodes, num_nodes] with all 1s except 0 on diag
# * is hadamard product
adj_egnn * (1 - torch.eye(13))

tensor([[0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [173]:
h = 0
print(f'h = {h}')

h = 0


In [109]:
for batch_id, rxn_batch in enumerate(test_loader):
    
    reactants = rxn_batch.r
    print(reactants)

[Data(edge_attr=[30, 4], edge_index=[2, 30], idx=48, pos=[15, 3], x=[15, 11], z=[15]), Data(edge_attr=[26, 4], edge_index=[2, 26], idx=49, pos=[14, 3], x=[14, 11], z=[14])]
[Data(edge_attr=[30, 4], edge_index=[2, 30], idx=50, pos=[14, 3], x=[14, 11], z=[14]), Data(edge_attr=[38, 4], edge_index=[2, 38], idx=51, pos=[17, 3], x=[17, 11], z=[17])]
[Data(edge_attr=[32, 4], edge_index=[2, 32], idx=52, pos=[15, 3], x=[15, 11], z=[15]), Data(edge_attr=[34, 4], edge_index=[2, 34], idx=53, pos=[17, 3], x=[17, 11], z=[17])]
[Data(edge_attr=[30, 4], edge_index=[2, 30], idx=54, pos=[13, 3], x=[13, 11], z=[13]), Data(edge_attr=[36, 4], edge_index=[2, 36], idx=55, pos=[17, 3], x=[17, 11], z=[17])]
[Data(edge_attr=[20, 4], edge_index=[2, 20], idx=56, pos=[10, 3], x=[10, 11], z=[10]), Data(edge_attr=[28, 4], edge_index=[2, 28], idx=57, pos=[15, 3], x=[15, 11], z=[15])]
[Data(edge_attr=[26, 4], edge_index=[2, 26], idx=58, pos=[13, 3], x=[13, 11], z=[13]), Data(edge_attr=[30, 4], edge_index=[2, 30], idx=

In [39]:
def train(epoch, loader):

    model.train()

    train_dict = {'epoch': epoch, 'loss': 0, 'bce': 0, 'adj_err': 0, 'coord_reg': 0}

    # want to create b
    for batch_id, rxn_batch in enumerate(loader):


Data(edge_attr=[28, 4], edge_index=[2, 28], idx=0, pos=[15, 3], x=[15, 11], z=[15])

In [None]:
# coords always same, maybe node and edge features too? need to pad adj matrix

# dataset dims
elements = "HCNO"
num_elements = len(elements)
max_n_atoms = max([r.GetNumAtoms() for r,ts,p in data])
num_coords = 3
num_bond_fs

# want to pad exist features

def prepare_batch(batch_mols):

    # initialise batch
    batch_size = len(batch_mols)
    atom_fs = torch.zeros((batch_size, max_n_atoms, num_elements + 1), dtype = torch.float32) # num_atoms, max_num_atoms, 
    bond_fs = torch.zeros((batch_size, max_n_atoms, max_n_atoms, num_bond_fs), dtype = torch.float32)
    sizes = torch.zeros(batch_size, dtype = torch.float32)
    coords = torch.zeros((batch_size, max_size, num_coords), dtype = torch.float32)
    
    pass

def pad_sequence(sequences: List[torch.Tensor], max_length: int, padding_value=0) -> torch.Tensor:
    # assuming trailing dimensions and type of all the Tensors
    # in sequences are same and fetching those from sequences[0]
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    out_dims = (len(sequences), max_length) + trailing_dims

    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)  # type: ignore
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        out_tensor[i, :length, ...] = tensor

    return out_tensor

# Training

In [13]:
# have to convert the training scheme: no more edge sampling and now use batches
# TODO: create data, opt, model; data and model to device

# simple R->R GAE, then build up
def train_gae(gae, opt, loader):
    # singular batch train loop

    model.train() # set flags

    batch_loss = 0

    # one iteration over different batches
    for i, rxn_batch in enumerate(loader):
        
        reactants = rxn_batch.r
        # pad mols for batch/maybe just pad all with max_num_atoms
        
        # zero gradients
        opt.zero_grad() 

        # encode reactant batch and calculate loss
        z_r = gae.encode(reactants)
        loss = gae.recon_loss(z_r)
        
        # modify gradients
        loss.backward()
        opt.step()
        
        # add batch loss
        batch_loss += loss.item()

        print("Loss: {:.3f}".format(loss.item() / len(reactants)))

    agg['Train_Loss'].append(batch_loss / len(loader.dataset))
    print('===> Epoch: {:03d}, Train Loss: {:.4f}'.format(epoch, agg['Train_Loss'][-1]))

In [None]:

# my train has automatic train-test edge split but this is what i was going for before [mmvae]
def train(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        data = unpack_data(dataT, device=device)
        optimizer.zero_grad()
        loss = - objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))

In [None]:
def train_gae(gae, opt, x, train_pos_edge_index):
    gae.train()
    opt.zero_grad()
    z = gae.encode(x, train_pos_edge_index)
    loss = gae.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test_gae(gae, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
    gae.eval()
    with torch.no_grad():
        z = gae.encode(x, train_pos_edge_index)
    return gae.test(z, test_pos_edge_index, test_neg_edge_index)

r_ae.reset_parameters()

epochs = 100
for epoch in range(1, epochs + 1):
    loss = train_gae(r_ae, r_opt, r_x, r_data.train_pos_edge_index)
    auc, ap = test_gae(r_ae, r_x, r_data.train_pos_edge_index, r_data.test_pos_edge_index, r_data.test_neg_edge_index)
    if epoch % 10 == 0:
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

In [None]:
# my train has automatic train-test edge split but this is what i was going for before [mmvae]
def train(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        data = unpack_data(dataT, device=device)
        optimizer.zero_grad()
        loss = - objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))


# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for batch, labels in loader:
        # Transfer to GPU if available
        batch, labels = batch.to(device), labels.to(device)
        # Model computations
        [...]