In [None]:
import numpy as np
import pandas as pd
from dataclasses import dataclass, asdict
import itertools
import string
from PIL import Image, ImageShow
import matplotlib.pyplot as plt
from plotnine import *
from plotnine_prism import *
from pathlib import Path
import io

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import lightning.pytorch as pl
import einops
from pytorch_lightning.loggers import WandbLogger
import wandb
import git

from gfn_attractors.data.mixture_of_gaussians import GridOfGaussiansDataModule
from gfn_attractors.models.helpers import MLP, PositionalEncoding, SafeEmbedding
from gfn_attractors.models.dynamics import MLPMeanBoundedDynamics
from gfn_attractors.models.discretizer import DiscretizeModule, Discretizer, MLPDiscretizer
from gfn_attractors.models.attractor import RNNAttractorModel, RecurrentMLPAttractorModel
from gfn_attractors.misc import Config
from gfn_attractors.misc import torch_utils as tu, image_utils as iu

In [None]:
run_name = 'gaussians_gfn'
device = 0
seed = 0

repo = git.Repo(search_parent_directories=True)
repo_root = Path(repo.git_dir).parent.expanduser()
checkpoint_dir = repo_root / 'checkpoints' / 'infocog'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
figures_dir = repo_root / 'figures' / 'infocog'
figures_dir.mkdir(parents=True, exist_ok=True)

pd.set_option('display.max_rows', 100)
ImageShow.register(ImageShow.IPythonViewer(), 0)
torch.tensor(0., device=device)

torch.manual_seed(seed)
np.random.seed(seed)

# Initialize dataset

In [None]:
data_module = GridOfGaussiansDataModule(batch_size=50, n_val_samples=500, components_per_dim=4, return_labels=True)
data_module.setup(None)

In [None]:
rows = torch.cat([data_module.train_data.mixture_id.unsqueeze(-1), data_module.train_data.data], dim=-1).numpy()
df = pd.DataFrame(rows, columns=['index', 'x', 'y'])
df['index'] = df['index'].astype(int).astype(str)
df_means = pd.DataFrame(data_module.train_data.means.numpy(), columns=['x', 'y'])
x_labels = {k: i for i, k in enumerate(sorted(set(df_means.x)))}
y_labels = {k: i for i, k in enumerate(sorted(set(df_means.y)))}
df_means['x_label'] = df_means.x.map(x_labels)
df_means['y_label'] = df_means.y.map(y_labels)
df_means['label'] = df_means.apply(lambda row: f"{int(row.x_label)},{int(row.y_label)}", axis=1)

p = (
    ggplot(df, aes(x='x', y='y')) 
    + geom_point(aes(color='index'), alpha=.3, size=5)
    + geom_point(data=df_means, color='black', size=5)
    + geom_text(data=df_means, mapping=aes(label='label', y='y+.1'), size=24, fontweight='bold')
    + labs(x='', y='')
    + theme_bw()
    + theme(legend_position='none',
            axis_line=element_blank(),
            axis_ticks=element_blank(),
            axis_text=element_blank(),
            panel_grid=element_blank(),
            panel_background=element_blank(),
            plot_margin=0.)
)

p.save(figures_dir / 'gaussians_data.png', dpi=300, width=8, height=8)


# Train Model

In [None]:
@dataclass
class ModelConfig(Config):
    score_weight: float = 4.0  # Weight of the score loss
    use_true_attractors: bool = False  # Use the true attractors instead of learning them

    # Dynamics
    max_steps: int = 30
    max_mean: float = 0.05
    attractor_sd: float = 0.04
    t_dependent_forward: bool = True
    t_dependent_backward: bool = True
    x_dependent_forward: bool = True
    x_dependent_backward: bool = True

    # Discretizer
    vocab_size: int = 4
    w_length: int = 2

    # Architetures
    dim_h: int = 256
    dim_t: int = 10  # temporal encoding for the g_model (not dynamics)
    dim_h_discretizer: int = 256
    num_dynamics_layers: int = 3
    num_discretizer_layers: int = 2
    nhead: int = 4 # if using transformer discretizer
    dim_feedforward: int = 256 # if using transformer discretizer
    
    # Training
    lr: float = 1e-3    # GFN learning rate
    discretizer_lr: float = 1e-3
    lookup_lr: float = 1e-4  # Attractor learning rate
    p_explore: float = 0.1  # Applies for both the dynamics and discretizer models

    # EM switch
    num_e_steps: int = 1
    num_m_steps: int = 1
    max_e_steps: int = np.inf
    start_e_steps: int = 0
    e_loss_improvement_threshold: float = 0.0 # if the loss improves by at least this much, keep updating E-step
    e_step_loss_rate: float = None # increase loss threshold for e-step by this much every e_step
    e_step_loss_window: int = 1

