In [9]:
import torch
import torch.nn as nn
from Models import TemporalDecoder, TemporalEncoder, DataDiffuser, TransitionNet, SimpleImageDecoder, SimpleImageEncoder, PositionalEncoder, StupidPositionalEncoder
import torch.optim as optim
from torchvision.utils import save_image
import torchvision
from torchvision import datasets, transforms
import numpy as np


def add_noise(x):
    """
    [0, 1] -> [0, 255] -> add noise -> [0, 1]
    """
    noise = x.new().resize_as_(x).uniform_()
    x = x * 255 + noise
    x = x / 256
    return x

def getMNISTDataLoader(bs):
    # MNIST Dataset
    train_dataset = datasets.MNIST(root='./mnist_data/', train=True, download=True, transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      #transforms.ToTensor(),
                                      #add_noise,
                                      ToTensor(),
        AddUniformNoise()
                                  ]))
    test_dataset = datasets.MNIST(root='./mnist_data/', train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      #transforms.ToTensor(),
                                      #add_noise,
                                      ToTensor(),
                                      AddUniformNoise()
                                  ]))

    # Data Loader (Input Pipeline)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

    return train_loader, test_loader

def getCIFAR10DataLoader(bs):
    # MNIST Dataset
    train_dataset = datasets.CIFAR10(root='./cifar10_data/', train=True, download=True, transform=transforms.Compose([
                                      transforms.Resize(32),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      add_noise,
                                      # transforms.ToTensor()
                                  ]))
    test_dataset = datasets.CIFAR10(root='./cifar10_data/', train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.Resize(32),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      add_noise,
                                      # transforms.ToTensor()
                                  ]))

    # Data Loader (Input Pipeline)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

    return train_loader, test_loader


def logit(x, alpha=1E-6):
    y = alpha + (1.-2*alpha)*x
    return np.log(y) - np.log(1. - y)


def logit_back(x, alpha=1E-6):
    y = torch.sigmoid(x)
    return (y - alpha)/(1.-2*alpha)


class AddUniformNoise(object):
    def __init__(self, alpha=1E-6):
        self.alpha = alpha
    def __call__(self,samples):
        samples = np.array(samples,dtype = np.float32)
        samples += np.random.uniform(size = samples.shape)
        samples = logit(samples/256., self.alpha)
        return samples


class ToTensor(object):
    def __init__(self):
        pass
    def __call__(self,samples):
        samples = torch.from_numpy(np.array(samples,dtype = np.float32)).float()
        return samples


