In [1]:
import pytorch_lightning as pl
from astrochem_embedding.pipeline.data import MaskedStringDataModule
from astrochem_embedding import get_paths
from astrochem_embedding.models import models
from astrochem_embedding import VICGAE
import os

In [2]:
model = VICGAE.from_pretrained()

  __import__(module, level=0)
  return torch.load(io.BytesIO(b))


In [3]:
pl.seed_everything(215015)

BATCH_SIZE = 128
NUM_WORKERS = os.cpu_count() - 2
EMBEDDING_DIM = 128
Z_DIM = 32
NUM_LAYERS = 1
LR = 1e-4

paths = get_paths()


Seed set to 215015


In [4]:
def train_model(ckpt_name: str = "VICGAE.ckpt"):
    model = models.VICGAE(EMBEDDING_DIM, Z_DIM, NUM_LAYERS, lr=LR)

    data = MaskedStringDataModule(BATCH_SIZE, NUM_WORKERS)

    logger = pl.loggers.TensorBoardLogger(
        "tb_logs", name="VICAstrochemEmbedder", log_graph=True
    )
    summarizer = pl.callbacks.ModelSummary(max_depth=-1)

    trainer = pl.Trainer(max_epochs=5, callbacks=[summarizer], logger=logger)
    trainer.fit(model, data)

    trainer.save_checkpoint(paths.get("models").joinpath(ckpt_name))

# train_model('VICGAE_test.ckpt')

In [5]:
loaded_model = models.VICGAE.load_from_checkpoint(paths.get("models").joinpath("VICGAE.ckpt"))

In [6]:
loaded_model, model

(VICGAE(
   (embedding): Embedding(634, 128)
   (encoder): GRU(128, 32, batch_first=True)
   (decoder): GRU(32, 32, batch_first=True)
   (metric): BCELoss()
   (output): Sequential(
     (0): Linear(in_features=32, out_features=634, bias=True)
     (1): Softmax(dim=-1)
   )
   (vic_reg): VICRegularization(
     (variance): VarianceHinge()
     (covariance): CovarianceLoss()
     (invariance): MSELoss()
   )
 ),
 VICGAE(
   (embedding): Embedding(634, 128)
   (encoder): GRU(128, 32, batch_first=True)
   (decoder): GRU(32, 32, batch_first=True)
   (metric): BCELoss()
   (output): Sequential(
     (0): Linear(in_features=32, out_features=634, bias=True)
     (1): Softmax(dim=-1)
   )
   (vic_reg): VICRegularization(
     (variance): VarianceHinge()
     (covariance): CovarianceLoss()
     (invariance): MSELoss()
   )
 ))

In [7]:
model.embed_smiles("c1ccccc1")

tensor([[-3.7251e-02, -1.9092e-02,  1.1982e-02, -3.8894e-02, -3.3427e-02,
         -1.4961e-02, -1.0720e-03, -4.0079e-02,  5.4368e-02, -2.9073e-02,
         -6.8230e-03,  6.2277e-02, -1.5995e-02, -6.1956e-03,  3.1204e-02,
          3.6043e-02,  2.9236e-02, -1.0343e-03, -7.5238e-05, -6.6223e-03,
          3.0475e-02, -7.7275e-03,  3.3575e-03,  3.5621e-02, -3.2590e-02,
          4.3446e-03, -5.8487e-03,  2.0878e-02,  3.1586e-02, -4.0533e-02,
          4.3696e-02, -2.0791e-02]])

In [8]:
loaded_model.embed_smiles("c1ccccc1")

tensor([[-3.7251e-02, -1.9092e-02,  1.1982e-02, -3.8894e-02, -3.3427e-02,
         -1.4961e-02, -1.0720e-03, -4.0079e-02,  5.4368e-02, -2.9073e-02,
         -6.8230e-03,  6.2277e-02, -1.5995e-02, -6.1956e-03,  3.1204e-02,
          3.6043e-02,  2.9236e-02, -1.0343e-03, -7.5238e-05, -6.6223e-03,
          3.0475e-02, -7.7275e-03,  3.3575e-03,  3.5621e-02, -3.2590e-02,
          4.3446e-03, -5.8487e-03,  2.0878e-02,  3.1586e-02, -4.0533e-02,
          4.3696e-02, -2.0791e-02]])