In [1]:
import pytorch_lightning as pl
from astrochem_embedding.pipeline.data import MaskedStringDataModule
from astrochem_embedding import get_paths
from astrochem_embedding.models import models
import os

In [2]:
pl.seed_everything(215015)

BATCH_SIZE = 128
NUM_WORKERS = os.cpu_count() - 2
EMBEDDING_DIM = 128
Z_DIM = 32
NUM_LAYERS = 1
LR = 1e-4

Seed set to 215015


In [3]:
model = models.VICGAE(EMBEDDING_DIM, Z_DIM, NUM_LAYERS, lr=LR)

data = MaskedStringDataModule(BATCH_SIZE, NUM_WORKERS)

logger = pl.loggers.TensorBoardLogger(
    "tb_logs", name="VICAstrochemEmbedder", log_graph=True
)
summarizer = pl.callbacks.ModelSummary(max_depth=-1)

In [4]:
trainer = pl.Trainer(max_epochs=5, callbacks=[summarizer], logger=logger)
trainer.fit(model, data)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

   | Name               | Type              | Params | Mode  | In sizes                    | Out sizes                  
------------------------------------------------------------------------------------------------------------------------------
0  | embedding          | Embedding         | 81.2 K | train | [64, 10]                    | [64, 10, 128]              
1  | encoder            | GRU               | 15.6 K | train | [64, 10, 128]               | [[64, 10, 32], [1, 64, 32]]
2  | decoder            | GRU               | 6.3 K  | train | [[64, 10, 32], [1, 64, 32]] | [[64, 10, 32], [1, 64, 32]]
3  | metric             | BCELoss           | 0      | train | ?                           | ?  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



In [None]:
paths = get_paths()
trainer.save_checkpoint(paths.get("models").joinpath("VICGAE.ckpt"))