class CNNDiffusionModel(nn.Module):
    def __init__(self, **kwargs):
        super(CNNDiffusionModel, self).__init__()
        self.T_MAX = kwargs['T_MAX']
        self.latent_s = kwargs['latent_s']
        self.t_emb_s = kwargs['t_emb_s']
        self.CNN = kwargs['CNN']
        self.register_buffer("beta_min", torch.tensor(kwargs['beta_min']))
        self.register_buffer("beta_max", torch.tensor(kwargs['beta_max']))
        self.device = 'cpu'
        self.img_size = [1, 32, 32]
        self.pos_enc = PositionalEncoder(self.t_emb_s // 2)  # StupidPositionalEncoder(T_MAX)  #
        self.simplified_trans = kwargs['simplified_trans']

        enc_net = [kwargs['enc_w']] * kwargs['enc_l']
        dec_net = [kwargs['dec_w']] * kwargs['dec_l']
        trans_net = [kwargs['trans_w']] * kwargs['trans_l']

        if self.CNN:
            self.enc = SimpleImageEncoder(self.img_size, self.latent_s, enc_net, t_dim=self.t_emb_s).to(dev)
            self.dec = SimpleImageDecoder(self.enc.features_dim, self.latent_s, dec_net, t_dim=self.t_emb_s,
                                          out_c=self.img_size[0]).to(dev)
        else:
            self.dec = TemporalDecoder(32**2, self.latent_s, dec_net, self.t_emb_s).to(dev)
            self.enc = TemporalEncoder(32**2, self.latent_s, enc_net, self.t_emb_s).to(dev)

        self.trans = TransitionNet(self.latent_s, trans_net, self.t_emb_s).to(dev)
        self.dif = DataDiffuser(beta_min=self.beta_min, beta_max=self.beta_max, t_max=self.T_MAX).to(dev)
        self.sampling_t0 = False

    def loss(self, x0):
        if self.sampling_t0:
            t0 = torch.randint(0, self.T_MAX - 1, [x0.shape[0]]).to(dev)
            x_t0, sigma_x_t0 = self.dif.diffuse(x0, t0, torch.zeros(x0.shape[0]).long().to(dev))
        else:
            t0 = torch.zeros(x0.shape[0]).to(dev).long()
            x_t0 = x0

        z_t0 = self.enc(x_t0.view(-1, *self.img_size), self.pos_enc(t0.float().unsqueeze(1)))
        # z_t0 = z_t0 + torch.randn(z_t0.shape).to(dev) * (1 - dif.alphas[t0]).sqrt().unsqueeze(1).expand(-1, z_t0.shape[1])
        t = torch.torch.distributions.Uniform(t0.float() + 1, torch.ones_like(t0) * self.T_MAX).sample().long().to(dev)

        z_t, sigma_z = self.dif.diffuse(z_t0, t, t0)
        x_t, sigma_x = self.dif.diffuse(x_t0, t, t0)

        mu_x_pred = self.dec(z_t, self.pos_enc(t.float().unsqueeze(1)))
        KL_x = ((mu_x_pred - x_t.view(bs, *self.img_size)) ** 2).view(bs, -1).sum(1) / sigma_x ** 2

        if self.simplified_trans:
            alpha_bar_t = self.dif.alphas[t].unsqueeze(1)#.expand(-1, self.latent_s)
            alpha_t = self.dif.alphas_t[t].unsqueeze(1)#.expand(-1, self.latent_s)
            beta_t = self.dif.betas[t].unsqueeze(1)#.expand(-1, self.latent_s)

            mu_z_pred = (z_t - beta_t/(1-alpha_bar_t).sqrt() * self.trans(z_t, self.pos_enc(t.float().unsqueeze(1))))/alpha_t.sqrt()
        else:
            mu_z_pred = self.trans(z_t, self.pos_enc(t.float().unsqueeze(1)))
        mu, sigma = self.dif.prev_mean(z_t0, z_t, t)

        KL_z = ((mu - mu_z_pred) ** 2).sum(1) / sigma ** 2

        loss = KL_x.mean(0) + KL_z.mean(0)

        return loss

    def to(self, device):
        super().to(device)
        self.device = device
        return self

    def sample(self, nb_samples=1):
        zT = torch.randn(64, self.latent_s).to(self.device)
        z_t = zT
        for t in range(self.T_MAX - 1, 0, -1):
            t_t = torch.ones(64, 1).to(self.device) * t
            if t > 0:
                sigma = ((1 - self.dif.alphas[t - 1]) / (1 - self.dif.alphas[t]) * self.dif.betas[t]).sqrt()
            else:
                sigma = 0
            if self.simplified_trans:
                alpha_bar_t = self.dif.alphas[t]
                alpha_t = self.dif.alphas_t[t]
                beta_t = self.dif.betas[t]
                mu_z_pred = (z_t - beta_t / (1 - alpha_bar_t).sqrt() * self.trans(z_t, self.pos_enc(t_t))) / alpha_t.sqrt()
            else:
                mu_z_pred = self.trans(z_t, self.pos_enc(t_t))
            z_t = mu_z_pred + torch.randn(z_t.shape, device=self.device) * sigma

        x_0 = self.dec(z_t, self.pos_enc(torch.zeros((nb_samples, 1), device=self.device))).view(nb_samples, -1)

        return x_0


import wandb
wandb.init(project="latent_diffusion", entity="awehenkel")


if __name__ == "__main__":
    bs = 100
    config = {
        'data': 'MNIST',
        'T_MAX': 20,
        'latent_s': 60,
        't_emb_s': 30,
        'CNN': False,
        'enc_w': 300,
        'enc_l': 4,
        'dec_w': 200,
        'dec_l': 3,
        'trans_w': 200,
        'trans_l': 3,
        "beta_min": 1e-4,
        "beta_max": .95,
        'simplified_trans': True
    }
    wandb.config.update(config)
    train_loader, test_loader = getMNISTDataLoader(bs)
    img_size = [1, 32, 32]

    # Compute Mean abd std per pixel
    x_mean = 0
    x_mean2 = 0
    for batch_idx, (cur_x, target) in enumerate(train_loader):
        cur_x = cur_x.view(bs, -1).float()
        x_mean += cur_x.mean(0)
        x_mean2 += (cur_x ** 2).mean(0)
    x_mean /= batch_idx + 1
    x_std = (x_mean2 / (batch_idx + 1) - x_mean ** 2) ** .5
    x_std[x_std == 0.] = 1.

    dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = CNNDiffusionModel(**config).to(dev)

    optimizer = optim.Adam(model.parameters(), lr=.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)

    wandb.watch(model)
    def get_X_back(x):
        nb_x = x.shape[0]
        x = x * x_std.to(dev).unsqueeze(0).expand(nb_x, -1) + x_mean.to(dev).unsqueeze(0).expand(nb_x, -1)
        return logit_back(x)


    #sample = list(train_loader)[0][0][[0]].expand(bs, -1, -1, -1)
    #save_image(get_X_back(sample.to(dev)[[0]].reshape(1, -1)).reshape(1, 3, 32, 32), './Samples/Generated/sample_rel_' + '.png')
    def train(epoch):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            #data = sample
            x0 = data.view(data.shape[0], -1).to(dev)

            x0 = (x0 - x_mean.to(dev).unsqueeze(0).expand(bs, -1)) / x_std.to(dev).unsqueeze(0).expand(bs, -1)
            optimizer.zero_grad()

            loss = model.loss(x0)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item() / len(data)))
        samples = get_X_back(model.sample(64)).view(64, *img_size)
        save_image(samples, './Samples/Generated/sample_gen_' + str(epoch) + '.png')
        scheduler.step(train_loss)
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))
        wandb.log({"Train Loss": train_loss / len(train_loader.dataset), "Samples": [wandb.Image(samples)]})




