In [1]:
import os

if not "src" in os.listdir():
    os.chdir("../../src")
os.listdir()

['models',
 'main.py',
 'utils',
 '__pycache__',
 '__init__.py',
 'run_pcm_mnist_experiment.py']

In [2]:
import os

current_directory = os.getcwd()
print("Current Directory:", current_directory)

Current Directory: /home/prz/bioml/mag/D4L-Hackaton/src


In [3]:
from pathlib import Path

Path(current_directory).name

'src'

In [4]:
import torch
from pathlib import Path
import pytorch_lightning as pl

# Plotting
import matplotlib
import matplotlib.pyplot as plt

# import scienceplots

# plt.style.use("science")

# Neptune
from pytorch_lightning.loggers import NeptuneLogger

# Conditional MNIST dataset
from utils.config import load_config_from_path
from utils.data.pcm.mnist_cond_trans_dataset import (
    ConditionalMNIST,
    get_ConditionalMnistDataloader,
)

# Typing
from utils.common_types import Batch
from typing import Callable, Dict, Any, Tuple, List

# Other
import inspect
from functools import partial
from argparse import Namespace

# Paths
from utils.paths import CONFIG_PATH_DATA, CONFIG_PATH_MODELS

# Chain model
from models.components.chain import Chain

# Neptune
import neptune

# Plotting & Callbacks
from src.utils.evaluation.plots import (
    plot_original_vs_reconstructed,
    plot_images_with_conditions,
    wrap_with_first_batch,
    plot_latent,
    NeptunePlotLogCallback,
    plot_latent_with_pca_umap,
)

ModuleNotFoundError: No module named 'utils'

In [3]:
%env NEPTUNE_API_TOKEN="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMDY1MDE0NC00Zjg3LTRiZmYtOTQwNi0xNjNlNmZjNWQ5MDkifQ=="

env: NEPTUNE_API_TOKEN="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMDY1MDE0NC00Zjg3LTRiZmYtOTQwNi0xNjNlNmZjNWQ5MDkifQ=="


In [4]:
project_name = "multimodal/vaes"

In [5]:
# Access the NEPTUNE_API_TOKEN environment variable
neptune_api_token = os.getenv("NEPTUNE_API_TOKEN")

# Check if the token is available
if neptune_api_token is None:
    raise ValueError("NEPTUNE_API_TOKEN environment variable is not set.")

In [6]:
# from torch.utils.data import Subset

train_data_cfg_file_path = CONFIG_PATH_DATA / "pcm-mnist-02-train.yaml"

test_data_cfg_file_path = CONFIG_PATH_DATA / "pcm-mnist-02-test.yaml"

train_data_cfg = load_config_from_path(file_path=train_data_cfg_file_path)
cmnist_train = ConditionalMNIST(cfg=train_data_cfg)

# cmnist_train = Subset(cmnist_train, list(range(1000)))

cmnist_train_dataloader = get_ConditionalMnistDataloader(
    cmnist=cmnist_train, batch_size=128, shuffle=True
)

test_data_cfg = load_config_from_path(file_path=test_data_cfg_file_path)
cmnist_val = ConditionalMNIST(cfg=test_data_cfg)

# cmnist_val = Subset(cmnist_val, list(range(1000)))

cmnist_val_dataloader = get_ConditionalMnistDataloader(
    cmnist=cmnist_val, batch_size=128, shuffle=False
)

In [7]:
plot_images_with_conditions_wrapped_with_wrap_with_first_batch = wrap_with_first_batch(
    plot_images_with_conditions,
    **vars(
        Namespace(
            imgs_name="img",
            conditions_name="condition_token_ids",
            condition_values_name="condition_values",
            disp_batch_size=10,
            disp_n_latent_samples=1,
            filename_comp="filename",
            disp_img_size=2,
            y_title_shift=0.91,
        )
    ),
)

fig = plot_images_with_conditions_wrapped_with_wrap_with_first_batch(
    dataloader=cmnist_val_dataloader, processing_function=lambda batch: batch
)

fig

{'filename': <Figure size 200x2000 with 10 Axes>}

### Plotting Callbacks

In [8]:
plot_prior_sampled_imgs_with_conditions_wrapped = wrap_with_first_batch(
    plot_images_with_conditions,
    **vars(
        Namespace(
            filename_comp="embeddings_plot",
            imgs_name="img",
            conditions_name="condition_token_ids",
            condition_values_name="condition_values",
            disp_batch_size=10,
            disp_n_latent_samples=16,
            disp_img_size=2,
            y_title_shift=0.91,
        )
    ),
)

plot_sample_prior_callback = NeptunePlotLogCallback(
    plotting_function_taking_dataloader=plot_prior_sampled_imgs_with_conditions_wrapped,
    command_name="sample-prior",
    neptune_plot_log_path="validation_plots/sample_prior",
    plot_file_base_name="sample_prior",
    command_dynamic_kwargs={},
)

In [9]:
plot_original_vs_reconstructed_wrapped = wrap_with_first_batch(
    plot_original_vs_reconstructed,
    **vars(
        Namespace(
            org_imgs_name="img_org",
            reconstructed_imgs_name="img",
            num_images=10,
            wspace=0.25,
            hspace=0.25,
            filename_comp="org_vs_reconstr",
        )
    ),
)

