# Toy Experiments

## 1. Imports

In [None]:
import os
import sys
import logging

sys.path.append("..")

import math

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam
from tqdm import trange
from IPython.display import clear_output

from src import distributions

# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# torch.set_num_threads(2)

## 2. Function definition

### network generators

In [None]:
def make_net(n_inputs, n_outputs, n_layers=3, n_hiddens=100):
    layers = [nn.Linear(n_inputs, n_hiddens), nn.ReLU()]

    for i in range(n_layers - 1):
        layers.extend([nn.Linear(n_hiddens, n_hiddens), nn.ReLU()])

    layers.append(nn.Linear(n_hiddens, n_outputs))

    return nn.Sequential(*layers)


class SDE(nn.Module):
    def __init__(self, shift_model, epsilon, n_steps):
        super().__init__()
        self.shift_model = shift_model
        self.noise_std = math.sqrt(epsilon)
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps

    def forward(self, x0):
        t0 = 0
        trajectory = [x0]
        times = [t0]
        shifts = []

        x, t = x0, t0

        for step in range(self.n_steps):
            x, t, shift = self._step(x, t)

            trajectory.append(x)
            times.append(t)
            shifts.append(shift)

        return (
            torch.stack(trajectory, dim=1),
            torch.tensor(times, device=x0.device),
            torch.stack(shifts, dim=1),
        )

    def _step(self, x, t):
        shift = self._get_shift(x, t)
        noise = self._sample_noise(x)

        return x + self.delta_t * shift + noise, t + self.delta_t, shift

    def _get_shift(self, x, t):
        batch_size = x.shape[0]
        t = torch.tensor(t).repeat(batch_size).to(device=x.device)

        inp = torch.cat((x, t[:, None]), dim=-1)
        return self.shift_model(inp)

    def _sample_noise(self, x):
        noise = self.noise_std * math.sqrt(self.delta_t) * (torch.randn(x.shape))
        return noise.to(x.device)

    def set_n_steps(self, n_steps):
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps

### data and pivotal sampler

In [None]:
from diffusers import DDIMScheduler


# 采样关键点
def sample_all_pivotal(
    source_sampler: distributions.Sampler,
    target_sampler: distributions.Sampler,
    batch_size: int = 1024,
    half_steps: int = 1000,
    pivotal_list: list[int] = [0, 10, 20, 50],
) -> list:
    scheduler = DDIMScheduler(num_train_timesteps=half_steps)
    pivotal_path = []

    source: torch.Tensor = source_sampler.sample(batch_size)
    target: torch.Tensor = target_sampler.sample(batch_size)
    source_list = [source]
    target_list = [target]
    for i in range(min(half_steps, pivotal_list[-1])):
        source = scheduler.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = scheduler.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in pivotal_list:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])

    return pivotal_path


def sample_step_t_pivotal(
    source_sampler: distributions.Sampler,
    target_sampler: distributions.Sampler,
    batch_size: int = 1024,
    half_steps: int = 1000,
    pivotal_list: list[int] = [50, 100, 200, 500],
    pivotal_step: int = 0,
):
    pivotal_path = sample_all_pivotal(
        source_sampler, target_sampler, batch_size, half_steps, pivotal_list
    )
    pivotal_t, pivotal_tadd1 = (
        pivotal_path[pivotal_step],
        pivotal_path[pivotal_step + 1],
    )
    return pivotal_t, pivotal_tadd1

### mapping plotters

