# Toy Experiments

## 1. Imports

In [None]:
import os
import sys
import json
import math
import warnings

import torch
import numpy as np
from torch.optim import Adam

from tensorboardX import SummaryWriter
from PIL import PngImagePlugin
from matplotlib import pyplot as plt

from tqdm import tqdm
from IPython.display import clear_output

sys.path.append("..")
from src import samplers
from src.enot import integrate, make_net

from src.tools import sde_push, set_random_seed, unfreeze, freeze, fig2tensor

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline

## 2. Config

In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# DATASET, REVERSE = 'moons2swissroll2d', False
DATASET, REVERSE = "swissroll3d2mobius", False
# GPU choosing
DEVICE_IDS = [0]
assert torch.cuda.is_available()


# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
BATCH_SIZE = 512  # 1 for testing
T_ITERS = 10
MAX_STEPS = 2500 + 1  # 2501 for testing
EPSILON_SCHEDULER_LAST_ITER = 20000

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
BETA_D, BETA_T = 0.9, 0.9
T_GRADIENT_MAX_NORM = float(500)
D_GRADIENT_MAX_NORM = float(500)

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = False
PREDICT_SHIFT = True
N_STEPS = 10
UNET_BASE_FACTOR = 128
TIME_DIM = 1
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

# plot settings
PLOT_N_POINTS = 10240

# log settings
LOG_INTERVAL = 500

EXP_NAME = f"ENOT_Toy_{DATASET}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    N_STEPS=N_STEPS,
    EPSILON=EPSILON,
    USE_POSITIONAL_ENCODING=USE_POSITIONAL_ENCODING,
    TIME_DIM=TIME_DIM,
    ONE_STEP_INIT_ITERS=ONE_STEP_INIT_ITERS,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    PREDICT_SHIFT=PREDICT_SHIFT,
    USE_GRADIENT_CHECKPOINT=USE_GRADIENT_CHECKPOINT,
    N_LAST_STEPS_WITHOUT_NOISE=N_LAST_STEPS_WITHOUT_NOISE,
    EPSILON_SCHEDULER_LAST_ITER=EPSILON_SCHEDULER_LAST_ITER,
)

with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

## 3. Initial samplers

In [None]:
if DATASET == "swissroll3d2mobius":
    DIM = 3
    INTEGRAL_SCALE = 1 / DIM
    X_sampler = samplers.SwissRollSampler(dim=DIM, noise=0.5)
    Y_sampler = samplers.MobiusStripSampler()
elif DATASET == "moons2swissroll2d":
    DIM = 2
    INTEGRAL_SCALE = 1 / DIM
    Y_sampler = samplers.SwissRollSampler(dim=DIM, noise=0.5)
    X_sampler = samplers.DoubleMoonSampler()
else:
    raise Exception("DATASET error")

if REVERSE:
    X_sampler, Y_sampler = Y_sampler, X_sampler

## 4. Training

### models initialization

In [None]:
class SDE(torch.nn.Module):
    def __init__(self, shift_model, epsilon, n_steps):
        super().__init__()
        self.shift_model = shift_model
        self.noise_std = math.sqrt(epsilon)
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps

    def forward(self, x0):
        t0 = 0
        trajectory = [x0]
        times = [t0]
        shifts = []

        x, t = x0, t0

        for step in range(self.n_steps):
            x, t, shift = self._step(x, t)

            trajectory.append(x)
            times.append(t)
            shifts.append(shift)

        return (
            torch.stack(trajectory, dim=1),
            torch.tensor(times),
            torch.stack(shifts, dim=1),
        )

    def _step(self, x, t):
        shift = self._get_shift(x, t)
        noise = self._sample_noise(x)
        return x + self.delta_t * shift + noise, t + self.delta_t, shift

    def _get_shift(self, x, t):
        batch_size = x.shape[0]

        t = torch.tensor(t).repeat(batch_size).to(device=x.device)
        inp = torch.cat((x, t[:, None]), dim=-1)
        return self.shift_model(inp)

    def _sample_noise(self, x):
        return (
            self.noise_std
            * math.sqrt(self.delta_t)
            * (torch.randn(x.shape, device=x.device))
        )

    def set_n_steps(self, n_steps):
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps

In [None]:
T = make_net(n_inputs=DIM + 1, n_outputs=DIM, n_layers=3, n_hiddens=100).cuda()
T = SDE(
    T,
    EPSILON,
    N_STEPS,
).cuda()

D = make_net(n_inputs=DIM, n_outputs=1, n_layers=3, n_hiddens=100).cuda()

T_opt = Adam(T.parameters(), lr=T_LR)
D_opt = Adam(D.parameters(), lr=D_LR)

### Ploter

