In [141]:
from lightning import LightningModule, seed_everything, LightningDataModule, Callback, Trainer
from hydra_zen import builds, make_config, instantiate, zen, just
from omegaconf import MISSING, DictConfig, OmegaConf

import lightning as pl
import torch
from torch_compose.module import ModuleGraph, DirectedModule

import wandb

from typing import Any, Callable, Optional, Sequence, Tuple, Union

In [142]:

class GenericLitModule(LightningModule):
    

    def __init__(
        self,
        model,
        loss_fn,
        optimizer_partial,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        #

        self.model = model
        self.loss_fn = loss_fn
        self.save_hyperparameters(logger=False)
        self.optimizer_partial = optimizer_partial

    def forward(self, x: torch.Tensor):
        return self.model(x)

    def training_step(self, batch: Any, batch_idx: int):
        y_hat = self.forward(batch[0])
        loss = self.loss_fn(y_hat, batch[1])
        self.log("train/loss", loss, on_epoch=True)
        return loss_dict

    def validation_step(self, batch: Any, batch_idx: int):
        y_hat = self.forward(batch[0])
        loss = self.loss_fn(y_hat, batch[1])
        self.log_dict({"valid/loss": loss, "valid/prediction": y_hat, "valid/targets": targets})

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return self.optimizer_partial(
            params=self.parameters(),
        )
        
class TorchComposeLitModule(LightningModule):
    def __init__(
        self,
        module_graph: ModuleGraph,
        optimizer_partial: Callable,
        loss_key: str = "combined_loss",
    ):
        super().__init__()
        self.module_graph = module_graph
        self.loss_key = loss_key
        self.optimizer_partial = optimizer_partial

    def forward(self, batch):
        return self.module_graph.forward(batch)

    def training_step(self, batch, batch_idx):
        batch = self.forward(batch)
        loss = batch[self.loss_key]
        return loss

    def validation_step(self, batch, batch_idx):
        batch = self.forward(batch)
        loss = batch[self.loss_key]
        return loss

    def configure_optimizers(self):
        return optimizer_partial(self.parameters())

In [143]:
class DictOutput(DirectedModule):
    def forward(self, x):
        xx = x + 1
        return {'x': xx}
    
    
class TupleOutput(DirectedModule):
    
    def forward(self, x):
        x_squared = x**2
        return x, x_squared
    

In [144]:
class LitDataModule(LightningDataModule):
    def __init__(
        self,
        dataset,
    ):
        super().__init__()

        self.dataset = dataset
        self.save_hyperparameters(logger=False)

    def setup(self, stage: Optional[str] = None):
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass

class DummyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data_list1 = torch.randn(100, 3, 64, 64)  # e.g. 100 64x64 RGB images
        self.data_list2 = torch.randn(100, 10)  # e.g. 100 vectors of length 10

    def __len__(self):
        return len(self.data_list1)

    def __getitem__(self, idx):
        item1 = self.data_list1[idx]
        item2 = self.data_list2[idx]

        return item1, item2

In [145]:
instantiate(Builds_Adam())

functools.partial(<class 'torch.optim.adam.Adam'>, lr=0.001, betas=[0.9, 0.999], eps=1e-08, weight_decay=0, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None)

In [148]:
Builds_LitModule = builds(
    TorchComposeLitModule,
    module_graph=builds(ModuleGraph,
        modules={
            'm0': builds(DictOutput,input_keys=['x0'], output_keys = {'x': 'x1'}),
            'm1': builds(TupleOutput, input_keys=['x1'], output_keys = ['x2', 'x3']),
            }),
    optimizer_partial=builds(torch.optim.Adam, zen_partial=True, populate_full_signature=True),
    populate_full_signature=True,
    hydra_recursive=True
    )

In [149]:
Builds_LitDataModule = builds(
    LitDataModule,
    dataset=builds(DummyDataset, populate_full_signature=True,),
    populate_full_signature=True,
)

Builds_Trainer = builds(Trainer, populate_full_signature=True,)

In [150]:
builds_wandb_run = builds(wandb.init, zen_partial=True, populate_full_signature=True)

In [151]:
def pre_seed(cfg):
    seed_everything(cfg.random_seed)

In [152]:
def start_run(config):
    run = instantiate(config.run)
    with run(config=OmegaConf.to_container(OmegaConf.structured(config))) as run:
        return train(
            model=config.model, 
            datamodule=config.datamodule, 
            trainer=config.trainer,
            optim=config.optim,
            callbacks=config.callbacks,
            seed=config.seed,
            run=run,
            )

In [180]:
def train(
    lit_module, 
    lit_datamodule, 
    trainer,
    wandb_run,
    ):
    with wandb_run() as run:
        
        lit_datamodule.setup()
    
    # dl = iter(datamodule.train_dataloader())
    # batch = next(dl)

    # model.forward(batch)
    
    # trainer.fit(model=model, datamodule=datamodule)

In [181]:
def launch_wandb_run(cfg):
    everything_but = {k:v for k,v in cfg.items() if k not in ['wandb_run']}
    cfg.wandb_run.config = everything_but

In [182]:
config = make_config(
    lit_module=Builds_LitModule,
    lit_datamodule = Builds_LitDataModule,
    trainer = Builds_Trainer,
    wandb_run = builds_wandb_run,
    random_seed = just(1),
)

In [183]:
zen_train_func = zen(train, pre_call=[launch_wandb_run, pre_seed], exclude=['random_seed'])

In [184]:
zen_train_func(config)

Global seed set to 1
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Traceback (most recent call last):
  File "/home/naka/.cache/pypoetry/virtualenvs/torch-compose-LagqJUxr-py3.10/lib/python3.10/site-packages/wandb/sdk/wandb_init.py", line 1147, in init
    run = wi.init()
  File "/home/naka/.cache/pypoetry/virtualenvs/torch-compose-LagqJUxr-py3.10/lib/python3.10/site-packages/wandb/sdk/wandb_init.py", line 611, in init
    run = Run(
  File "/home/naka/.cache/pypoetry/virtualenvs/torch-compose-LagqJUxr-py3.10/lib/python3.10/site-packages/wandb/sdk/wandb_run.py", line 537, in __init__
    self._init(
  File "/home/naka/.cache/pypoetry/virtualenvs/torch-compose-LagqJUxr-py3.10/lib/python3.10/site-packages/wandb/sdk/wandb_run.py", line

functools.partial(<function init at 0x7ff3581f7a30>, job_type=None, dir=None, config={'lit_module': TorchComposeLitModule(
  (module_graph): ModuleGraph(
    (m0): DictOutput()
    (m1): TupleOutput()
  )
), 'lit_datamodule': <__main__.LitDataModule object at 0x7ff3551afee0>, 'trainer': <lightning.pytorch.trainer.trainer.Trainer object at 0x7ff3551ac6a0>, 'random_seed': 1}, project=None, entity=None, reinit=None, tags=None, group=None, name=None, notes=None, magic=None, config_exclude_keys=None, config_include_keys=None, anonymous=None, mode=None, allow_val_change=None, resume=None, force=None, tensorboard=None, sync_tensorboard=None, monitor_gym=None, save_code=None, id=None, settings=None)
Problem at: /tmp/ipykernel_86365/877047569.py 8 train


Error: An unexpected error occurred