In [None]:
def plot_results(source_dataset, target_dataset, mapped_dataset):
    fig = plt.figure(figsize=(15, 5))

    datasets = [source_dataset, target_dataset, mapped_dataset]
    titles = ["Input distribution", "Target distribution", "Fitted distribution"]
    for i, (dataset, title) in enumerate(zip(datasets, titles)):
        dim = dataset.shape[-1]

        x = dataset.numpy()[:, 0]
        y = dataset.numpy()[:, 1]
        # color setting
        # num_points = x.shape[0]
        # colors = [(i / num_points) for i in range(num_points)]
        # Calculate the angles for color mapping
        angles = np.arctan2(y, x)
        normalized_angles = (angles + np.pi) / (
            2 * np.pi
        )  # Normalize angles between 0 and 1
        # Apply a smooth transition function for colors
        colors = 0.5 * (1 + np.sin(2 * np.pi * normalized_angles - np.pi / 2))

        if dim == 2:
            ax = fig.add_subplot(1, 3, i + 1)
            ax.scatter(
                x,
                y,
                c=colors,  # Apply smooth color transition
                cmap="rainbow",  # Use rainbow colormap
                s=10,  # Point size
                edgecolors="none",  # Remove point borders
            )
        if dim == 3:
            z = dataset.numpy()[:, 2]
            ax = fig.add_subplot(1, 3, i + 1, projection="3d")
            ax.scatter(
                x,
                y,
                z,
                c=colors,  # Apply smooth color transition
                cmap="rainbow",  # Use rainbow colormap
                s=10,  # Point size
                edgecolors="none",  # Remove point borders
            )
        ax.set_title(title)
        ax.grid()
        ax.set_axis_off()
        # axes[i].set_xlim([-2.5, 2.5])
        # axes[i].set_ylim([-2.5, 2.5])

    plt.tight_layout()
    plt.show()


