In [7]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm
from scipy.stats import norm
import matplotlib as mpl
mpl.use("TkAgg")
import ddpm
import datasets

In [8]:
sizes = range(6500, 7500, 400)
sizes = [7300]
names = [f"point_1d_medgamma{s}" for s in sizes]

In [9]:
for i in range(len(sizes)):
    print(sizes[i])
    !python ddpm.py --dataset point1d --experiment_name {names[i]} --num_epochs 500 --dataset_size {sizes[i]} --dimension 1 --beta_schedule ours --num_timesteps 1000

7300
MIN GAMMA tensor(2.9135e-05, dtype=torch.float64) LEN 74
Training model...
Epoch 0: 100%|█████████| 228/228 [00:00<00:00, 342.92it/s, loss=0.303, step=227]
100%|███████████████████████████████████████████| 74/74 [00:00<00:00, 91.74it/s]
Epoch 1: 100%|█████████| 228/228 [00:00<00:00, 414.03it/s, loss=0.305, step=455]
100%|█████████████████████████████████████████| 74/74 [00:00<00:00, 1063.75it/s]
Epoch 2: 100%|█████████| 228/228 [00:00<00:00, 411.42it/s, loss=0.133, step=683]
100%|█████████████████████████████████████████| 74/74 [00:00<00:00, 1001.62it/s]
Epoch 3: 100%|█████████| 228/228 [00:00<00:00, 367.27it/s, loss=0.111, step=911]
100%|█████████████████████████████████████████| 74/74 [00:00<00:00, 1021.77it/s]
Epoch 4: 100%|████████| 228/228 [00:00<00:00, 366.68it/s, loss=0.234, step=1139]
100%|██████████████████████████████████████████| 74/74 [00:00<00:00, 951.05it/s]
Epoch 5: 100%|████████| 228/228 [00:00<00:00, 373.90it/s, loss=0.247, step=1367]
100%|████████████████████████

