In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.sedd import SEDD
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet
from d3pm_sc.dit import DiT_Llama
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

N = 128
forward_kwargs = {"type":"gaussian",
                  "normalized": True,
                  "bandwidth":1/10}

from scipy import linalg, special
from scipy import optimize

L = utils.get_inf_gens(forward_kwargs, N)
p0 = torch.ones(len(L)) / len(L)
ent_p0 = -torch.xlogy(p0, p0).sum()
# L = p0[:, None] * L
evals, V = torch.linalg.eig(L)
evals[torch.real(evals) > 0] = 0
second_eval = torch.real(evals)
second_eval = -second_eval.sort().values[-2]
V_inv = torch.linalg.inv(V)
def mi(t, t_shift=1):
    """ t_shift is 1-t """
    evals_skew = torch.exp(torch.tensor(t)[:, None, None] * evals[None, None, :])
    too_big = torch.real(evals_skew * evals_skew.conj()) > 1
    p = torch.real(p0[None, :, None] * ((V[None,:, :] * evals_skew) @ V_inv))
    p = torch.where(p < 0, 0, p)
    mi_m1 = (torch.xlogy(p, p).sum(-1) - torch.xlogy(p.sum(-2), p.sum(-2))).sum(-1) / ent_p0
    return mi_m1 + t_shift

def beta_int(ts):
    out = []
    for t in ts:
        x = optimize.brentq(lambda x: mi(x*torch.ones(1), t).numpy(),
                            0, 20/second_eval, xtol=2e-5)
        out.append(x)
    return torch.tensor(out)

def alpha(ts):
    return np.exp(-beta_int(ts))

def beta(ts):
    out = []
    for t in ts:
        x = optimize.brentq(lambda x: mi(x*torch.ones(1), t).numpy(),
                            0, 20/second_eval, xtol=2e-5)
        out.append(x)
    grad = -1/torch.func.grad(lambda ts: mi(ts,).sum())(torch.tensor(out))
    return grad

ts = np.linspace(0.0001, 0.9999, 2000) 
int_ts = beta_int(ts)
mis = mi(int_ts)
plt.figure(figsize=[3, 3])
plt.plot(ts, mis, color='black')
plt.plot(ts, alpha(ts), color='red')
log_betas = torch.log(beta(ts))
log_betas = log_betas - log_betas.min()
plt.plot(ts, log_betas / log_betas.max(), color='blue')
plt.ylabel("MI")
plt.xlabel("time")
plt.ylim(0, 1.1)
plt.xlim(0, 1)

In [None]:
optimize.minimize(lambda x: (mi(x, ts)**2).sum().numpy(), torch.ones(len(ts)), method='CG')