@torch.no_grad()
def draw_sub_mapping(
    plot_n_samples,
    source_sampler,
    target_sampler,
    half_steps,
    pivotal_list,
    pivotal_step,
    batch_size,
    sde: torch.nn.Module,
    img_path: None,
):
    clear_output()
    device = next(sde.parameters()).device
    source_dataset, target_dataset, mapped_dataset = [], [], []
    for i in range(plot_n_samples // batch_size + 1):
        source, target = sample_step_t_pivotal(
            source_sampler,
            target_sampler,
            batch_size,
            half_steps,
            pivotal_list,
            pivotal_step,
        )

        source = source.to(device)
        target = target.to(device)

        trajectory, times, _ = sde(source)
        xN = trajectory[:, -1, :]

        mapped_dataset.append(xN)
        source_dataset.append(source)
        target_dataset.append(target)

    source_dataset, target_dataset, mapped_dataset = (
        torch.cat(source_dataset[:plot_n_samples]).cpu(),
        torch.cat(target_dataset[:plot_n_samples]).cpu(),
        torch.cat(mapped_dataset[:plot_n_samples]).cpu(),
    )
    plot_results(source_dataset, target_dataset, mapped_dataset)
    if img_path:
        plt.savefig(img_path)
    plt.show()


@torch.no_grad()
def draw_linked_mapping(
    plot_n_samples,
    source_sampler,
    target_sampler,
    batch_size,
    half_steps,
    pivotal_list,
    SDEs: list[torch.nn.Module],
    img_path: None,
):
    device = next(SDEs[0].parameters()).device
    source_dataset, target_dataset, mapped_dataset = [], [], []
    for i in range(plot_n_samples // batch_size + 1):
        source, target = (
            source_sampler.sampler(batch_size).to(device),
            target_sampler.sample(batch_size).to(device),
        )

        # TODO: 可视化 linked mapping 中的所有 pivotal
        pivotals = [source.clone().detach()]
        for t in range(len(SDEs)):
            x0 = pivotals[t]
            trajectory, times, _ = SDEs[t](x0)
            xN = trajectory[:, -1, :]
            pivotals.append(xN)

        mapped_dataset.append(pivotals[-1])
        source_dataset.append(source)
        target_dataset.append(target)

    source_dataset, target_dataset, mapped_dataset = (
        torch.cat(source_dataset[:plot_n_samples]).cpu(),
        torch.cat(target_dataset[:plot_n_samples]).cpu(),
        torch.cat(mapped_dataset[:plot_n_samples]).cpu(),
    )

    plot_results(source_dataset, target_dataset, mapped_dataset)
    if img_path:
        plt.savefig(img_path)
    plt.show()


# 采样关键点
def sample_and_draw_all_pivotal(
    source_sampler: distributions.Sampler,
    target_sampler: distributions.Sampler,
    batch_size: int = 1024,
    half_steps: int = 1000,
    pivotal_list: list[int] = [50, 100, 200, 500],
) -> list:
    scheduler = DDIMScheduler(num_train_timesteps=half_steps)
    pivotal_path = []

    source: torch.Tensor = source_sampler.sample(batch_size)
    target: torch.Tensor = target_sampler.sample(batch_size)

    source_list = [source]
    target_list = [target]
    for i in range(min(half_steps, pivotal_list[-1])):
        source = scheduler.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = scheduler.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in pivotal_list:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list)

    ncols = len(pivotal_list)
    fig = plt.figure(figsize=(5 * ncols, 10))
    for i, pivotal in enumerate(pivotal_path):
        dim = pivotal.shape[-1]
        points = pivotal.numpy()
        x, y = points[:, 0], points[:, 1]
        # color setting
        # num_points = x.shape[0]
        # colors = [(i / num_points) for i in range(num_points)]
        # Calculate the angles for color mapping
        angles = np.arctan2(y, x)
        normalized_angles = (angles + np.pi) / (
            2 * np.pi
        )  # Normalize angles between 0 and 1
        # Apply a smooth transition function for colors
        colors = 0.5 * (1 + np.sin(2 * np.pi * normalized_angles - np.pi / 2))

        if dim == 3:
            z = points[:, 2]
            ax = fig.add_subplot(2, ncols, i + 1, projection="3d")
            ax.scatter(
                x,
                y,
                z,
                c=colors,
                cmap="rainbow",
                s=10,
                edgecolors="none",
            )
        if dim == 2:
            ax = fig.add_subplot(2, ncols, i + 1)
            ax.scatter(
                x,
                y,
                c=colors,
                cmap="rainbow",
                s=10,
                edgecolors="none",
            )
        ax.set_title(f"x{i} distribution")
        ax.grid(False)
        ax.set_axis_off()

    plt.tight_layout()
    plt.show()

    # return pivotal_path

### trainers

In [None]:
def integrate(values, times):
    deltas = times[1:] - times[:-1]
    return (values * deltas[None, :]).sum(dim=1)


def train_linked_mapping(
    source_sampler: distributions.Sampler,
    target_sampler: distributions.Sampler,
    SDEs: list[SDE],
    SDE_OPTs: list[torch.optim.Optimizer],
    BETA_NETs: list[torch.nn.Module],
    BETA_NET_OPTs: list[torch.optim.Optimizer],
    iterations: int,
    inner_iterations: int,
    half_steps: int = 1000,
    pivotal_list: list[int] = [0, 10, 20, 50],
    batch_size: int = 1024,
    device: torch.device = torch.device("cpu"),
):
    T = len(pivotal_list) * 2 - 2
    NORMs: list[list] = []

    # 1. 迭代训练ENOT(X_t-1, X_t)
    for t in range(T):
        norms = []
        integral_scale = 1
        # a. 选择mapping网络和优化器
        sde, beta_net, sde_opt, beta_net_opt = (
            SDEs[t],
            BETA_NETs[t],
            SDE_OPTs[t],
            BETA_NET_OPTs[t],
        )
        for iteration in trange(iterations):
            if iteration % 100 == 0:
                draw_sub_mapping(
                    1024,
                    source_sampler,
                    target_sampler,
                    half_steps,
                    pivotal_list,
                    t,
                    batch_size,
                    sde,
                    img_path=f"./logs_plist_without_g2g/sub_mapping_{t}to{t+1}_{iteration}iter.png",
                )
            # b. 生成训练数据
            source, target = sample_step_t_pivotal(
                source_sampler,
                target_sampler,
                batch_size,
                half_steps,
                pivotal_list,
                t,
            )
            source = source.to(device)
            target = target.to(device)

            # print(
            #     f"[Debug] {device = } {source.device = } {target.device = } {next(sde.parameters()).device = } {next(beta_net.parameters()).device = }"
            # )

            # c. 训练网络
            trajectory, times, shifts = sde(source)
            target_predicted = trajectory[:, -1, :]

            norm = torch.norm(shifts, p=2, dim=-1) ** 2
            norms.append(norm)

            integral = integral_scale * integrate(norm, times)

            loss_beta = (
                -integral - beta_net(target) + beta_net(target_predicted)
            ).mean()
            beta_net_opt.zero_grad()
            loss_beta.backward()
            beta_net_opt.step()

            for inner_iteration in range(inner_iterations):
                source, target = sample_step_t_pivotal(
                    source_sampler,
                    target_sampler,
                    batch_size,
                    half_steps,
                    pivotal_list,
                    t,
                )
                source = source.to(device)
                target = target.to(device)
                trajectory, times, shifts = sde(source)
                target_predicted = trajectory[:, -1, :]

                norm = torch.norm(shifts, p=2, dim=-1) ** 2
                integral = integral_scale * integrate(norm, times)

                loss_sde = (
                    integral + beta_net(target) - beta_net(target_predicted)
                ).mean()
                sde_opt.zero_grad()
                loss_sde.backward()
                sde_opt.step()

        NORMs.append(norms)

    draw_linked_mapping(
        1024,
        source_sampler,
        target_sampler,
        batch_size,
        half_steps,
        pivotal_list,
        SDEs,
        img_path="./logs_plist_without_g2g/linked_mapping.png",
    )
    # 5. 返回
    return SDEs, BETA_NETs, NORMs

## 3. Config

target_data_type is "8_gaussians" or "swiss_roll" depending on the target distribution

In [None]:
# DEVICE_IDS = [0, 1, 2, 3]
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")

SEED = 0xBADBEEF
torch.manual_seed(SEED)
np.random.seed(SEED)


half_steps = 1000
pivotal_list = [0, 10, 20, 50]


batch_size = 512
iterations = 2000
inner_iterations = 10

epsilon = 0.1
lr = 1e-4
n_steps = 10

# "8_gaussians" or "swiss_roll"
source_data_type = "swiss_roll"
target_data_type = "Mobius"
dim = 3

## 4. Train

### initalize data sampler

In [None]:
# 准备数据/采样器
if source_data_type == "swiss_roll":
    source_sampler = distributions.SwissRollSampler(device="cpu", dim=dim)
elif source_data_type == "8_gaussians":
    source_sampler = distributions.Mix8GaussiansSampler(
        std=0.1, r=math.sqrt(2), device="cpu"
    )
elif source_data_type == "moons":
    source_sampler = distributions.MoonsSampler(device="cpu")
elif source_data_type == "Mobius":
    source_sampler = distributions.MobiusStripSampler(device="cpu")

if target_data_type == "swiss_roll":
    target_sampler = distributions.SwissRollSampler(device="cpu", dim=dim)
elif target_data_type == "8_gaussians":
    target_sampler = distributions.Mix8GaussiansSampler(
        std=0.1, r=math.sqrt(2), device="cpu"
    )
elif target_data_type == "moons":
    target_sampler = distributions.MoonsSampler(device="cpu")
elif target_data_type == "Mobius":
    target_sampler = distributions.MobiusStripSampler(device="cpu")

### show sample data and pivotals

In [None]:
sample_and_draw_all_pivotal(
    source_sampler, target_sampler, batch_size=5000, pivotal_list=[0, 10, 20, 50, 100]
)

### initalize model

In [None]:
# 准备映射网络与优化器
SDEs, BETA_NETs = [], []
SDE_OPTs, BETA_NET_OPTs = [], []

for i in range(len(pivotal_list) * 2 - 2):
    sde_shift_model = make_net(
        n_inputs=dim + 1, n_outputs=dim, n_layers=3, n_hiddens=100
    ).to(device)
    sde = SDE(sde_shift_model, epsilon, n_steps).to(device)
    SDEs.append(sde)

    beta_net = make_net(n_inputs=dim, n_outputs=1, n_layers=3, n_hiddens=100).to(device)
    BETA_NETs.append(beta_net)

    sde_opt = Adam(sde.parameters(), lr=lr)
    beta_net_opt = Adam(beta_net.parameters(), lr=lr)
    SDE_OPTs.append(sde_opt)
    BETA_NET_OPTs.append(beta_net_opt)

### training

In [None]:
# 训练网络
SDEs, BETA_NETs, NORMs = train_linked_mapping(
    source_sampler,
    target_sampler,
    SDEs,
    SDE_OPTs,
    BETA_NETs,
    BETA_NET_OPTs,
    iterations,
    inner_iterations,
    half_steps,
    pivotal_list,
    batch_size,
    device,
)

## 5. Saving model

In [None]:
from glob import glob


path = f"../toy_sde_models/pivotal_{'_'.join(map(str, pivotal_list))}"
print(f"[Info] saving {path = }")
for i, sde in enumerate(SDEs):
    torch.save(
        sde.state_dict(),
        os.path.join(
            path, f"sde{i}_{source_data_type}_{target_data_type}_{epsilon}_{n_steps}.pt"
        ),
    )


def load_sdes(path: str) -> list[SDE]:
    num = len(glob(os.path.join(path, "*.pt")))
    SDEs = []
    for i in range(num):
        sde_path = os.path.join(
            path, f"sde{i}_{source_data_type}_{target_data_type}_{epsilon}_{n_steps}.pt"
        )
        print(sde_path)
        sde_shift_model = make_net(
            n_inputs=2 + 1, n_outputs=2, n_layers=3, n_hiddens=100
        )
        sde = SDE(shift_model=sde_shift_model, epsilon=epsilon, n_steps=n_steps)
        sde.load_state_dict(torch.load(sde_path))
        SDEs.append(sde)

    return SDEs

## 6. TODO: Current model plots

In [None]:
# TODO
plot_n_samples = 1024

mapped_dataset = map_dataset(sde, source_sampler, batch_size, plot_n_samples)
source_dataset = torch.cat(
    [
        source_sampler.sample(batch_size)
        for i in range(plot_n_samples // batch_size + 1)
    ],
    dim=0,
)[:plot_n_samples].cpu()
target_dataset = torch.cat(
    [
        target_sampler.sample(batch_size)
        for i in range(plot_n_samples // batch_size + 1)
    ],
    dim=0,
)[:plot_n_samples].cpu()

plot_results(source_dataset, target_dataset)
plt.show()

In [None]:
tr = sde(source_sampler.sample(64)[:64])[0].detach().cpu()

for i in range(10):
    plt.plot(
        tr[i, :, 0],
        tr[i, :, 1],
        "-o",
        markeredgecolor="black",
        linewidth=4,
        markersize=4,
    )

target_dataset = target_sampler.sample(750).cpu().numpy()
plt.scatter(
    target_dataset[:, 0],
    target_dataset[:, 1],
    c="orange",
    s=20,
    edgecolors="black",
    label="target data",
)
plt.title("Trajectories")

plt.grid()
plt.xlim([-2.5, 2.5])
plt.ylim([-2.5, 2.5])
plt.legend()

plt.show()

## 7. TODO: Final plots for all models

### Models loading

In [None]:
epsilons = [0, 0.01, 0.1]
path = "../toy_sde_models/"

source_sampler = distributions.StandardNormalSampler(dim=2, device="cpu")

if target_data_type == "swiss_roll":
    target_sampler = distributions.SwissRollSampler(device="cpu")
elif target_data_type == "8_gaussians":
    target_sampler = distributions.Mix8GaussiansSampler(
        std=0.1, r=math.sqrt(2), device="cpu"
    )

sdes = []
for epsilon in epsilons:
    sde_shift_model = make_net(n_inputs=2 + 1, n_outputs=2, n_layers=3, n_hiddens=100)
    sde = SDE(shift_model=sde_shift_model, epsilon=epsilon, n_steps=n_steps)
    sde.load_state_dict(
        torch.load(os.path.join(path, f"sde_{target_data_type}_{epsilon}_{n_steps}.pt"))
    )
    print(os.path.join(path, f"sde_{target_data_type}_{epsilon}_{n_steps}.pt"))
    sdes.append(sde)

In [None]:
plot_n_samples = 600

source_dataset = torch.cat(
    [
        source_sampler.sample(batch_size)
        for i in range(plot_n_samples // batch_size + 1)
    ],
    dim=0,
)[:plot_n_samples].cpu()
target_dataset = torch.cat(
    [
        target_sampler.sample(batch_size)
        for i in range(plot_n_samples // batch_size + 1)
    ],
    dim=0,
)[:plot_n_samples].cpu()

### Plot for all models

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)

fig, axes = plt.subplots(2, 4, figsize=(25, 10), dpi=450)
axes[0, 0].scatter(
    source_dataset.numpy()[:, 0],
    source_dataset.numpy()[:, 1],
    c="g",
    s=48,
    edgecolors="black",
)
axes[0, 0].set_title("Input distribution")

axes[1, 0].scatter(
    target_dataset.numpy()[:, 0],
    target_dataset.numpy()[:, 1],
    c="orange",
    s=48,
    edgecolors="black",
)
axes[1, 0].set_title("Target distribution")

for n, (sde, epsilon) in enumerate(zip(sdes, epsilons)):
    n = n + 1
    mapped_dataset = map_dataset(sde, source_dataset, plot_n_samples=plot_n_samples)
    axes[0, n].scatter(
        mapped_dataset.numpy()[:, 0],
        mapped_dataset.numpy()[:, 1],
        c="yellow",
        s=48,
        edgecolors="black",
    )
    axes[0, n].set_title("Fitted distribution $\epsilon=$ " + f"{epsilon}")

    tr = sde(source_dataset[:20])[0].detach()

    for i in range(20):
        axes[1, n].plot(
            tr[i, :, 0],
            tr[i, :, 1],
            "-o",
            markeredgecolor="black",
            c="green",
            linewidth=3,
            markersize=4,
            markeredgewidth=0.5,
        )

    axes[1, n].scatter(
        target_dataset.numpy()[:, 0],
        target_dataset.numpy()[:, 1],
        c="orange",
        s=48,
        edgecolors="black",
        label="target data",
    )
    axes[1, n].set_title("Trajectories $\epsilon=$ " + f"{epsilon}")


for i in range(8):
    axes[i % 2, i // 2].grid()
    axes[i % 2, i // 2].set_xlim([-2.5, 2.5])
    axes[i % 2, i // 2].set_ylim([-2.5, 2.5])

plt.savefig(f"../pics/{target_data_type}_results.jpg")

### Pictures for the article

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)

fig, axes = plt.subplots(2, 1, figsize=(3.125 * 0.9, 6.25 * 0.9), dpi=450)
axes[0].scatter(
    source_dataset.numpy()[:, 0],
    source_dataset.numpy()[:, 1],
    c="g",
    s=48,
    edgecolors="black",
    label="Input distribution",
)
axes[0].legend()
axes[0].grid()
# axes[0, 0].set_title("Input distribution")

axes[1].scatter(
    target_dataset.numpy()[:, 0],
    target_dataset.numpy()[:, 1],
    c="orange",
    s=48,
    edgecolors="black",
    label="Target distribution",
)
axes[1].legend()
axes[1].grid()
fig.tight_layout(pad=0.1)

for i in range(2):
    axes[i % 2].set_xlim([-2.5, 2.5])
    axes[i % 2].set_ylim([-2.5, 2.5])

plt.savefig(f"../pics/{target_data_type}_results_input_and_target.jpg")

for n, (sde, epsilon) in enumerate(zip(sdes, epsilons)):
    fig, axes = plt.subplots(2, 1, figsize=(3.125 * 0.9, 6.25 * 0.9), dpi=450)
    #     n = n+1
    mapped_dataset = map_dataset(sde, source_dataset, plot_n_samples=plot_n_samples)

    axes[0].scatter(
        mapped_dataset.numpy()[:, 0],
        mapped_dataset.numpy()[:, 1],
        c="yellow",
        s=48,
        edgecolors="black",
        label="Fitted distribution",
    )
    axes[0].legend()

    tr = sde(source_dataset[:20])[0].detach()

    n_traj = 20 if epsilon < 1 else 5
    for i in range(n_traj):
        if i == 0:
            axes[1].plot(
                tr[i, :, 0],
                tr[i, :, 1],
                "-o",
                markeredgecolor="black",
                c="green",
                linewidth=3,
                markersize=4,
                markeredgewidth=0.5,
                label="Trajectories",
            )

        else:
            axes[1].plot(
                tr[i, :, 0],
                tr[i, :, 1],
                "-o",
                markeredgecolor="black",
                c="green",
                linewidth=3,
                markersize=4,
                markeredgewidth=0.5,
            )

    axes[1].scatter(
        target_dataset.numpy()[:, 0],
        target_dataset.numpy()[:, 1],
        c="orange",
        s=48,
        edgecolors="black",
        label="Target data",
    )
    axes[1].legend()

    for i in range(2):
        axes[i % 2].grid()
        axes[i % 2].set_xlim([-2.5, 2.5])
        axes[i % 2].set_ylim([-2.5, 2.5])

    fig.tight_layout(pad=0.1)

    plt.savefig(f"../pics/{target_data_type}_results_{epsilon}.jpg")