In [None]:
@torch.no_grad()
def map_dataset(SDE, X0: torch.Tensor, batch_size=32):
    total_size = X0.shape[0]
    mapped_data = []

    for i in range(0, total_size, batch_size):
        x0 = X0[i : i + batch_size]
        xN = sde_push(SDE, x0, return_type="XN")
        mapped_data.append(xN)

    if X0.shape[0] % batch_size != 0:
        last_batch_size = X0.shape[0] % batch_size
        last_x0 = X0[-last_batch_size:]
        last_xN = sde_push(SDE, last_x0, return_type="XN")
        mapped_data.append(last_xN)

    mapped_data = torch.cat(mapped_data)

    return mapped_data


def plot_results(source_dataset, target_dataset, mapped_dataset):
    fig = plt.figure(figsize=(1.5 * 3, 1.5 * 1), dpi=150)

    datasets = [source_dataset, target_dataset, mapped_dataset]
    titles = ["Input", "Target", "Trans"]
    for i, (dataset, title) in enumerate(zip(datasets, titles)):
        dim = dataset.shape[-1]

        x = dataset.numpy()[:, 0]
        y = dataset.numpy()[:, 1]

        angles = np.arctan2(y, x)
        normalized_angles = (angles + np.pi) / (
            2 * np.pi
        )  # Normalize angles between 0 and 1
        # Apply a smooth transition function for colors
        colors = 0.5 * (1 + np.sin(2 * np.pi * normalized_angles - np.pi / 2))

        if dim == 2:
            ax = fig.add_subplot(1, 3, i + 1)
            ax.scatter(
                x,
                y,
                c=colors,  # Apply smooth color transition
                cmap="rainbow",  # Use rainbow colormap
                s=1,  # Point size
                edgecolors="none",  # Remove point borders
            )
        if dim == 3:
            z = dataset.numpy()[:, 2]
            ax = fig.add_subplot(1, 3, i + 1, projection="3d")
            ax.scatter(
                x,
                y,
                z,
                c=colors,  # Apply smooth color transition
                cmap="rainbow",  # Use rainbow colormap
                s=1,  # Point size
                edgecolors="none",  # Remove point borders
            )
        ax.set_title(title)
        ax.grid()
        ax.set_axis_off()
        # axes[i].set_xlim([-2.5, 2.5])
        # axes[i].set_ylim([-2.5, 2.5])

    fig.tight_layout()

    return fig, fig.axes

### Main training cycle

In [None]:
progress_bar = tqdm(total=MAX_STEPS)
CONTINUE = 0

for step in range(MAX_STEPS):
    unfreeze(T)
    freeze(D)

    for t_iter in range(T_ITERS):
        T_opt.zero_grad()
        X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)

        trajectory, times, shifts = T(X0)
        XN = trajectory[:, -1, :]
        norm = torch.norm(shifts, p=2, dim=-1) ** 2
        integral = INTEGRAL_SCALE * integrate(norm, times)
        T_loss = (integral + D(X1) - D(XN)).mean()
        writer.add_scalar("T_loss", T_loss.item(), step)
        T_loss.backward()
        T_gradient_norm = torch.nn.utils.clip_grad_norm_(
            T.parameters(), max_norm=T_GRADIENT_MAX_NORM
        )
        T_opt.step()

    freeze(T)
    unfreeze(D)

    D_opt.zero_grad()
    X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
    trajectory, times, shifts = T(X0)
    XN = trajectory[:, -1, :]
    norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
    integral = INTEGRAL_SCALE * integrate(norm, times)

    D_loss = (-integral - D(X1) + D(XN)).mean()
    writer.add_scalar("D_loss", D_loss.item(), step)
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(
        D.parameters(), max_norm=D_GRADIENT_MAX_NORM
    )
    D_opt.step()

    CONTINUE = step + 1
    progress_bar.update(1)

    if step % LOG_INTERVAL == 0:
        progress_bar.close()
        clear_output(wait=True)
        progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)
        print("Plotting")

        original_dataset = torch.cat(
            [
                X_sampler.sample(BATCH_SIZE)
                for i in range(PLOT_N_POINTS // BATCH_SIZE + 1)
            ],
            dim=0,
        )[:PLOT_N_POINTS]

        transfered_dataset = map_dataset(T, original_dataset, BATCH_SIZE).cpu()
        original_dataset = original_dataset.cpu()
        target_dataset = torch.cat(
            [
                Y_sampler.sample(BATCH_SIZE)
                for i in range(PLOT_N_POINTS // BATCH_SIZE + 1)
            ],
            dim=0,
        )[:PLOT_N_POINTS].cpu()

        fig, axes = plot_results(original_dataset, target_dataset, transfered_dataset)
        writer.add_image("trans", fig2tensor(fig), step)
        plt.show()
        plt.close()
        print("Computing L1 MSE")
        l1 = torch.nn.functional.l1_loss(target_dataset, transfered_dataset)
        mse = torch.nn.functional.mse_loss(target_dataset, transfered_dataset)
        writer.add_scalar("Metrics/L1", l1, step)
        writer.add_scalar("Metrics/MSE", mse, step)