In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric as tg
from torch_geometric.loader import DataLoader
from pymongo import MongoClient
from chespex.encoding import Encoder, Decoder, Autoencoder, graph_loss, prediction_to_molecules

In [8]:
# Level dependent parameters
level = 0
n_bead_classes = [5, 13][level]
n_bead_types = [15, 45][level]
latent_dim = [4, 5, 5][level]
number_of_molecules = [89960, 6742680, 136870880][level]
embedding_dim = [128, 256, 512][level]
encoder_dim = [128, 64, 64][level]
node_hidden_dims = [[1024, 1024], [1024, 1024, 1024], [2048, 1024, 1024]][level]
encoded_node_dims = [[128, 32], [256, 32], [512, 32]][level]
log_steps = [10, 100, 1000][level]
batches_per_query = [3, 20, 20][level]
# Fixed parameters
n_beads = 4
n_bead_sizes = 3
feature_dim = n_bead_classes + n_bead_sizes + 2 + 1 # bead class, bead size, charge, octanol-water transfer free energy
print(f'Feature dimension: {feature_dim}')

Feature dimension: 11


In [9]:
def as_torch_geometric_data(db_molecule):
    return tg.data.Data(x=torch.Tensor(db_molecule["node_features"]), edge_index=torch.Tensor(db_molecule["edge_index"]).to(torch.int64))

def transform(batch):
    node_type = nn.functional.one_hot(batch.x[:, 1].to(torch.int64), n_bead_classes)
    node_size = nn.functional.one_hot(batch.x[:, 0].to(torch.int64), 3)
    node_charge = nn.functional.one_hot(batch.x[:, 2].to(torch.int64), 3)[:,::2]
    node_oco_w_tfe = batch.x[:, 3:] / 10 # Octanol-water transfer free energy
    batch.x = torch.cat([node_type, node_size, node_charge, node_oco_w_tfe], dim=1)
    return batch.cuda()

class MoleculeDataset(torch.utils.data.IterableDataset):
    BATCH_SIZE = 8192*2
    def __iter__(self):
        client = MongoClient("mongodb://localhost:27017")
        database = client.get_database(f"molecules-4")
        collection = database.get_collection(f"level-{level}")
        for _ in range(int(number_of_molecules / MoleculeDataset.BATCH_SIZE / batches_per_query)):
            queue = collection.aggregate([{"$sample": {"size": MoleculeDataset.BATCH_SIZE * batches_per_query}}])
            for db_molecule in queue:
                yield as_torch_geometric_data(db_molecule)
        client.close()

In [10]:
encoder = Encoder(
	node_feature_dim=feature_dim,
	embedding_dim=embedding_dim,
	encoder_dim=encoder_dim,
  	message_passing_steps=3,
    latent_dim=latent_dim
)
decoder = Decoder(
    n_nodes=n_beads,
    node_feature_dim=feature_dim,
    latent_dim=latent_dim,
    node_hidden_dims=node_hidden_dims,
    encoded_node_dims=encoded_node_dims,
    edge_prelayer_dims=[512, 512],
    edge_hidden_dims=[512, 512],
)
model = Autoencoder(encoder, decoder).to('cuda')
params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model has {params_count} trainable parameters")

dataset = MoleculeDataset()
dataloader = DataLoader(dataset, batch_size=MoleculeDataset.BATCH_SIZE, num_workers=8)

optimizer = optim.Adam([
    {'params': model.encoder.parameters()},
    {'params': model.decoder.parameters(), 'weight_decay': 0.0001},
], weight_decay=0, lr=1e-3)

Model has 2349786 trainable parameters


In [11]:
#model = torch.load(f"model-4-{n_bead_types}-try1.pt")
#print(model)

In [12]:
# Test the model input and output
for batch in dataloader:
	batch = transform(batch)
	predicted_nodes, predicted_adjacency_matrix = model(batch)
	loss = graph_loss(batch, predicted_nodes, predicted_adjacency_matrix, n_bead_classes)
	print(f"Loss: {loss.item()}")
	break

Loss: 7.451411247253418


In [13]:
loss_history = []
accuracy_history = []

In [None]:
try:
    model.train()
    temp_loss_history = []
    temp_accuracy_history = []
    for epoch in tqdm(range(1000)):
        for bidx, batch in enumerate(dataloader):
            batch = transform(batch)
            optimizer.zero_grad(set_to_none=True)
            nodes, adjacency_matrix, latent_space = model(batch, return_latent_space=True)
            latent_space_loss = 0.0001 * torch.linalg.vector_norm(latent_space, dim=1).mean()
            loss, accuracy = graph_loss(batch, nodes, adjacency_matrix, return_accuracy=True)
            total_loss = loss + latent_space_loss
            total_loss.backward()
            optimizer.step()
            temp_loss_history.append(total_loss.item())
            temp_accuracy_history.append(accuracy.tolist())
            if len(temp_loss_history) == 5:
                loss_history.append(np.mean(temp_loss_history))
                accuracy_history.append(np.nanmean(temp_accuracy_history, axis=0))
                temp_loss_history = []
                temp_accuracy_history = []
                accuracy_text = ', '.join([f'{n}: {an:.3f}' for n, an in zip(['edges', 'class', 'size', 'charge'], accuracy_history[-1])])
                print(f'Training epoch {epoch+1}, step {bidx}, LOSS: {loss_history[-1]:.4f}, ACCURACY: {accuracy_text}', end='\r')
                if len(loss_history) % (log_steps // 5) == 0:
                    with open(f'train-{n_bead_types}.log', 'a') as log_file:
                        log_file.write(f'Epoch {epoch+1}, STEP: {bidx+1}, LOSS: {loss_history[-1]:.4f}, ACCURACY: {accuracy_text}\n')
except KeyboardInterrupt:
    print('Training interrupted')
    pass
finally:
    pass
    torch.save(model, f"model-4-{n_bead_types}.pt")
    print('Model saved')
    with open(f'train-{n_bead_types}.log', 'a') as log_file:
        log_file.write(f'Model saved\n')

In [None]:
plt.figure(figsize=(8, 3))
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
l, = plt.plot(loss_history, color=colors[0])
lines = [l]
plt.xlabel('Epoch')
plt.ylabel('Loss')
twx = plt.gca().twinx()
for i, acc in enumerate(np.array(accuracy_history).T):
    l, = twx.plot(acc, color=colors[i+1])
    lines.append(l)
plt.ylabel('Accuracy')
plt.grid(alpha=0.3)
plt.legend(lines, ['Loss', 'Edges', 'Class', 'Size', 'Charge'], loc=(0.77,0.14))
plt.show()

In [10]:
torch.save(model, f'model-4-{n_bead_types}.pt')