In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math

from unet import UNetModel
from diffusion import GaussianDiffusion

import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rc('image', cmap='gray')

## Load Data

In [None]:
# Load MNIST dataset
device = torch.device('cuda:0')
batch_size = 128

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Pad(2),
    torchvision.transforms.Normalize(0.5, 0.5),
])
mnist_train = torchvision.datasets.MNIST(root='data/', train=True, transform=transforms, download=True)
data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)

for batch in data_loader:
    img, labels = batch
    break
    
fig, ax = plt.subplots(1, 4, figsize=(15,15))
for i in range(4):
    ax[i].imshow(img[i,0,:,:].numpy())
    ax[i].set_title(str(labels[i].item()))
plt.show()

## Load Model

In [None]:
# Save/Load model
#torch.save(net.state_dict(), 'models/mnist_unet.pth')
#print('Saved model')

net = UNetModel(image_size=32, in_channels=1, out_channels=1, 
                model_channels=64, num_res_blocks=2, channel_mult=(1,2,3,4),
                attention_resolutions=[8,4], num_heads=4).to(device)
net.load_state_dict(torch.load('models/mnist_unet.pth'))
net.to(device)
net.train()
print('Loaded model')

## Inference with hand-crafted conditions

In [None]:
# Visualize t annealing schedule
diffusion = GaussianDiffusion(T=1000, schedule='linear')
steps = 1000

t_vals = []
for i in range(steps):
    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine
    t = np.array([t]).astype(int)
    t = np.clip(t, 1, diffusion.T)
    t_vals.append(t[0])
    
plt.figure(figsize=(8,5))
plt.plot(range(steps), t_vals, linewidth=2)
plt.title('$t$ Annealing Schedule')
plt.xlabel('Steps')
plt.ylabel('$t$')
plt.show()

In [None]:
# Inference Model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.img = nn.Parameter(torch.randn(1,1,32,32))
        self.img.requires_grad = True
                
    def encode(self):
        return self.img
    
model = Model().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)

diffusion = GaussianDiffusion(T=1000, schedule='linear')
net.train()

steps = 1000
bar = tqdm.tqdm(range(steps))
losses = []
update_every = 50
for i, _ in enumerate(bar):
    sample_img = model.encode()
   
    # Select t
    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine
    t = np.array([t]).astype(int)
    t = np.clip(t, 1, diffusion.T)
    
    # Denoise
    xt, epsilon = diffusion.sample(sample_img, t)       
    t = torch.from_numpy(t).float().view(sample_img.shape[0])
    epsilon_pred = net(xt.float(), t.to(device))

    # Hand-crafted conditions
    sample_img_clipped = torch.clip(sample_img, -1, 1)
    #vertical_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(model.encode()))
    #horizontal_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(model.encode()))
    #vertical_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(sample_img_clipped))
    horizontal_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(sample_img_clipped))

    # Denoising loss + aux loss
    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*horizontal_dissimilarity
    
    # Update
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    losses.append(loss.item())
    if i % update_every == 0:
        bar.set_postfix({'Loss': np.mean(losses)})
        losses = []

    # Visualize sample
    if (i+1) % 100 == 0 or i == 0:
        with torch.no_grad():
            fig, ax = plt.subplots(1, 1, figsize=(5,5))
            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)
            ax.set_title('Inferred sample')
            plt.show()

## Inference with a learned condition

