# Model Factory

> Scripts to get each Generative Model easily

In [None]:
#| default_exp model_factory

In [None]:
#| export
#| hide
from orbit_generation.architectures import get_conv5_vae_components, get_conv5_legit_tsgm_vae_components, get_inception_time_vae_components
from orbit_generation.vae import BetaVAE, InceptionTimeVAE

import torch

In [None]:
#| export
def get_model(params):
    model_name = params['model_name']
    model_kwargs = params.get('model_kwargs', {})
    
    if model_name == 'vae_inception_time':
        # Accessing InceptionTime VAE components using parameters from the dictionary
        encoder, decoder = get_inception_time_vae_components(
            seq_len=params['seq_len'], 
            feat_dim=params['feature_dim'], 
            latent_dim=params['latent_dim'],
            model_kwargs=model_kwargs
        )
        # Build the InceptionTimeVAE
        vae = InceptionTimeVAE(
            encoder=encoder,
            decoder=decoder,
            beta=model_kwargs.get('beta', 1.0),
            loss_fn=model_kwargs.get('loss_fn', None),
            optimizer_cls=model_kwargs.get('optimizer_cls', torch.optim.Adam),
            lr=model_kwargs.get('lr', params.get('lr'))
        )

        return vae

    # Check if the model name starts with 'vae'
    elif model_name.startswith('vae'):
        # Handle specific VAE models
        if model_name == 'vae_conv5_legit':
            # Accessing model configuration from the zoo using parameters from the dictionary
            encoder, decoder = get_conv5_legit_tsgm_vae_components(
                seq_len=params['seq_len'], 
                feat_dim=params['feature_dim'], 
                latent_dim=params['latent_dim'],
                dropout_rate=model_kwargs.get('dropout_rate', 0.1)
            )

        elif model_name == 'vae_conv5_1':
            # Accessing model configuration from the zoo using parameters from the dictionary
            encoder, decoder = get_conv5_vae_components(
                seq_len=params['seq_len'], 
                feat_dim=params['feature_dim'], 
                latent_dim=params['latent_dim'],
                dropout_rate=model_kwargs.get('dropout_rate', 0.1)
            )

        else:
            raise ValueError(f"Unknown VAE model: {model_name}")
        
        # Build the VAE
        vae = BetaVAE(
            encoder=encoder,
            decoder=decoder,
            beta=model_kwargs.get('beta', 1.0),
            loss_fn=model_kwargs.get('loss_fn', None),
            optimizer_cls=model_kwargs.get('optimizer_cls', torch.optim.Adam),
            lr=model_kwargs.get('lr', params.get('lr'))
        )
        
        return vae
    
    else:
        raise ValueError(f"Model name '{model_name}' is not recognized or not supported yet.")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()