In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm

import ddpm
import datasets

In [None]:
sizes = range(100, 5000, 200)
names = [f"point_base{s}" for s in sizes]
for i in range(len(sizes)):
    !python ddpm.py --dataset circle --experiment_name {names[i]} --num_epochs 100 --dataset_size {sizes[i]}

Training model...
Epoch 0: 100%|████████████████| 3/3 [00:00<00:00, 18.88it/s, loss=0.819, step=2]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 738.96it/s]
Epoch 1: 100%|███████████████| 3/3 [00:00<00:00, 310.63it/s, loss=0.768, step=5]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 931.38it/s]
Epoch 2: 100%|███████████████| 3/3 [00:00<00:00, 345.78it/s, loss=0.461, step=8]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 941.90it/s]
Epoch 3: 100%|██████████████| 3/3 [00:00<00:00, 350.02it/s, loss=0.278, step=11]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 934.97it/s]
Epoch 4: 100%|███████████████| 3/3 [00:00<00:00, 273.48it/s, loss=0.49, step=14]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 943.47it/s]
Epoch 5: 100%|██████████████| 3/3 [00:00<00:00, 351.88it/s, loss=0.191, step=17]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 936.40it/s]
Epoch 6: 1

In [3]:
import matplotlib as mpl
mpl.use('TkAgg')

In [4]:
def plot_ablation(frames_dict, outname):
    num_rows = len(frames_dict)
    num_cols = 10

    fig = plt.figure(figsize=(3.5*num_cols, 3*num_rows + 0.5))
    row = 0

    for name, frames in frames_dict.items():
        epoch_step = len(frames) // num_cols
        offset = row*(num_cols + 1)
        plt.subplot(num_rows, num_cols + 1, offset + 1)
        plt.scatter(0, 0, alpha=0)
        plt.text(0, 0, name, fontdict={"size": 30})
        plt.xlim(-0.25, 2)
        plt.axis("off")

        for i in range(num_cols):
            plt.subplot(num_rows, num_cols + 1, offset + i + 2)
            ix = i * epoch_step
            frame = frames[ix]
            plt.scatter(frame[:, 0], frame[:, 1], s=5, alpha=0.7)
            if row == 0:
                if i == 0:
                    title = f"epoch {ix}"
                else:
                    title = f"{ix}"
                plt.title(title, fontdict={"size": 30}, pad=30)
            plt.xlim(-3.5, 3.5)
            plt.ylim(-4., 4.75)
            plt.axis("off")

        row += 1

    plt.tight_layout()
    plt.savefig(outname, facecolor="white")
    plt.show()

In [5]:
frames_dict = {}

for i in range(0, len(sizes), 4):
    frames_dict[f'{sizes[i]}'] = np.load(f'exps/{names[i]}/frames.npy')

plot_ablation(frames_dict, "static/datasets.png")

In [32]:
def calculate_stats(model_path):
    model = ddpm.MLP()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = model_path
    # path = "exps/circle_base5000/model.pth"
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
    eval_batch_size = 10000
    num_timesteps = 50
    plot_step = 5
    noise_scheduler = ddpm.NoiseScheduler(num_timesteps=num_timesteps)
    sample = torch.randn(eval_batch_size, 2).to(device)
    timesteps = list(range(num_timesteps))[::-1]
    samples = []
    steps = []
    for i, t in enumerate(tqdm(timesteps)):
        t = torch.from_numpy(np.repeat(t, eval_batch_size)).long().to(device)
        with torch.no_grad():
            residual = model(sample, t)
        sample = noise_scheduler.step(residual, t[0], sample)
        if (i + 1) % plot_step == 0:
            sample_cpu = sample.cpu()
            samples.append(sample_cpu.numpy())
            steps.append(i + 1)
    return process_square(samples[-1])
def process_circle(samples, r=3):
    end_sample_norm_dev = np.sort(np.apply_along_axis(np.linalg.norm, 1, (samples))-r)
    end_sample_norm_dev[round(len(end_sample_norm_dev)*0.01):-round(len(end_sample_norm_dev)*0.01)]
    c = 0
    for a in end_sample_norm_dev:
        c += a**2
    c /= len(end_sample_norm_dev)
    return np.sqrt(c)
def process_square(samples, r=3):
    t = 0
    d = []
    for s in samples:
        dists = [abs(s[0] - r), abs(s[0] + r), abs(s[1] - r), abs(s[1] + r)]
        m = min(dists)
        d.append(m)
    # d = sorted(d)[:-round(len(d)*0.02)]
    return sorted(d)[len(d)//2]
    for m in d:
        t += m**2
    t /= len(d)
    return np.sqrt(t)

In [33]:
from matplotlib import rc
rc('font', family='serif', size='10')
rc('axes', labelsize='large')

s = []
for i in range(len(sizes)):
    s.append(calculate_stats(f"exps/{names[i]}/model.pth"))
print(s, sizes)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

[0.2349987030029297, 0.03278517723083496, 0.0159759521484375, 0.018033981323242188, 0.0075490474700927734, 0.008200407028198242, 0.006604433059692383, 0.006675243377685547, 0.005632162094116211, 0.005910634994506836, 0.0041961669921875, 0.005381107330322266, 0.0060765743255615234, 0.0024771690368652344, 0.0042111873626708984, 0.0025255680084228516, 0.004868030548095703, 0.003567218780517578, 0.0037832260131835938, 0.003083944320678711, 0.002004861831665039, 0.004245281219482422, 0.004563570022583008, 0.0027701854705810547, 0.0030028820037841797] range(100, 5000, 200)


In [34]:
log_s = list(map(lambda x: math.log(1/x), s))
plt.clf()
plt.scatter(sizes, s)
plt.yscale('log')
# plt.scatter(log_s, sizes)
# plt.title("MSE of radius as a function of dataset size")
plt.title("dataset size as a function of log(1/gamma)")
plt.ylabel("log(1/gamma)")
plt.xlabel("dataset size")
plt.savefig('static/devs.png', bbox_inches='tight')
plt.show()

In [26]:
n = 8000
rng = np.random.default_rng(42)
x = np.ones(n//4)
y = rng.uniform(-1, 1, n//4)
X1 = np.stack((x, y), axis=1)

x = -1 * np.ones(n//4)
y = rng.uniform(-1, 1, n//4)
X2 = np.stack((x, y), axis=1)

y = np.ones(n//4)
x = rng.uniform(-1, 1, n//4)
X3 = np.stack((x, y), axis=1)

y = -1 * np.ones(n//4)
x = rng.uniform(-1, 1, n//4)
X4 = np.stack((x, y), axis=1)

X = np.concatenate((X1, X2, X3, X4))

In [27]:
plt.scatter(X[:, 0], X[:, 1])
plt.show()