In [None]:
# Train classifier
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
        self.conv3 = nn.Conv2d(32, 32, 3, 2, 1)
        self.out = nn.Linear(4*4*32, 1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.out(x.flatten(1))
        
        return torch.sigmoid(x)
    
# Train network
class_net = Classifier().to(device)
class_opt = torch.optim.Adam(class_net.parameters(), lr=1e-4)

target_label = 3 # Digit to distinguish
epochs = 5
update_every = 100
for e in range(epochs):
    print(f'Epoch [{e+1}/{epochs}]')
    
    losses = []
    batch_bar = tqdm.tqdm(data_loader)
    for i, batch in enumerate(batch_bar):
        img, labels = batch
        
        labels = (labels != target_label).float().to(device)
        
        # Pass through network
        out = class_net(img.float().to(device))
        
        # Compute loss and backprop
        loss = F.binary_cross_entropy(out.squeeze(-1), labels)
        
        class_opt.zero_grad()
        loss.backward()
        class_opt.step()
        
        losses.append(loss.item())
        if i % update_every == 0:
            batch_bar.set_postfix({'Loss': np.mean(losses)})
            losses = []
            
    batch_bar.set_postfix({'Loss': np.mean(losses)})
    losses = []
    
    plt.figure(figsize=(5,5))
    plt.imshow(img.numpy()[0,0,:,:])
    plt.title(f'Score {out[0].item():.3f}')
    plt.show()

In [None]:
# Inference Model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.img = nn.Parameter(torch.randn(1,1,32,32))
        self.img.requires_grad = True
                
    def encode(self):
        return self.img
    
model = Model().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)

diffusion = GaussianDiffusion(T=1000, schedule='linear')
net.train()
class_net.train()

steps = 1000
bar = tqdm.tqdm(range(steps))
losses = []
update_every = 50
for i, _ in enumerate(bar):
    sample_img = model.encode()
   
    # Select t
    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine
    t = np.array([t]).astype(int)
    t = np.clip(t, 1, diffusion.T)
    
    # Denoise
    xt, epsilon = diffusion.sample(sample_img, t)       
    t = torch.from_numpy(t).float().view(sample_img.shape[0])
    epsilon_pred = net(xt.float(), t.to(device))

    # Learned condition
    sample_img_clipped = torch.clip(sample_img, -1, 1)
    class_loss = class_net(sample_img_clipped).mean()

    # Denoising loss + aux loss
    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*class_loss
    
    # Update
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    losses.append(loss.item())
    if i % update_every == 0:
        bar.set_postfix({'Loss': np.mean(losses)})
        losses = []

    # Visualize sample
    if (i+1) % 100 == 0 or i == 0:
        with torch.no_grad():
            fig, ax = plt.subplots(1, 1, figsize=(5,5))
            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)
            ax.set_title('Inferred sample')
            plt.show()

## Inference with hand-crafted and learned conditions

In [None]:
# Inference Model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.img = nn.Parameter(torch.randn(1,1,32,32))
        self.img.requires_grad = True
                
    def encode(self):
        return self.img
    
model = Model().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)

diffusion = GaussianDiffusion(T=1000, schedule='linear')
net.train()
class_net.train()

steps = 1000
bar = tqdm.tqdm(range(steps))
losses = []
update_every = 50
for i, _ in enumerate(bar):
    sample_img = model.encode()
   
    # Select t
    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine
    t = np.array([t]).astype(int)
    t = np.clip(t, 1, diffusion.T)
    
    # Denoise
    xt, epsilon = diffusion.sample(sample_img, t)       
    t = torch.from_numpy(t).float().view(sample_img.shape[0])
    epsilon_pred = net(xt.float(), t.to(device))

    # Conditions
    sample_img_clipped = torch.clip(sample_img, -1, 1)
    # Hand-crafted
    #vertical_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(model.encode()))
    horizontal_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(model.encode()))
    #vertical_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(sample_img_clipped))
    #horizontal_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(sample_img_clipped))
    # Learned
    class_loss = class_net(sample_img_clipped).mean()

    # Denoising loss + aux loss
    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*class_loss + 0.01*(steps-i)/steps*horizontal_similarity
    
    # Update
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    losses.append(loss.item())
    if i % update_every == 0:
        bar.set_postfix({'Loss': np.mean(losses)})
        losses = []

    # Visualize sample
    if (i+1) % 100 == 0 or i == 0:
        with torch.no_grad():
            fig, ax = plt.subplots(1, 1, figsize=(5,5))
            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)
            ax.set_title('Inferred sample')
            plt.show()