plot_reconstruction_callback = NeptunePlotLogCallback(
    plotting_function_taking_dataloader=plot_original_vs_reconstructed_wrapped,
    command_name="encode-decode",
    neptune_plot_log_path="validation_plots/reconstructed",
    plot_file_base_name="embedding",
    command_dynamic_kwargs={},
)

In [10]:
plot_embeddings_callback = NeptunePlotLogCallback(
    plotting_function_taking_dataloader=partial(
        plot_latent,
        **vars(
            Namespace(
                data_name="img",
                condition_value_name="condition_values",
                filename_comp="latent",
                condition_value_idxs=[0, 1, 2, 3, 4],
                are_conditions_categorical=[True, True, True, True, True],
            )
        )
    ),
    command_dynamic_kwargs={},
    command_name="encode",
    neptune_plot_log_path="validation_plots/embeddings",
    plot_file_base_name="embedding",
)

In [11]:
# def plot_latent_with_pca_umap(
#     processing_function: Callable,
#     dataloader: torch.utils.data.DataLoader,
#     data_name: str,
#     condition_value_name: str,
#     filename_comp: str,
#     num_batches: None | int = None,
#     plot_dims: Tuple[int] = (0, 1),
#     figsize: Tuple[float, float] = (6, 6),
#     n_components: int = 2,
#     umap_n_neighbors: int = 15,
#     umap_min_dist: float = 0.1,
# ) -> List[matplotlib.figure.Figure]:

## Model

In [12]:
model_cfg_file_path = CONFIG_PATH_MODELS / "pcm-04.yaml"

chain_cfg = load_config_from_path(file_path=model_cfg_file_path)
chainae = Chain(cfg=chain_cfg)
chainae

Chain(
  (_chain): ModuleDict(
    (encoder): BlockStack(
      (blocks): Sequential(
        (0): Block(
          (layer): ModuleList(
            (0): Linear(in_features=784, out_features=512, bias=True)
            (1): ReLU()
          )
        )
        (1): Block(
          (layer): ModuleList(
            (0): Linear(in_features=512, out_features=100, bias=True)
            (1): ReLU()
          )
        )
        (2): Linear(in_features=100, out_features=32, bias=True)
      )
    )
    (posterior): GaussianPosterior()
    (condition_embedding): ConditionEmbeddingTransformer(
      (_condition_embedding_module): DiscreteValuedConditionEmbedding(
        (_condition_embeddings): Embedding(6, 128, padding_idx=0)
        (_category_embeddings): Embedding(27, 128, padding_idx=0)
      )
      (_transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-3): 4 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj)

In [13]:
d = {"g": 5}
f"keys = {d.keys()}"

"keys = dict_keys(['g'])"

In [13]:
# Create a Neptune logger
neptune_logger = NeptuneLogger(
    api_key=neptune_api_token,
    project=project_name,
    name="cpiwae-3-09-24",
)

trainer = pl.Trainer(
    max_epochs=500,
    logger=neptune_logger,
    check_val_every_n_epoch=10,
    callbacks=[
        plot_sample_prior_callback,
        # plot_embeddings_callback,
        plot_reconstruction_callback,
    ],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [14]:
if trainer.logger is not None and hasattr(trainer.logger, "experiment"):
    trainer.logger.experiment[f"config/model_config.yaml"].upload(
        neptune.types.File(str(model_cfg_file_path))
    )
    trainer.logger.experiment[f"config/train_data_config.yaml"].upload(
        neptune.types.File(str(train_data_cfg_file_path))
    )
    trainer.logger.experiment[f"config/test_data_config.yaml"].upload(
        neptune.types.File(str(test_data_cfg_file_path))
    )



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/multimodal/vaes/e/VAES-66


In [15]:
trainer.fit(
    model=chainae,
    train_dataloaders=cmnist_train_dataloader,
    val_dataloaders=cmnist_val_dataloader,
)

/home/prz/bioml/.venv/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py:316: The lr scheduler dict contains the key(s) ['monitor', 'strict'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.

  | Name   | Type       | Params | Mode 
----------------------------------------------
0 | _chain | ModuleDict | 1.5 M  | train
----------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.843     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/prz/bioml/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/prz/bioml/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/home/prz/bioml/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [16]:
# Stop the Neptune experiment after training ends
neptune_logger.experiment.stop()

[neptune] [info   ] Shutting down background jobs, please wait a moment...
[neptune] [info   ] Done!
[neptune] [info   ] Waiting for the remaining 1 operations to synchronize with Neptune. Do not kill this process.
[neptune] [info   ] All 1 operations synced, thanks for waiting!
[neptune] [info   ] Explore the metadata in the Neptune app: https://app.neptune.ai/multimodal/vaes/e/VAES-66/metadata


In [None]:
# run_id = "VAES-40"  # Replace with your run ID
# run = neptune.init_run(
#     project=project_name, api_token=neptune_api_token, with_id=run_id, mode="read-only"
# )

In [None]:
# run.get_structure()

In [None]:
# run["training/model/checkpoints/epoch=9-step=240"].download()