In [None]:
class Model(pl.LightningModule):
    """
    Simplest possible model for Gaussians task.
    z0 is the input.
    Reward is negative distance between attractor and input, so no score model is used.
    """

    def __init__(self, config: ModelConfig, data_module: GridOfGaussiansDataModule):
        super().__init__()
        self.config = config
        self.data_module = data_module

        self.temporal_encoding = PositionalEncoding(config.dim_t) if config.dim_t is not None else None
        self.dynamics_model = MLPMeanBoundedDynamics(dim_x=2, dim_z=2, 
                                                     dim_h=config.dim_h,
                                                     num_layers=config.num_dynamics_layers,
                                                     max_mean=config.max_mean,
                                                     max_sd=1,
                                                     t_dependent_forward=config.t_dependent_forward,
                                                     t_dependent_backward=config.t_dependent_backward,
                                                     x_dependent_forward=config.x_dependent_forward,
                                                     x_dependent_backward=config.x_dependent_backward)
        
        self.discretizer = MLPDiscretizer(vocab_size=config.vocab_size, 
                                          length=config.w_length,
                                          dim_input=2,
                                          dim_h=config.dim_h_discretizer,
                                          num_layers=config.num_discretizer_layers)
        self.attractor_model = RecurrentMLPAttractorModel(vocab_size=config.vocab_size, dim_z=2, dim_h=config.dim_h, 
                                                          residual=False, num_layers=1)
        
        self.g_model = MLP(4 + self.config.dim_t, 1, hidden_dim=self.config.dim_h, n_layers=2, nonlinearity=nn.ReLU()) # Flow correction model
        self.f_z_model = MLP(4, 1, hidden_dim=self.config.dim_h, n_layers=2, nonlinearity=nn.ReLU()) # Flow for discretizer

        self.e_step = True
        self.num_mode_updates = 0
        self.num_m_steps = 0
        self.e_step_loss_goal = np.inf
        self.e_step_losses = np.zeros(config.e_step_loss_window)
        
    def configure_optimizers(self):
        return torch.optim.Adam([{'params': [*self.dynamics_model.parameters(),
                                             *self.attractor_model.parameters(),
                                             *self.g_model.parameters()],
                                  'lr': self.config.lr},
                                 {'params': [*self.discretizer.parameters(), 
                                             *self.f_z_model.parameters()],
                                  'lr': self.config.discretizer_lr}])
        
    def get_score(self, z0, z_hat):
        """
        z0: (..., dim_z)
        z_hat: (..., dim_z)
        """
        return Normal(z_hat, self.config.attractor_sd).log_prob(z0).sum(-1)

    def get_log_reward(self, z_traj, z_hat, z0=None):
        """
        z_traj: (batch_size, num_steps, dim_z)
        z_hat: (batch_size, num_steps, dim_z)
        z0: (batch_size, dim_z)
        """
        if z0 is None:
            z0 = z_traj[:,0]
        z0 = einops.repeat(z0, 'n z -> n t z', t=z_traj.shape[1])
        score = self.get_score(z0, z_hat)
        sd_zhat = self.config.attractor_sd
        logp_zhat = Normal(z_hat, sd_zhat).log_prob(z_traj).sum(-1)
        log_reward = self.config.score_weight * score + logp_zhat
        metrics = {
            'reward/score': score.mean().item(),
            'reward/logp_zhat': logp_zhat.mean().item(),
            'reward/total': log_reward.mean().item()
        }
        return log_reward, metrics

    def sample_forward_trajectory(self, x, num_steps=None, p_explore=0.):
        """
        x: (batch_size, 2)
        returns: (batch_size, num_steps, dim_z)
        """
        if num_steps is None:
            num_steps = self.config.max_steps
        return self.dynamics_model.sample_trajectory(x, x, num_steps=num_steps, forward=True, p_explore=p_explore, explore_mean=True)
    
    def sample_backward_trajectory(self, z, x, num_steps=None, p_explore=0.):
        """
        z: (batch_size, 2)
        x: (batch_size, 2)
        returns: (batch_size, num_steps, dim_z)
        """
        if num_steps is None:
            num_steps = self.config.max_steps
        return self.dynamics_model.sample_trajectory(z, x, num_steps=num_steps, forward=False, p_explore=p_explore)

    def sample_w(self, z, argmax=False, p_explore=0.):
        """
        z: tensor with shape (batch_size, dim_z) or (batch_size, num_steps, dim_z)
        returns:
            if z has shape (batch_size, dim_z):
                w: (batch_size, max_w_length)
                logpw: (batch_size,)
            if z has shape (batch_size, num_steps, dim_z):
                w: (batch_size, num_steps, max_w_length)
                logpw: (batch_size, num_steps)
        """
        if z.ndim == 2:
            return self.discretizer.sample(z, argmax=argmax, p_explore=p_explore)

        batch_size, num_steps = z.shape[:2]
        z = z.flatten(0, 1)
        w, logpw = self.discretizer.sample(z, argmax=argmax, p_explore=p_explore)
        w = w.view(batch_size, num_steps, -1)
        logpw = logpw.view(batch_size, num_steps)
        return w, logpw

    def get_z_hat(self, w):
        """
        w: (batch_size, max_w_length)
        returns: (batch_size, dim_z)
        """
        if self.config.use_true_attractors:
            shape = w.shape[:-1]
            w = w.view(-1, w.shape[-1])
            index = w[:,:-1] - 2
            z_hat = self.data_module.train_data.means.to(w.device)[index[:,0] * 4 + index[:,1]]
            return z_hat.view(*shape, -1)
        return self.attractor_model(w)
    
    def get_discretizer_loss(self, z_traj, logpw, log_reward):
        """
        z_traj: (batch_size, num_steps, dim_z)
        logpw: (batch_size, num_steps)
        log_reward: (batch_size, num_steps)
        """
        z0 = einops.repeat(z_traj[:,0], 'n z -> n t z', t=z_traj.shape[1])
        f = self.f_z_model(torch.cat([z0, z_traj], dim=-1))
        loss = f + logpw - log_reward
        loss = loss.pow(2).mean()
        metrics = {
            'discretizer/loss': loss.item(),
            'discretizer/logF': f.mean().item(),
            'discretizer/logpw': logpw.mean().item(),
            'discretizer/log_reward': log_reward.mean().item(),
        }
        return loss, metrics
    
    def get_dynamics_loss(self, z_traj, logpf, logpb, logpw, log_reward):
        z0 = einops.repeat(z_traj[:,0], 'n z -> n t z', t=z_traj.shape[1] - 1)
        z = torch.cat([z0, z_traj[:,:-1]], dim=-1)
        g = self.g_model(self.temporal_encoding(z))
        g = F.pad(g, (0, 1))

        loss = g[:,1:] - g[:,:-1]
        loss += -logpw[:,1:] + logpw[:,:-1]
        loss += log_reward[:,1:] - log_reward[:,:-1]
        loss += logpb - logpf
        loss = loss.pow(2).mean()

        metrics = {
            'dynamics/loss': loss.item(),
            'dynamics/logpf': logpf.mean().item(),
            'dynamics/logpb': logpb.mean().item(),
            'dynamics/logpw': logpw.mean().item(),
            'dynamics/log_reward': log_reward.mean().item(),
            'dynamics/g': g.mean().item(),
        }
        return loss, metrics

    def get_gfn_loss(self, x):
        """
        x: (batch_size, 2)
        """
        z_traj = self.sample_forward_trajectory(x, p_explore=self.config.p_explore)
        w, logpw = self.sample_w(z_traj, mode=self.config.w_traj_mode, p_explore=self.config.p_explore)       
        with torch.no_grad():
            z_hat = self.get_z_hat(w)

        logpf_traj, logpb_traj, mu_f, sd_f, mu_b, sd_b = self.dynamics_model.log_prob(z_traj, x, return_params=True)  
        with torch.no_grad():
            log_reward, metrics = self.get_log_reward(z_traj, z_hat)
        dynamics_loss, dynamics_metrics = self.get_dynamics_loss(z_traj, logpf_traj, logpb_traj, logpw.detach(), log_reward)
        discretizer_loss, discretizer_metrics = self.get_discretizer_loss(z_traj, logpw, log_reward)

        loss = dynamics_loss + discretizer_loss
        metrics.update(dynamics_metrics)
        metrics.update(discretizer_metrics)

        metrics['dynamics/mu_f'] = mu_f.norm(-1).mean().item()
        metrics['dynamics/sd_f'] = sd_f.norm(-1).mean().item()
        metrics['dynamics/mu_b'] = mu_b.norm(-1).mean().item()
        metrics['dynamics/sd_b'] = sd_b.norm(-1).mean().item()
        return loss, metrics

    def get_attractor_loss(self, x):
        with torch.no_grad():
            z_traj = self.sample_forward_trajectory(x)
            w, logpw = self.sample_w(z_traj[:,-1], argmax=True)
        z_hat = self.get_z_hat(w)
        p_x_zhat = -self.get_score(x, z_hat).mean()

        # regularization
        w = self.discretizer.get_all_w(device=x.device)
        z_hat = self.get_z_hat(w)
        zhat_reg = (z_hat**2).mean()
        loss = p_x_zhat + 1e-4 * zhat_reg
        
        return loss, {'attractor_loss': loss.item(),
                      'p_x_zhat': p_x_zhat.item(),
                      'zhat_reg': zhat_reg.item()}
    
    def check_and_exit_e_step(self, loss, record_loss: bool):
        """
        Returns whether or not the model should exit E-step.
        """
        self.num_mode_updates += 1
        if record_loss:
            self.e_step_losses = np.roll(self.e_step_losses, 1)
            self.e_step_losses[0] = loss.item()
        avg_loss = self.e_step_losses.mean()

        if self.config.num_m_steps <= 0:
            return False
        elif self.global_step <= self.config.start_e_steps:
            return False
        elif self.num_mode_updates < self.config.num_e_steps:
            return False
        elif self.num_mode_updates >= self.config.max_e_steps:
            pass
        # elif self.config.e_step_max_loss is not None and avg_loss > self.config.e_step_max_loss:
        #     return False
        elif self.config.e_step_loss_rate is not None and avg_loss > self.e_step_loss_goal:
            self.e_step_loss_goal = self.e_step_loss_goal * self.config.e_step_loss_rate
            return False
        elif self.config.e_loss_improvement_threshold > 0:
            avg_loss_1 = self.e_step_losses[:len(self.e_step_losses)//2].mean()
            avg_loss_2 = self.e_step_losses[len(self.e_step_losses)//2:].mean()
            if 1 - (avg_loss_2 / avg_loss_1) < self.config.e_loss_improvement_threshold:
                return False

        self.num_mode_updates = 0
        self.e_step = False
        self.e_step_loss_goal = avg_loss.item()
        self.e_step_losses = np.zeros(self.config.e_step_loss_window)
        return True

    def training_step(self, batch, batch_idx):

        x = batch['data']

        if self.e_step or self.config.use_true_attractors: # E-step
            loss, metrics = self.get_gfn_loss(x)
            self.check_and_exit_e_step(loss, True)
            metrics['training/e_step_loss'] = loss.item()
            
        else: # M-step
            loss, metrics = self.get_attractor_loss(x)
            self.num_mode_updates += 1
            self.num_m_steps += 1
            if self.num_mode_updates >= self.config.num_m_steps:
                self.num_mode_updates = 0
                self.e_step = True
            metrics['training/m_step_loss'] = loss.item()

        metrics['training/num_m_steps'] = self.num_m_steps
        metrics['training/e_step_loss_goal'] = self.e_step_loss_goal
        self.log_metrics(**metrics)

        # debugging
        self._last_x = x
        self._last_metrics = metrics
        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        if isinstance(self.logger, WandbLogger) and batch_idx == 0:
            x = data_module.val_data.data.to(self.device)
            z_traj = self.sample_forward_trajectory(x)
            self.log_gif('forward', self.plot_steps(z_traj))

    @torch.no_grad()
    def plot_steps(self, z_traj):
        w_attractors = self.discretizer.get_all_w(device=z_traj.device)
        z_attractors = self.get_z_hat(w_attractors)
        batch_size, num_steps = z_traj.shape[:2]
        index = pd.MultiIndex.from_product([range(batch_size), range(num_steps)], names=['xid', 'step'])
        df_traj = pd.DataFrame(z_traj.cpu().numpy().reshape(-1, z_traj.shape[-1]), 
                            index=index, columns=[f'z{i}' for i in range(z_traj.shape[-1])]).reset_index()

        with torch.no_grad():
            w, logpw = self.sample_w(z_traj)
        w = tu.to_strings(w.flatten(0, 1), min_value=2, chars=string.ascii_letters)
        df_traj['w'] = w

        df_attractors = pd.DataFrame(z_attractors.cpu().numpy(), 
                                    columns=[f'z{i}' for i in range(z_attractors.shape[-1])])
        df_attractors['w'] = tu.to_strings(w_attractors, min_value=2, chars=string.ascii_letters)

        w_unique = df_traj.w.unique()
        colors = {w: c for w, c in zip(w_unique, prism_color_pal('colors')(len(w_unique)))}

        def plot_traj_step(step):
            p = (ggplot(df_traj[df_traj.step == step], aes(x='z0', y='z1'))
            + geom_point(aes(color='w'), alpha=.4, size=1.5)
            + geom_text(aes(label='w'), data=df_attractors, size=12, fontweight='bold')
            + scale_color_manual(values=colors)
            + coord_cartesian(xlim=(-1.5, 1.5), ylim=(-1.5, 1.5))
            + labs(title=f"Step {step}")
            + theme_bw()
            + theme(legend_position='none')
            )
            return p

        return [iu.plot_to_image(plot_traj_step(step)) for step in range(num_steps)]

    def log_metrics(self, **kwargs):
        # prefix = "validation" if self.trainer.validating else "training"
        # d = {f"{prefix}/{k}": v for k, v in kwargs.items()}
        self.log_dict(kwargs, prog_bar=True)

    def log_figure(self, key, p):
        img_buf = io.BytesIO()
        p.save(img_buf, format='png', verbose = False)
        im = Image.open(img_buf).copy()
        plt.close()
        img_buf.close()
        self.logger.log_image(key=key, images=[im], caption=[f'{key} ({self.global_step})'], commit=True)

    def log_gif(self, key, images):
        imarrays = np.array([np.transpose(np.array(im), (2, 0, 1)) for im in images])
        wandb.log({key: wandb.Video(imarrays, fps=1, format='gif', caption=f"{key} ({self.global_step})")}, commit=True)

In [None]:
config = ModelConfig(score_weight=2,
                     use_true_attractors=False,
                     brownian_bridge=False, 
                     dim_h=256,
                     dim_h_discretizer=256,
                     dim_t=10,
                     max_steps=20,
                     t_dependent_forward=False,
                     max_mean=.05,
                     max_sd=1,
                     fixed_sd=None,
                     w_traj_mode='all',
                     num_dynamics_layers=3,
                     num_discretizer_layers=2,
                     lr=1e-4,
                     num_e_steps=50,
                     num_m_steps=1,
                    #  e_step_loss_rate=1.0001,
                     e_step_loss_window=50,
                     e_loss_improvement_threshold=0.005,
                     start_e_steps=1000)
model = Model(config, data_module)
x = data_module.train_data.data[:250].cpu()
with torch.no_grad():
    loss, metrics = model.get_attractor_loss(x)
    loss, metrics = model.get_gfn_loss(x)

# Training

In [None]:
# config.lr = 1e-5
# config.discretizer_lr = 1e-4
# config.lookup_lr = 1e-4
# num_e_steps = 500

In [None]:
max_epochs = 5000
check_val_every_n_epoch = 50
wandb.finish()
logger = WandbLogger(project='gaussians_gfn',
                     name=run_name,
                     entity='andrewnam',
                     config=asdict(config))
trainer = pl.Trainer(max_epochs=max_epochs, 
                     devices=[device], 
                     check_val_every_n_epoch=check_val_every_n_epoch, 
                     logger=logger, 
                     log_every_n_steps=31,
                     gradient_clip_val=1.0,
                     enable_progress_bar=True)
trainer.fit(model, train_dataloaders=data_module)

In [None]:
torch.save(model.cpu().state_dict(), checkpoint_dir / 'gfn.pt')