In [5]:
import torch
import torch.nn as nn
from Models import TemporalDecoder, TemporalEncoder, DataDiffuser, TransitionNet, SimpleImageDecoder, SimpleImageEncoder, PositionalEncoder, StupidPositionalEncoder
import torch.optim as optim
from torchvision.utils import save_image
import torchvision
from torchvision import datasets, transforms
import numpy as np


def add_noise(x):
    """
    [0, 1] -> [0, 255] -> add noise -> [0, 1]
    """
    noise = x.new().resize_as_(x).uniform_()
    x = x * 255 + noise
    x = x / 256
    return x

def getMNISTDataLoader(bs):
    # MNIST Dataset
    train_dataset = datasets.MNIST(root='./mnist_data/', train=True, download=True, transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      #transforms.ToTensor(),
                                      #add_noise,
                                      ToTensor(),
        AddUniformNoise()
                                  ]))
    test_dataset = datasets.MNIST(root='./mnist_data/', train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      #transforms.ToTensor(),
                                      #add_noise,
                                      ToTensor(),
                                      AddUniformNoise()
                                  ]))

    # Data Loader (Input Pipeline)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

    return train_loader, test_loader

def getCIFAR10DataLoader(bs):
    # MNIST Dataset
    train_dataset = datasets.CIFAR10(root='./cifar10_data/', train=True, download=True, transform=transforms.Compose([
                                      transforms.Resize(32),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      add_noise,
                                      # transforms.ToTensor()
                                  ]))
    test_dataset = datasets.CIFAR10(root='./cifar10_data/', train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.Resize(32),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      add_noise,
                                      # transforms.ToTensor()
                                  ]))

    # Data Loader (Input Pipeline)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

    return train_loader, test_loader


def logit(x, alpha=1E-6):
    y = alpha + (1.-2*alpha)*x
    return np.log(y) - np.log(1. - y)


def logit_back(x, alpha=1E-6):
    y = torch.sigmoid(x)
    return (y - alpha)/(1.-2*alpha)


class AddUniformNoise(object):
    def __init__(self, alpha=1E-6):
        self.alpha = alpha
    def __call__(self,samples):
        samples = np.array(samples,dtype = np.float32)
        samples += np.random.uniform(size = samples.shape)
        samples = logit(samples/256., self.alpha)
        return samples


class ToTensor(object):
    def __init__(self):
        pass
    def __call__(self,samples):
        samples = torch.from_numpy(np.array(samples,dtype = np.float32)).float()
        return samples