VBox(children=(Label(value=' 2.71MB of 2.71MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train Loss,18.32189
_step,74.0
_runtime,1836.0
_timestamp,1612890275.0


0,1
Train Loss,█▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.18 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [10]:
for i in range(150):
    train(i)

====> Epoch: 0 Average loss: 26.5507
====> Epoch: 1 Average loss: 24.3633
====> Epoch: 2 Average loss: 23.4536
====> Epoch: 3 Average loss: 23.3545
====> Epoch: 4 Average loss: 22.6123
====> Epoch: 5 Average loss: 22.6356
====> Epoch: 6 Average loss: 22.6136
====> Epoch: 7 Average loss: 22.4249
====> Epoch: 8 Average loss: 22.1911
====> Epoch: 9 Average loss: 22.2719
====> Epoch: 10 Average loss: 22.0822
====> Epoch: 11 Average loss: 21.9481
====> Epoch: 12 Average loss: 22.0816
====> Epoch: 13 Average loss: 22.0170
====> Epoch: 14 Average loss: 21.6598
====> Epoch: 15 Average loss: 22.0272
====> Epoch: 16 Average loss: 21.8616
====> Epoch: 17 Average loss: 21.8325
====> Epoch: 18 Average loss: 21.8477
====> Epoch: 19 Average loss: 21.6802
====> Epoch: 20 Average loss: 21.7897
====> Epoch: 21 Average loss: 21.5523
====> Epoch: 22 Average loss: 21.7817
====> Epoch: 23 Average loss: 21.5163


====> Epoch: 24 Average loss: 21.6126
====> Epoch: 25 Average loss: 21.5007
====> Epoch: 26 Average loss: 21.6930
====> Epoch: 27 Average loss: 21.4504
====> Epoch: 28 Average loss: 21.7593
====> Epoch: 29 Average loss: 21.3713
====> Epoch: 30 Average loss: 21.6203
====> Epoch: 31 Average loss: 21.3668
====> Epoch: 32 Average loss: 21.7013
====> Epoch: 33 Average loss: 21.3850
====> Epoch: 34 Average loss: 21.4176
====> Epoch: 35 Average loss: 21.2590
====> Epoch: 36 Average loss: 21.4432
====> Epoch: 37 Average loss: 21.5564
====> Epoch: 38 Average loss: 21.2739
====> Epoch: 39 Average loss: 21.2656
====> Epoch: 40 Average loss: 21.1435
====> Epoch: 41 Average loss: 21.1821
====> Epoch: 42 Average loss: 21.2217
====> Epoch: 43 Average loss: 21.1919
====> Epoch: 44 Average loss: 21.2158
====> Epoch: 45 Average loss: 21.2879
====> Epoch: 46 Average loss: 21.4879


====> Epoch: 47 Average loss: 21.2045
====> Epoch: 48 Average loss: 21.1820
====> Epoch: 49 Average loss: 21.3136
====> Epoch: 50 Average loss: 20.9593
====> Epoch: 51 Average loss: 21.2073
====> Epoch: 52 Average loss: 21.2231
====> Epoch: 53 Average loss: 21.0874
====> Epoch: 54 Average loss: 21.1265
====> Epoch: 55 Average loss: 21.3099
====> Epoch: 56 Average loss: 21.1703
====> Epoch: 57 Average loss: 21.0285
====> Epoch: 58 Average loss: 21.0322
====> Epoch: 59 Average loss: 21.1485
====> Epoch: 60 Average loss: 21.1512
Epoch    62: reducing learning rate of group 0 to 5.0000e-04.
====> Epoch: 61 Average loss: 21.2624
====> Epoch: 62 Average loss: 20.9032
====> Epoch: 63 Average loss: 20.9617
====> Epoch: 64 Average loss: 21.0579
====> Epoch: 65 Average loss: 20.8959
====> Epoch: 66 Average loss: 21.0176
====> Epoch: 67 Average loss: 20.7911
====> Epoch: 68 Average loss: 20.6982
====> Epoch: 69 Average loss: 21.0411
====> Epoch: 70 Average loss: 20.9309


====> Epoch: 71 Average loss: 20.7182
====> Epoch: 72 Average loss: 20.9817
====> Epoch: 73 Average loss: 20.8917
====> Epoch: 74 Average loss: 20.7809
====> Epoch: 75 Average loss: 20.9465
====> Epoch: 76 Average loss: 20.7439
====> Epoch: 77 Average loss: 20.8823
====> Epoch: 78 Average loss: 21.0301
Epoch    80: reducing learning rate of group 0 to 2.5000e-04.
====> Epoch: 79 Average loss: 20.9396
====> Epoch: 80 Average loss: 21.0493
====> Epoch: 81 Average loss: 20.6840
====> Epoch: 82 Average loss: 20.7288
====> Epoch: 83 Average loss: 20.9738


KeyboardInterrupt: 