In [1]:
%reload_ext autoreload
%autoreload 2

# Imports

In [2]:
from kret_notebook import *  # NOTE import first
from kret_matplotlib.mpl_nb_imports import *
from kret_np_pd.np_pd_nb_imports import *
from kret_sklearn.sklearn_nb_imports import *
from kret_torch_utils.torch_nb_imports import *
from kret_lightning.lightning_nb_imports import *
from kret_tqdm.tqdm_nb_imports import *
from kret_type_hints.types_nb_imports import *
from kret_utils.utils_nb_imports import *

# from kret_wandb.wandb_nb_imports import *  # NOTE this is slow to import

Loaded environment variables from /Users/Akseldkw/coding/kretsinger/.env
[kret_matplotlib.mpl_nb_imports] Imported kret_matplotlib.mpl_nb_imports in 2.0006 seconds
[kret_np_pd.np_pd_nb_imports] Imported kret_np_pd.np_pd_nb_imports in 1.1344 seconds
[kret_sklearn.sklearn_nb_imports] Imported kret_sklearn.sklearn_nb_imports in 0.5408 seconds
[kret_torch_utils.torch_nb_imports] Imported kret_torch_utils.torch_nb_imports in 2.5645 seconds
[kret_lightning.lightning_nb_imports] Imported kret_lightning.lightning_nb_imports in 0.1031 seconds
[kret_tqdm.tqdm_nb_imports] Imported kret_tqdm.tqdm_nb_imports in 0.0000 seconds
[kret_type_hints.types_nb_imports] Imported kret_type_hints.types_nb_imports in 0.0003 seconds
[kret_utils.utils_nb_imports] Imported kret_utils.utils_nb_imports in 0.0002 seconds


In [3]:
from lightning.pytorch.loggers import CSVLogger

# Load Data

In [4]:
from mnist_data import MNISTDataModule

mnist_data_module = MNISTDataModule(DATA_DIR / "MNIST")

# Implementation

In [5]:
from kret_lightning.abc_lightning import HPasKwargs


class Kret_AutoEncoder(BaseLightningNN):
    _criterion: nn.Module = nn.MSELoss()

    def __init__(self, embedding_dim: tuple[int, int], **kwargs: t.Unpack[HPasKwargs]):
        super().__init__(**kwargs)
        # define any number of nn.Modules (or use your current ones)
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(*embedding_dim))
        self.decoder = nn.Sequential(nn.Linear(embedding_dim[1], embedding_dim[0]), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # print(f"Input shape: {x.shape}")
        x = x.view(x.size(0), -1)  # flatten
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

    # endregion
    # region Training / Validation Steps
    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """

        loss = ...
        return loss
        """
        x, _ = batch
        x = x.view(x.size(0), -1)
        outputs = self(x)
        loss = self.get_loss(outputs, x)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """
        val_loss = ...
        self.log('val_loss', val_loss)
        """
        x, _ = batch
        x = x.view(x.size(0), -1)
        outputs = self(x)
        val_loss = self.get_loss(outputs, x)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    # endregion

In [6]:
base = Kret_AutoEncoder((64, 3))
auto_enc = base

In [7]:
base.save_load_logging_dict

{'save_dir': PosixPath('/Users/Akseldkw/coding/data_kretsinger/lightning_logs'),
 'name': 'Kret_AutoEncoder',
 'version': 'v_000'}

In [8]:
logger = CSVLogger(**base.save_load_logging_dict)

In [9]:
static_args = TrainerStaticDefaults.TRAINER_QUICK_ITER
dynamic_args = TrainerDynamicDefaults.trainer_dynamic_defaults(auto_enc, mnist_data_module)
trainer_args = static_args | dynamic_args

In [10]:
trainer_args

{'min_epochs': 5,
 'max_epochs': 5,
 'check_val_every_n_epoch': 1,
 'log_every_n_steps': 10,
 'limit_train_batches': 0.1,
 'limit_val_batches': 0.1,
 'limit_test_batches': 0.1,
 'logger': <lightning.pytorch.loggers.csv_logs.CSVLogger at 0x16cc06720>,
 'callbacks': [<lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint at 0x3056526c0>],
 'default_root_dir': PosixPath('/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000')}

In [11]:
logger = trainer_args["logger"]
logger._save_dir, type(logger)

('/Users/Akseldkw/coding/data_kretsinger/lightning_logs',
 lightning.pytorch.loggers.csv_logs.CSVLogger)

In [12]:
trainer = L.Trainer(**trainer_args)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores


## Train

In [13]:
raise ValueError("Stop here")

ValueError: Stop here

In [14]:
trainer.fit(model=auto_enc, datamodule=mnist_data_module)

Output()

Epoch 0, global step 86: 'val_loss' reached 0.61594 (best 0.61594), saving model to '/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt' as top 1


Epoch 1, global step 172: 'val_loss' reached 0.53020 (best 0.53020), saving model to '/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt' as top 1


Epoch 2, global step 258: 'val_loss' reached 0.52180 (best 0.52180), saving model to '/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt' as top 1


Epoch 3, global step 344: 'val_loss' reached 0.50724 (best 0.50724), saving model to '/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt' as top 1


Epoch 4, global step 430: 'val_loss' reached 0.48235 (best 0.48235), saving model to '/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=5` reached.


## Load

In [18]:
auto_enc.__class__.__name__

'Kret_AutoEncoder'

In [19]:
checkpoint_path = Kret_AutoEncoder.ckpt_file_name()
checkpoint_path

PosixPath('/Users/Akseldkw/coding/data_kretsinger/lightning_logs/Kret_AutoEncoder/v_000/checkpoints/best.ckpt')

In [None]:
auto_enc_saved = Kret_AutoEncoder.load_from_checkpoint(checkpoint_path)

In [24]:
auto_enc_saved.hparams

"embedding_dim": (64, 3)
"gamma":         0.5
"l1_penalty":    0.0
"l2_penalty":    0.0
"lr":            0.001
"patience":      10
"stepsize":      12

In [25]:
auto_enc_saved.hparams_initial

"embedding_dim": (64, 3)
"gamma":         0.5
"l1_penalty":    0.0
"l2_penalty":    0.0
"lr":            0.001
"patience":      10
"stepsize":      12

# Sandbox