# Model Factory

> Scripts to get each Generative Model easily

In [9]:
#| default_exp model_factory

In [10]:
#| export
#| hide
from orbit_generation.architectures import VAE_CONV5Architecture
from orbit_generation.vae import BetaVAE

In [11]:
import os
print("Current Working Directory:", os.getcwd())


Current Working Directory: /orbit-generation/nbs


In [12]:
#| export
def get_model(params):
    model_name = params['model_name']

    if model_name == 'vae_conv5':
        # Accessing model configuration from the zoo using parameters from the dictionary
        architecture = VAE_CONV5Architecture(
            seq_len=params['seq_len'], 
            feat_dim=params['feature_dim'], 
            latent_dim=params['latent_dim']
        )

        # Extracting encoder and decoder from the architecture
        encoder, decoder = architecture.encoder, architecture.decoder

        # Build the VAE
        vae = BetaVAE(
            encoder=encoder,
            decoder=decoder,
            beta=params.get('beta', None),
            loss_fn=params.get('loss_fn', None),
            optimizer_cls=params.get('optimizer_cls', None),
            lr=params.get('lr', None)
        )

        return vae

    elif model_name == 'timeGAN':
        model = tsgm.models.timeGAN.TimeGAN(
            seq_len=params['seq_len'],
            module="gru",
            hidden_dim=24,
            n_features=params['feature_dim'],
            n_layers=3,
            batch_size=params['batch_size'],
            gamma=1.0,
        )
        # .compile() sets all optimizers to Adam by default
        model.compile(optimizer=params['optimizer']['name'], learning_rate=params['optimizer']['learning_rate'])
        return model

    else:
        raise ValueError(f"Unsupported model_name: {model_name}")

In [13]:
#| hide
import nbdev; nbdev.nbdev_export()