In [10]:
def calculate_stats(model_path, dataset='point', score='model'):
    model = ddpm.MLP(input_dim=1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = model_path
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
    eval_batch_size = 100000
    num_scheduler_timesteps = 1000
    plot_step = 1
    noise_scheduler = ddpm.NoiseScheduler(num_timesteps=num_scheduler_timesteps, beta_schedule='ours')
    num_timesteps = len(noise_scheduler.betas)
    curr_vars = torch.sqrt(1 - torch.exp(-2 * noise_scheduler.times))
    sample = torch.randn(eval_batch_size*2, 1).to(device)[eval_batch_size:, :]
    print(sample.shape)
    timesteps = list(range(num_timesteps))[::-1]
    samples = []
    steps = []
    for i, t in enumerate(tqdm(timesteps)):
        t = torch.from_numpy(np.repeat(t, eval_batch_size)).long().to(device)
        with torch.no_grad():
            variance = torch.sqrt(1 - noise_scheduler.alphas_cumprod[t])
            v = curr_vars[t].cpu().numpy()
            if score=='model':
                # print("RATIO", sample.squeeze().cpu().numpy() / v)
                # print("DIFF", torch.mean(model(sample, t) - sample / variance[0]))
                residual = model(sample, t)
            elif score == 'true':
                # print(variance[0])
                residual =  sample / variance[0]
            else:
                if score > 1 or score < 0:
                    raise ValueError("INVALID SCORE")
                residual = score * model(sample, t) + (1-score) * (sample / variance[0])
            # residual = (0.003*residual1 + 0.997*residual2)
        sample = noise_scheduler.step(residual, t[0], sample)
        # print(np.median(sample.squeeze().cpu().numpy()))
        if (i + 1) % plot_step == 0:
            sample_cpu = sample.cpu()
            samples.append(sample_cpu.numpy())
            steps.append(i + 1)
    if dataset == 'ret':
        return samples
    elif dataset == 'point':
        m = process_point(samples[-1], mode='median')
        print("SCORE AT MEAN", model(torch.tensor([[m]], device=device), torch.tensor([0], device=device)))
        return process_point(samples[-1], mode='median')
    else:
        raise ValueError("INVALID DATASET")
def process_point(samples, mode='median'):
    return np.mean(samples.squeeze())
    # print("MEAN", )
    # return np.median(samples.squeeze())
    # centered = samples.squeeze() - np.mean(samples.squeeze())
    # print(np.mean(samples.squeeze()))
    # return np.median(centered)

In [11]:
s = []
for i in range(len(sizes)):
    print(sizes[i], names[i])
    s.append(np.abs(calculate_stats(f"exps/{names[i]}/model.pth", dataset='point', score='model')))
print(s)
plt.clf()
plt.scatter(sizes, np.array(s))
plt.yscale('log')
plt.title("median error vs dataset size")
plt.ylabel("median error")
plt.xlabel("dataset size")
plt.savefig(f'static/devs_point.png', bbox_inches='tight')
plt.show()

7300 point_1d_medgamma7300
MIN GAMMA tensor(2.9135e-05, dtype=torch.float64) LEN 74
torch.Size([100000, 1])


  0%|          | 0/74 [00:00<?, ?it/s]

SCORE AT MEAN tensor([[-0.0123]], device='cuda:0', dtype=torch.float64,
       grad_fn=<AddmmBackward0>)
[3.53062912977824e-07]


In [15]:
model_path = "exps/point_1d_smallgamma7300/model.pth"
model = ddpm.MLP(input_dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()

MLP(
  (time_mlp): PositionalEmbedding(
    (layer): SinusoidalEmbedding()
  )
  (input_mlp1): PositionalEmbedding(
    (layer): SinusoidalEmbedding()
  )
  (joint_mlp): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): GELU(approximate='none')
    (2): Block(
      (ff): Linear(in_features=128, out_features=128, bias=True)
      (act): GELU(approximate='none')
    )
    (3): Block(
      (ff): Linear(in_features=128, out_features=128, bias=True)
      (act): GELU(approximate='none')
    )
    (4): Block(
      (ff): Linear(in_features=128, out_features=128, bias=True)
      (act): GELU(approximate='none')
    )
    (5): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [16]:
noise_scheduler = ddpm.NoiseScheduler(num_timesteps=1000, beta_schedule='ours')
curr_stds = torch.sqrt(1 - torch.exp(-2 * noise_scheduler.times))
t = 0
v = curr_stds[t].item()
print(v)
x_scale = np.linspace(-v * 5, v * 5, 1000)
# x_scale = np.linspace(-10, 10, 1000)
inputs = torch.tensor(x_scale, device=device).unsqueeze(1)
times = torch.ones(len(inputs)).to(device) * t
model_residuals = model(inputs, times)
true_residuals = inputs / v

MIN GAMMA tensor(2.9135e-05, dtype=torch.float64) LEN 74
2.9134887396820698e-05


In [17]:
plt.plot(x_scale, model_residuals.data.cpu().numpy(), label='model')
plt.plot(x_scale, true_residuals.data.cpu().numpy(), label='true')
for y in (np.arange(11)-5)*v:
    plt.axvline(y, alpha=1 if y == 0 else 0.2)
plt.legend()
plt.savefig("score.png")
plt.show()

In [14]:
errors = []
for t in range(50):
    print(t)
    v = curr_vars[t].cpu().numpy()
    x_range = np.linspace(-v*5, v*5, 1000)
    diff = x_range[1] - x_range[0]
    l2 = 0
    difc = 0
    pc = 0
    tot = 0
    for i in x_range:
        v = curr_vars[t]
        pdf = norm.pdf(i, 0, v.item())
        model_val = model(torch.tensor([[i]], device=device, dtype=torch.float32), torch.ones(1, device=device, dtype=torch.float32)*t)
        true_val = torch.tensor([[i]], device=device) / torch.sqrt(1 - torch.exp(-2 * noise_scheduler.times[t]))
        error = (model_val.data.cpu().numpy() - true_val.data.cpu().numpy())[0]
        l2 += (error**2)*diff*pdf
        difc += diff
        pc += pdf
        tot += diff*pdf
    print(l2, difc, tot, v.item()*2)
    errors.append(l2*v.item()*v.item())
    # break

0


NameError: name 'curr_vars' is not defined

In [96]:
x_range = np.linspace(-10, 10, 1000)
diff = x_range[1] - x_range[0]
tot = 0
for i in x_range:
    tot += diff * norm.pdf(i, 0, 10)
print(tot)

0.6831737563750486


In [29]:
model_path = "exps/point_1d_smallgamma7300/model.pth"
model = ddpm.MLP(input_dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
scores = []
inputs = torch.linspace(-10, 10, 1000).unsqueeze(1).to(device)
noise_scheduler = ddpm.NoiseScheduler(num_timesteps=1000, beta_schedule="ours")
print(noise_scheduler.betas.shape)
times = torch.ones(len(inputs), device=device)*9
residuals = model(inputs, times)

MIN GAMMA tensor(3.1610e-06, dtype=torch.float64) LEN 84
torch.Size([84])
