# Mixture CDF Flow


## Setup


In [None]:
import copy
import torch

import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

from models import MixtureCDFFlow
from utils.data import load_1d_flow_data
from utils.training_and_evaluation import train
from utils.plotting import plot_losses, visualize_model_progression

In [None]:
SEED = 3
np.random.seed(SEED)
torch.manual_seed(SEED)

plt.style.use('ggplot')

## Load Data


In [None]:
GAUSSIAN_PARAMS = [(-1, 0.5), (0.5, 0.25)]
LOADER_ARGS = dict(batch_size=128, shuffle=True)

train_dataloader, test_dataloader = load_1d_flow_data(
    gaussian_params=GAUSSIAN_PARAMS,
    loader_args=LOADER_ARGS)

## Load Model

In [None]:
model = MixtureCDFFlow(num_components=5)

## Training and Evaluation

### Training

In [None]:
LEARNING_RATE = 5e-3

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
EPOCHS = 30
model_copy = copy.deepcopy(model)
train_losses, eval_losses = train(model,
                                  train_dataloader,
                                  test_dataloader,
                                  optimizer,
                                  epochs=EPOCHS)

In [None]:
_ = plot_losses(train_losses, eval_losses, 'Training Progression')

In [None]:

_ = visualize_model_progression(train_dataloader,
                            trained_model=model,
                            untrained_model=model_copy)