class CNNDiffusionModel(nn.Module):
    def __init__(self, **kwargs):
        super(CNNDiffusionModel, self).__init__()
        self.T_MAX = kwargs['T_MAX']
        self.latent_s = kwargs['latent_s']
        self.t_emb_s = kwargs['t_emb_s']
        self.CNN = kwargs['CNN']
        self.register_buffer("beta_min", torch.tensor(kwargs['beta_min']))
        self.register_buffer("beta_max", torch.tensor(kwargs['beta_max']))
        self.device = 'cpu'
        self.img_size = [1, 32, 32]
        self.pos_enc = PositionalEncoder(self.t_emb_s // 2)  # StupidPositionalEncoder(T_MAX)  #
        self.simplified_trans = kwargs['simplified_trans']

        enc_net = [kwargs['enc_w']] * kwargs['enc_l']
        dec_net = [kwargs['dec_w']] * kwargs['dec_l']
        trans_net = [kwargs['trans_w']] * kwargs['trans_l']

        if self.CNN:
            self.enc = SimpleImageEncoder(self.img_size, self.latent_s, enc_net, t_dim=self.t_emb_s).to(dev)
            self.dec = SimpleImageDecoder(self.enc.features_dim, self.latent_s, dec_net, t_dim=self.t_emb_s,
                                          out_c=self.img_size[0]).to(dev)
        else:
            self.dec = TemporalDecoder(32**2, self.latent_s, dec_net, self.t_emb_s).to(dev)
            self.enc = TemporalEncoder(32**2, self.latent_s, enc_net, self.t_emb_s).to(dev)

        self.trans = TransitionNet(self.latent_s, trans_net, self.t_emb_s).to(dev)
        self.dif = DataDiffuser(beta_min=self.beta_min, beta_max=self.beta_max, t_max=self.T_MAX).to(dev)
        self.sampling_t0 = False

    def loss(self, x0):
        if self.sampling_t0:
            t0 = torch.randint(0, self.T_MAX - 1, [x0.shape[0]]).to(dev)
            x_t0, sigma_x_t0 = self.dif.diffuse(x0, t0, torch.zeros(x0.shape[0]).long().to(dev))
        else:
            t0 = torch.zeros(x0.shape[0]).to(dev).long()
            x_t0 = x0

        z_t0 = self.enc(x_t0.view(-1, *self.img_size), self.pos_enc(t0.float().unsqueeze(1)))
        # z_t0 = z_t0 + torch.randn(z_t0.shape).to(dev) * (1 - dif.alphas[t0]).sqrt().unsqueeze(1).expand(-1, z_t0.shape[1])
        t = torch.torch.distributions.Uniform(t0.float() + 1, torch.ones_like(t0) * self.T_MAX).sample().long().to(dev)

        z_t, sigma_z = self.dif.diffuse(z_t0, t, t0)
        x_t, sigma_x = self.dif.diffuse(x_t0, t, t0)

        mu_x_pred = self.dec(z_t, self.pos_enc(t.float().unsqueeze(1)))
        KL_x = ((mu_x_pred - x_t.view(bs, *self.img_size)) ** 2).view(bs, -1).sum(1) / sigma_x ** 2

        if self.simplified_trans:
            alpha_bar_t = self.dif.alphas[t].unsqueeze(1)#.expand(-1, self.latent_s)
            alpha_t = self.dif.alphas_t[t].unsqueeze(1)#.expand(-1, self.latent_s)
            beta_t = self.dif.betas[t].unsqueeze(1)#.expand(-1, self.latent_s)

            mu_z_pred = (z_t - beta_t/(1-alpha_bar_t).sqrt() * self.trans(z_t, self.pos_enc(t.float().unsqueeze(1))))/alpha_t.sqrt()
        else:
            mu_z_pred = self.trans(z_t, self.pos_enc(t.float().unsqueeze(1)))
        mu, sigma = self.dif.prev_mean(z_t0, z_t, t)

        KL_z = ((mu - mu_z_pred) ** 2).sum(1) / sigma ** 2

        loss = KL_x.mean(0) + KL_z.mean(0)

        return loss

    def to(self, device):
        super().to(device)
        self.device = device
        return self

    def sample(self, nb_samples=1):
        zT = torch.randn(64, self.latent_s).to(self.device)
        z_t = zT
        for t in range(self.T_MAX - 1, 0, -1):
            t_t = torch.ones(64, 1).to(self.device) * t
            if t > 0:
                sigma = ((1 - self.dif.alphas[t - 1]) / (1 - self.dif.alphas[t]) * self.dif.betas[t]).sqrt()
            else:
                sigma = 0
            if self.simplified_trans:
                alpha_bar_t = self.dif.alphas[t]
                alpha_t = self.dif.alphas_t[t]
                beta_t = self.dif.betas[t]
                mu_z_pred = (z_t - beta_t / (1 - alpha_bar_t).sqrt() * self.trans(z_t, self.pos_enc(t_t))) / alpha_t.sqrt()
            else:
                mu_z_pred = self.trans(z_t, self.pos_enc(t_t))
            z_t = mu_z_pred + torch.randn(z_t.shape, device=self.device) * sigma

        x_0 = self.dec(z_t, self.pos_enc(torch.zeros((nb_samples, 1), device=self.device))).view(nb_samples, -1)

        return x_0


import wandb
wandb.init(project="latent_diffusion", entity="awehenkel")


if __name__ == "__main__":
    bs = 100
    config = {
        'data': 'MNIST',
        'T_MAX': 20,
        'latent_s': 60,
        't_emb_s': 30,
        'CNN': False,
        'enc_w': 300,
        'enc_l': 4,
        'dec_w': 200,
        'dec_l': 3,
        'trans_w': 200,
        'trans_l': 3,
        "beta_min": 1e-2,
        "beta_max": .95,
        'simplified_trans': True
    }
    wandb.config.update(config)
    train_loader, test_loader = getMNISTDataLoader(bs)
    img_size = [1, 32, 32]

    # Compute Mean abd std per pixel
    x_mean = 0
    x_mean2 = 0
    for batch_idx, (cur_x, target) in enumerate(train_loader):
        cur_x = cur_x.view(bs, -1).float()
        x_mean += cur_x.mean(0)
        x_mean2 += (cur_x ** 2).mean(0)
    x_mean /= batch_idx + 1
    x_std = (x_mean2 / (batch_idx + 1) - x_mean ** 2) ** .5
    x_std[x_std == 0.] = 1.

    dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = CNNDiffusionModel(**config).to(dev)

    optimizer = optim.Adam(model.parameters(), lr=.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)

    wandb.watch(model)
    def get_X_back(x):
        nb_x = x.shape[0]
        x = x * x_std.to(dev).unsqueeze(0).expand(nb_x, -1) + x_mean.to(dev).unsqueeze(0).expand(nb_x, -1)
        return logit_back(x)


    #sample = list(train_loader)[0][0][[0]].expand(bs, -1, -1, -1)
    #save_image(get_X_back(sample.to(dev)[[0]].reshape(1, -1)).reshape(1, 3, 32, 32), './Samples/Generated/sample_rel_' + '.png')
    def train(epoch):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            #data = sample
            x0 = data.view(data.shape[0], -1).to(dev)

            x0 = (x0 - x_mean.to(dev).unsqueeze(0).expand(bs, -1)) / x_std.to(dev).unsqueeze(0).expand(bs, -1)
            optimizer.zero_grad()

            loss = model.loss(x0)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item() / len(data)))
        samples = get_X_back(model.sample(64)).view(64, *img_size)
        save_image(samples, './Samples/Generated/sample_gen_' + str(epoch) + '.png')
        scheduler.step(train_loss)
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))
        wandb.log({"Train Loss": train_loss / len(train_loader.dataset), "Samples": [wandb.Image(samples)]})




