In [1]:
import torch
import matplotlib.pyplot as plt
from torch import nn, Tensor
from tqdm import tqdm
from src.tools import load_dataset
import numpy as np
# import wandb
from src.model import SongUNet

In [2]:
inner_iters = 10000
batch_size = 64
vocab_size = 256
T = 10

IMG_SIZE = 32
IMG_CHANNELS = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
from src.tools import load_dataset
sampler3, test_sampler3, loader3, test_loader3 = load_dataset('MNIST-colored_3', './datasets/MNIST', img_size=IMG_SIZE, batch_size=batch_size, device=device)
sampler2, test_sampler2, loader2, test_loader2 = load_dataset('MNIST-colored_2', './datasets/MNIST', img_size=IMG_SIZE, batch_size=batch_size, device=device)
Y_sampler = sampler3
X_sampler = sampler2

def sampler_0(batch_size: int = 200, device: str = "cpu") -> Tensor:
    x_end = X_sampler.sample(batch_size)
    # x_end = x_end.view(batch_size, -1).mul(0.5).add(0.5).clip(0,1)*255
    x_end = x_end.mul(0.5).add(0.5).clip(0,1)*(vocab_size-1)
    return x_end.long()

def sampler_1(batch_size: int = 200, device: str = "cpu") -> Tensor:
    x_end = Y_sampler.sample(batch_size)
    # x_end = x_end.view(batch_size, -1).mul(0.5).add(0.5).clip(0,1)*255
    x_end = x_end.mul(0.5).add(0.5).clip(0,1)*(vocab_size-1)
    return x_end.long()

In [4]:
from typing import Optional, Tuple
import torch.nn.functional as F

# Training

In [5]:
model_f = SongUNet(img_resolution=32, in_channels=3, out_channels=3, vocab_size=vocab_size, model_channels=96).to(device)
optim_f = torch.optim.Adam(model_f.parameters(), lr=1e-3)

model_b = SongUNet(img_resolution=32, in_channels=3, out_channels=3, vocab_size=vocab_size, model_channels=96).to(device)
optim_b = torch.optim.Adam(model_b.parameters(), lr=1e-3)

print('denoiser params:', np.sum([np.prod(p.shape) for p in model_f.parameters()]))

denoiser params: 35445888


In [6]:
model_f = torch.load('model_f.pt')

  model_f = torch.load('model_f.pt')


In [27]:
def train_with_data(model, optim, type='f'):
    # for iter in tqdm(range(inner_iters)):
    for iter in range(inner_iters):
        loss_sum = 0
        for i in range(100):
            t = torch.randint(low=1, high=T + 2, size=(batch_size,), device=device)
            x_1 = sampler_1(batch_size).view(batch_size, -1)
            x_0 = sampler_0(batch_size).view(batch_size, -1)
            
            if type=='f':
                x_t = torch.where(torch.randint(low=1, high=T + 2, size=(batch_size, model.d), device=device) <  t[:, None], x_1, x_0)
            else:
                x_t = torch.where(torch.randint(low=1, high=T + 2, size=(batch_size, model.d), device=device) <  t[:, None], x_0, x_1)

            x_t = x_t.view(x_1.shape[0], IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
            logits = model(x_t, t).flatten(start_dim=0, end_dim=-2)
            if type=='f':
                loss = F.cross_entropy(logits, x_0.flatten(start_dim=0, end_dim=-1)).mean()
            else:
                loss = F.cross_entropy(logits, x_1.flatten(start_dim=0, end_dim=-1)).mean()

            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_sum += loss.item()

        print('Loss:', loss_sum/100)
        visualize(sampler_1, model_f, f'samples/f/{iter}.png')
        torch.save(model, 'model_f.pt')

# Sampling

In [28]:
def cost(traj):
    costs = 0
    for i in range(len(traj)-1):
        costs += (traj[i] != traj[i+1]).float().mean()
    
    return costs

In [29]:
def visualize(sampler, model, name='samples/test.png'):
    N = 10

    x_t = sampler(N)
    x_0 = x_t
    t = T
    results = [(x_t, t)]
    while t > 0:
        p1 = torch.softmax(model(x_t, torch.ones(N).to(device) * t), dim=-1)
        one_hot_x_t = nn.functional.one_hot(x_t, vocab_size).float()
        u = (p1 - one_hot_x_t) / (t/(T+1))
        x_t = torch.distributions.Categorical(probs=one_hot_x_t + u/(T+1)).sample()
        t -= 1
        results.append((x_t.long(), t))

    x_1 = results[-1][0]

    fig, axes = plt.subplots(1, len(results), figsize=(15, 3), sharex=True, sharey=True)

    for i, (x_t, t) in enumerate(results):
        axes[i].imshow(x_t.permute(0,2,3,1)[0].detach().cpu())

    plt.tight_layout()
    plt.savefig(name)  # Saves as PNG
    plt.close()
    plt.show()
    # print('OT cost:', cost(results))

In [30]:
visualize(sampler_1, model_f)
# visualize(sampler_1, model_b)

In [33]:
train_with_data(model_f, optim_f, 'f')

Loss: 0.8918605768680572
Loss: 0.8898259776830674
Loss: 0.8905714559555054


KeyboardInterrupt: 

In [None]:
x_0 = sampler_0(10)
x_0.shape

In [None]:
pred_x_start = model_f.sample_trajectory(x_0, prior)[-2]
plt.imshow(x_0.to('cpu').permute(0,2,3,1)[0].numpy().clip(0,255))

In [None]:
plt.imshow(pred_x_start.to('cpu').permute(0,2,3,1)[0].numpy().clip(0,255))

In [None]:
pred_x_start.shape

In [None]:
def sample_from_model(x, model, batch_size):
    with torch.no_grad():
        t = 0.0
        while t < 1 - 1e-3:
            p1 = torch.softmax(model(x, torch.ones(batch_size).to(device) * t), dim=-1)
            h = min(0.01, 1.0 - t)
            one_hot_x_t = nn.functional.one_hot(x, vocab_size).float()
            u = (p1 - one_hot_x_t) / (1.0 - t)
            x = torch.distributions.Categorical(probs=one_hot_x_t + h * u).sample()
            t += h
    return x