In [1]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *

In [2]:
transform

Compose(
    ToTensor()
    Normalize(mean=(0.5,), std=(0.5,))
)

In [3]:
from torchvision import datasets
transformerr = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
])

# Time mapping

In [4]:
step = list(range(1,1001))
beta = dict()
for t in step:
    beta[t] = round(t * 0.0199 / 999 + (0.02 - 0.0199 / 999 * 1000), 15) # beta = at + b
alpha = dict()
alpha[1] = 1 - beta[1]
for t in step[1:]:
    alpha[t] = alpha[t-1] * (1 - beta[t])


In [5]:
import math
T = dict()
for t in step:
    T[t] = -1/2 * math.log(alpha[t])

In [6]:
T[1000] #collapse time experimental

5.058856771206552

# Speciation time

In [7]:
transformerr

Compose(
    ToTensor()
)

In [8]:
trainset = datasets.FashionMNIST('~/.pytorch/FashionMNIST_data/', train=True, download=True, transform=transformerr)

In [9]:
torch.mean(trainset[0][0])

tensor(0.3814)

In [10]:
mask_1 = trainset.targets == 2
mask_2 = trainset.targets == 7
mask = mask_1 | mask_2
trainset = torch.utils.data.Subset(trainset, torch.where(mask)[0])

In [11]:
dataset = torch.Tensor()

In [12]:
# Initialize an empty list to store tensors
tensor_list = []

# Loop through the dataset and collect the tensors
for x in range(len(trainset)):
    tensor_list.append(trainset[x][0])
    

# Concatenate all tensors at once
dataset = torch.cat(tensor_list, dim=0)

print(dataset.shape)


torch.Size([12000, 28, 28])


In [13]:
reduced = dataset[0:10000]

In [14]:
torch.mean(dataset), torch.mean(reduced)

(tensor(0.2722), tensor(0.2709))

In [15]:
dataset = dataset - torch.mean(dataset)
reduced = reduced - torch.mean(reduced)

In [16]:
row = []
for i in range(len(dataset)):
    row.append(dataset[i].flatten().tolist())

In [17]:
row_r = []
for i in range(len(reduced)):
    row_r.append(reduced[i].flatten().tolist())

In [18]:
torch.Tensor(row).shape, torch.Tensor(row_r).shape

(torch.Size([12000, 784]), torch.Size([10000, 784]))

In [19]:
row = torch.transpose(torch.Tensor(row),0,1)
row_r = torch.transpose(torch.Tensor(row_r),0,1)

In [20]:
C = torch.cov(torch.tensor(row))
C_r = torch.cov(torch.tensor(row_r))

  C = torch.cov(torch.tensor(row))
  C_r = torch.cov(torch.tensor(row_r))


In [21]:
C.shape, C_r.shape

(torch.Size([784, 784]), torch.Size([784, 784]))

In [22]:
with torch.no_grad():
    L, _ = torch.linalg.eigh(C, UPLO="L")
    L_r, _ = torch.linalg.eigh(C_r, UPLO="L")

0 -> 1: 31.77

-1 -> 1: 127.11

In [23]:
L[-1], L_r[-1]

(tensor(31.7777), tensor(31.6263))

In [24]:
t_s = 0.5 * np.log(L[-1])
t_s

tensor(1.7294)

In [25]:
mapped_time = min(T, key=lambda k: abs(T[k]-t_s))
mapped_time

584

# Collapse time

In [175]:
transformerr

Compose(
    ToTensor()
)

In [176]:
# MNIST
from torchvision import datasets, transforms
dataset = datasets.MNIST('~/.pytorch/MNIST_data/', train=True, download=True, transform=transformerr)

In [177]:
def first_n (N, mask):
    n = 0
    for i in range(len(mask)):
        if mask[i]:
            n += 1
        if n == N:
            for j in range(i+1, len(mask)):
                mask[j] = False
            break
    return mask

In [178]:
mask_1 = dataset.targets == 0
first_n(100, mask_1)
mask_2 = dataset.targets == 7
first_n(100, mask_2)
mask = mask_1 | mask_2
dataset = torch.utils.data.Subset(dataset, torch.where(mask)[0])


In [179]:
a = torch.Tensor()
for x in range(len(dataset)):
    a = torch.cat((a, dataset[x][0]), dim=0)

In [180]:
a.shape

torch.Size([200, 28, 28])

In [248]:
@torch.no_grad
def P(t, x):
    #t = T[t]
    d = 28 * 28
    logs = []
    for i in range(a.shape[0]):
        logs.append(np.log(1 / a.shape[0]) - d/2 * np.log( 2 * torch.pi * (1 - torch.e **(-2*t)) ) - 0.5 * torch.norm(x - a[i]*torch.e**(-t))**2 / (1-torch.e**(0-2*t)))
    return logs

In [249]:
P(0.09, a[0])

[tensor(-21.1284),
 tensor(-299.4906),
 tensor(-99.4059),
 tensor(-353.8616),
 tensor(-249.1019),
 tensor(-159.3817),
 tensor(-366.3493),
 tensor(-318.9525),
 tensor(-191.6410),
 tensor(-324.5208),
 tensor(-253.8734),
 tensor(-212.1662),
 tensor(-378.3724),
 tensor(-260.0832),
 tensor(-316.6208),
 tensor(-182.4068),
 tensor(-325.6183),
 tensor(-172.0652),
 tensor(-366.4845),
 tensor(-256.1553),
 tensor(-322.3724),
 tensor(-163.1113),
 tensor(-333.0019),
 tensor(-320.7443),
 tensor(-309.5170),
 tensor(-364.6244),
 tensor(-214.9875),
 tensor(-160.2290),
 tensor(-99.8308),
 tensor(-336.0434),
 tensor(-329.6689),
 tensor(-333.2940),
 tensor(-339.9143),
 tensor(-323.4059),
 tensor(-228.1105),
 tensor(-342.8399),
 tensor(-398.6082),
 tensor(-213.8842),
 tensor(-308.4899),
 tensor(-152.8037),
 tensor(-316.9795),
 tensor(-388.8369),
 tensor(-287.5333),
 tensor(-315.1446),
 tensor(-254.4215),
 tensor(-335.5030),
 tensor(-366.6305),
 tensor(-318.6136),
 tensor(-321.3599),
 tensor(-302.5354),
 te

In [246]:
t =0.09
x = a[0]
d = 28 * 28
lg = np.log(1 / a.shape[0]) - d/2 * np.log( 2 * torch.pi * (1 - torch.e **(-2*t)) ) - 0.5 * torch.norm(x - a[0]*torch.e**(-t))**2 / (1-torch.e**(0-2*t))

In [247]:
lg

tensor(-21.1284)

In [182]:
P(1,a[4])

tensor(0.)

In [98]:
T[150]

0.1189247332798637