VBox(children=(Label(value=' 2.19MB of 2.19MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train Loss,19.00525
_step,44.0
_runtime,1324.0
_timestamp,1612877396.0


0,1
Train Loss,█▅▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██


[34m[1mwandb[0m: wandb version 0.10.18 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
for i in range(150):
    train(i)

====> Epoch: 0 Average loss: 22.5364
====> Epoch: 1 Average loss: 20.9040
====> Epoch: 2 Average loss: 20.3431
====> Epoch: 3 Average loss: 20.0035
====> Epoch: 4 Average loss: 19.8038
====> Epoch: 5 Average loss: 19.7747
====> Epoch: 6 Average loss: 19.5918
====> Epoch: 7 Average loss: 19.3602
====> Epoch: 8 Average loss: 19.4085
====> Epoch: 9 Average loss: 19.3569
====> Epoch: 10 Average loss: 19.1633
====> Epoch: 11 Average loss: 19.1506
====> Epoch: 12 Average loss: 19.1768
====> Epoch: 13 Average loss: 19.0924
====> Epoch: 14 Average loss: 18.9896
====> Epoch: 15 Average loss: 19.1124
====> Epoch: 16 Average loss: 19.0326
====> Epoch: 17 Average loss: 19.0500
====> Epoch: 18 Average loss: 19.0802
====> Epoch: 19 Average loss: 18.9496
====> Epoch: 20 Average loss: 18.8770
====> Epoch: 21 Average loss: 18.8360
====> Epoch: 22 Average loss: 18.8559
====> Epoch: 23 Average loss: 18.6561


====> Epoch: 24 Average loss: 18.9973
====> Epoch: 25 Average loss: 18.7107
====> Epoch: 26 Average loss: 18.8158
====> Epoch: 27 Average loss: 18.7133
====> Epoch: 28 Average loss: 18.8624
====> Epoch: 29 Average loss: 18.6395
====> Epoch: 30 Average loss: 18.7788
====> Epoch: 31 Average loss: 18.7625
====> Epoch: 32 Average loss: 18.5210
====> Epoch: 33 Average loss: 18.7998
====> Epoch: 34 Average loss: 18.7078
====> Epoch: 35 Average loss: 18.6770
====> Epoch: 36 Average loss: 18.8410
====> Epoch: 37 Average loss: 18.5908
====> Epoch: 38 Average loss: 18.6667
====> Epoch: 39 Average loss: 18.6351
====> Epoch: 40 Average loss: 18.8292
====> Epoch: 41 Average loss: 18.6465
====> Epoch: 42 Average loss: 18.5586
Epoch    44: reducing learning rate of group 0 to 5.0000e-04.
====> Epoch: 43 Average loss: 18.5445
====> Epoch: 44 Average loss: 18.4493
====> Epoch: 45 Average loss: 18.4838
====> Epoch: 46 Average loss: 18.4389


====> Epoch: 47 Average loss: 18.4853
====> Epoch: 48 Average loss: 18.4847
====> Epoch: 49 Average loss: 18.5505
====> Epoch: 50 Average loss: 18.5000
====> Epoch: 51 Average loss: 18.4090
====> Epoch: 52 Average loss: 18.3810
====> Epoch: 53 Average loss: 18.3482
====> Epoch: 54 Average loss: 18.4231
====> Epoch: 55 Average loss: 18.3811
====> Epoch: 56 Average loss: 18.3658
====> Epoch: 57 Average loss: 18.3847
====> Epoch: 58 Average loss: 18.4083
====> Epoch: 59 Average loss: 18.3994
====> Epoch: 60 Average loss: 18.3830
====> Epoch: 61 Average loss: 18.3249
====> Epoch: 62 Average loss: 18.2833
====> Epoch: 63 Average loss: 18.4154
====> Epoch: 64 Average loss: 18.2741
====> Epoch: 65 Average loss: 18.2993
====> Epoch: 66 Average loss: 18.1949
====> Epoch: 67 Average loss: 18.2880
====> Epoch: 68 Average loss: 18.3639
====> Epoch: 69 Average loss: 18.2075
====> Epoch: 70 Average loss: 18.2772


====> Epoch: 71 Average loss: 18.3173
====> Epoch: 72 Average loss: 18.2863
====> Epoch: 73 Average loss: 18.3271
====> Epoch: 74 Average loss: 18.3183
====> Epoch: 75 Average loss: 18.2962
====> Epoch: 76 Average loss: 18.3948
Epoch    78: reducing learning rate of group 0 to 2.5000e-04.
====> Epoch: 77 Average loss: 18.1975
====> Epoch: 78 Average loss: 18.2459
====> Epoch: 79 Average loss: 18.1752
====> Epoch: 80 Average loss: 18.3067
====> Epoch: 81 Average loss: 18.1817
====> Epoch: 82 Average loss: 18.0927
====> Epoch: 83 Average loss: 18.1907
====> Epoch: 84 Average loss: 18.1600
====> Epoch: 85 Average loss: 18.1878
====> Epoch: 86 Average loss: 18.1455
====> Epoch: 87 Average loss: 18.3201
====> Epoch: 88 Average loss: 18.2089
====> Epoch: 89 Average loss: 18.0474
====> Epoch: 90 Average loss: 18.1697
====> Epoch: 91 Average loss: 18.1798
====> Epoch: 92 Average loss: 18.3266
====> Epoch: 93 Average loss: 18.3122
====> Epoch: 94 Average loss: 18.2527


====> Epoch: 95 Average loss: 18.1814
====> Epoch: 96 Average loss: 18.2038
====> Epoch: 97 Average loss: 18.1677
====> Epoch: 98 Average loss: 18.1993
====> Epoch: 99 Average loss: 18.0573
Epoch   101: reducing learning rate of group 0 to 1.2500e-04.
====> Epoch: 100 Average loss: 18.1924
====> Epoch: 101 Average loss: 18.0349
====> Epoch: 102 Average loss: 17.9586
====> Epoch: 103 Average loss: 18.1697
====> Epoch: 104 Average loss: 18.1500
====> Epoch: 105 Average loss: 18.0078
====> Epoch: 106 Average loss: 18.1500
====> Epoch: 107 Average loss: 18.1380
====> Epoch: 108 Average loss: 18.1894
====> Epoch: 109 Average loss: 18.0725
====> Epoch: 110 Average loss: 18.1793
====> Epoch: 111 Average loss: 18.3087
====> Epoch: 112 Average loss: 18.1847
Epoch   114: reducing learning rate of group 0 to 6.2500e-05.
====> Epoch: 113 Average loss: 18.1200
====> Epoch: 114 Average loss: 18.0155
====> Epoch: 115 Average loss: 18.0229
====> Epoch: 116 Average loss: 18.2021
====> Epoch: 117 Averag

====> Epoch: 118 Average loss: 18.1465
====> Epoch: 119 Average loss: 18.0683
====> Epoch: 120 Average loss: 18.0523
====> Epoch: 121 Average loss: 18.0248
====> Epoch: 122 Average loss: 18.1620
====> Epoch: 123 Average loss: 18.1445
Epoch   125: reducing learning rate of group 0 to 3.1250e-05.
====> Epoch: 124 Average loss: 18.1525
====> Epoch: 125 Average loss: 18.1969
====> Epoch: 126 Average loss: 18.0236
====> Epoch: 127 Average loss: 18.1221
====> Epoch: 128 Average loss: 18.1611
====> Epoch: 129 Average loss: 18.1846
====> Epoch: 130 Average loss: 18.1967
====> Epoch: 131 Average loss: 18.1407
====> Epoch: 132 Average loss: 18.1387
====> Epoch: 133 Average loss: 18.2597
====> Epoch: 134 Average loss: 18.0779
Epoch   136: reducing learning rate of group 0 to 1.5625e-05.
====> Epoch: 135 Average loss: 18.0367
====> Epoch: 136 Average loss: 18.0944
====> Epoch: 137 Average loss: 17.9833
====> Epoch: 138 Average loss: 18.1895
====> Epoch: 139 Average loss: 18.2090
====> Epoch: 140 A

====> Epoch: 141 Average loss: 18.1108
====> Epoch: 142 Average loss: 18.1101
====> Epoch: 143 Average loss: 18.0824
====> Epoch: 144 Average loss: 18.0948
====> Epoch: 145 Average loss: 18.0508
Epoch   147: reducing learning rate of group 0 to 7.8125e-06.
====> Epoch: 146 Average loss: 18.0526
====> Epoch: 147 Average loss: 18.0377
====> Epoch: 148 Average loss: 18.1493
====> Epoch: 149 Average loss: 17.9820
