In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import numpy as np

In [2]:
import torchvision
def logit(x, alpha=1E-6):
    y = alpha + (1.-2*alpha)*x
    return np.log(y) - np.log(1. - y)

def logit_back(x, alpha=1E-6):
    y = torch.sigmoid(x)
    return (y - alpha)/(1.-2*alpha)

class AddUniformNoise(object):
    def __init__(self, alpha=1E-6):
        self.alpha = alpha
    def __call__(self,samples):
        samples = np.array(samples,dtype = np.float32)
        samples += np.random.uniform(size = samples.shape)
        samples = logit(samples/256., self.alpha)
        return samples

class ToTensor(object):
    def __init__(self):
        pass
    def __call__(self,samples):
        samples = torch.from_numpy(np.array(samples,dtype = np.float32)).float()
        return samples


In [3]:
bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, download=True, transform=transforms.Compose([
                       AddUniformNoise(),
                       ToTensor()
    #transforms.ToTensor()
                   ]))
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, download=True,
                   transform=transforms.Compose([
                       AddUniformNoise(),
                       ToTensor()
                       #transforms.ToTensor()
                   ]))

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [4]:
# Compute Mean abd std per pixel
x_mean = 0
x_mean2 = 0
for batch_idx, (cur_x, target) in enumerate(train_loader):
    cur_x = cur_x.view(bs, -1).float()
    x_mean += cur_x.mean(0)
    x_mean2 += (cur_x ** 2).mean(0)
x_mean /= batch_idx + 1
x_std = (x_mean2 / (batch_idx + 1) - x_mean ** 2) ** .5
x_mean, x_std
x_std[x_std == 0.] = 1.
x_mean, x_std

(tensor([-6.5465, -6.5416, -6.5377, -6.5430, -6.5337, -6.5398, -6.5445, -6.5419,
         -6.5440, -6.5391, -6.5383, -6.5458, -6.5421, -6.5434, -6.5438, -6.5419,
         -6.5392, -6.5327, -6.5465, -6.5348, -6.5409, -6.5490, -6.5432, -6.5420,
         -6.5393, -6.5467, -6.5407, -6.5384, -6.5403, -6.5375, -6.5447, -6.5398,
         -6.5451, -6.5348, -6.5407, -6.5413, -6.5399, -6.5366, -6.5314, -6.5305,
         -6.5351, -6.5325, -6.5323, -6.5302, -6.5382, -6.5338, -6.5429, -6.5368,
         -6.5379, -6.5393, -6.5363, -6.5439, -6.5406, -6.5411, -6.5402, -6.5391,
         -6.5400, -6.5335, -6.5355, -6.5420, -6.5443, -6.5364, -6.5362, -6.5363,
         -6.5237, -6.5245, -6.4940, -6.4739, -6.4447, -6.4161, -6.3762, -6.3549,
         -6.3553, -6.3675, -6.3938, -6.4338, -6.4740, -6.5084, -6.5285, -6.5276,
         -6.5386, -6.5391, -6.5449, -6.5364, -6.5437, -6.5379, -6.5365, -6.5389,
         -6.5391, -6.5375, -6.5210, -6.5069, -6.4756, -6.4216, -6.3536, -6.2548,
         -6.1656, -6.0520, -

In [31]:
class dataDiffuser(nn.Module):
    def __init__(self, beta_min=1e-4, beta_max=.02, t_min=1, t_max=1000):
        super(dataDiffuser, self).__init__()
        self.register_buffer('betas', torch.arange(beta_min, beta_max + 1e-10, (beta_max - beta_min) / (t_max - t_min)))
        self.register_buffer('alphas_t', (1 - self.betas))
        self.register_buffer('alphas', self.alphas_t.log().cumsum(0).exp())

    def diffuse(self, x_t0, t, t0=0):
        
        alpha_t0 = 1 * (t0 == 0).float() + (1 - (t0 == 0).float()) * self.alphas[t0-1]

        mu = x_t0*(self.alphas[t]/alpha_t0).sqrt().unsqueeze(1).expand(-1, x_t0.shape[1]).float()
        #mu = x_t0 * self.alphas[t].sqrt().unsqueeze(1).expand(-1, x_t0.shape[1]).float()
        sigma_t = ((self.alphas[t]/alpha_t0) * (1 - alpha_t0) + (1 - self.alphas[t])).sqrt()
        sigma = sigma_t.unsqueeze(1).expand(-1, x_t0.shape[1]).float()
        #sigma = (1 - self.alphas[t].unsqueeze(1).expand(-1, x_t0.shape[1]).float()).sqrt()
        return mu + torch.randn(x_t0.shape).to(x_t0.device) * sigma, sigma_t
    
    def prevMean(self, x_t, x_0, t):
        alphas = self.alphas.unsqueeze(1).expand(-1, x_t.shape[1]).float()
        betas = self.betas.unsqueeze(1).expand(-1, x_t.shape[1]).float()
        alphas_t = self.alphas_t.unsqueeze(1).expand(-1,x_t.shape[1]).float()
        mu = alphas[t - 1].sqrt() * betas[t] * x_0/(1 - alphas[t]) + alphas_t[t].sqrt()*(1 - alphas[t-1])*x_t/(1 - alphas[t])
        sigma = ((1 - self.alphas[t-1])/(1 - self.alphas[t]) * self.betas[t]).sqrt()
        return mu, sigma

class TemporalDecoder(nn.Module):
    def __init__(self, x_dim, z_dim, h_dim, t_dim=1):
        super(TemporalDecoder, self).__init__()
        # decoder part
        self.net = nn.Sequential(nn.Linear(z_dim + t_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, h_dim), nn.ReLU(),                                 
                                 #nn.Linear(h_dim, h_dim), nn.ReLU(),                               
                                 #nn.Linear(h_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, x_dim))
        
    def forward(self, z, t):
        return self.net(torch.cat((z, t), 1))

class PositionalEncoder(nn.Module):
    def __init__(self, dim):
        super(PositionalEncoder, self).__init__()
        self.dim = dim
        
    def forward(self, t):
        emb = t/torch.exp(torch.arange(self.dim).float()/self.dim * torch.log(torch.ones(1, self.dim) * 100)).to(t.device)
        return torch.cat((torch.sin(emb), torch.cos(emb)), 1)

class StupidPositionalEncoder(nn.Module):
    def __init__(self, T_MAX):
        super(StupidPositionalEncoder, self).__init__()
        self.T_MAX = T_MAX
        
    def forward(self, t):
        return t.float()/self.T_MAX
    
class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, h_dim):
            super(Encoder, self).__init__()
            # decoder part
            self.net = nn.Sequential(nn.Linear(x_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),                                     
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),                                  
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),                               
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, z_dim))
        
    def forward(self, x):
        return self.net(x)
    
class TemporalEncoder(nn.Module):
    def __init__(self, x_dim, z_dim, h_dim, t_dim=1):
            super(TemporalEncoder, self).__init__()
            # decoder part
            self.net = nn.Sequential(nn.Linear(x_dim + t_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),                                     
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),                                  
                                     #nn.Linear(h_dim, h_dim), nn.ReLU(),                               
                                     #nn.Linear(h_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, z_dim))
        
    def forward(self, x, t):
        return self.net(torch.cat((x, t), 1))
    
class TransitionNet(nn.Module):
    def __init__(self, z_dim, h_dim, t_dim=1):
        super(TransitionNet, self).__init__()
        self.net = nn.Sequential(nn.Linear(z_dim + t_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, h_dim), nn.ReLU(),                                  
                                 #nn.Linear(h_dim, h_dim), nn.ReLU(),                               
                                 #nn.Linear(h_dim, h_dim), nn.ReLU(),
                                 nn.Linear(h_dim, z_dim))
    def forward(self, z, t):
        return self.net(torch.cat((z, t), 1))

In [32]:
T_MAX = 25
latent_s = 25
t_emb_s = 1
pos_enc = StupidPositionalEncoder(T_MAX)#PositionalEncoder(t_emb_s//2)#
dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dec = TemporalDecoder(784, latent_s, 256, t_emb_s).to(dev)
enc = TemporalEncoder(784, latent_s, 256, t_emb_s).to(dev)
trans = TransitionNet(latent_s, 100, t_emb_s).to(dev)
dif = dataDiffuser(beta_min=1e-2, beta_max=1., t_max=T_MAX).to(dev)
sampling_t0 = False
(1 - dif.alphas).sqrt(), (dif.alphas).sqrt()

(tensor([0.1000, 0.2464, 0.3842, 0.5115, 0.6252, 0.7229, 0.8034, 0.8668, 0.9143,
         0.9479, 0.9702, 0.9842, 0.9922, 0.9965, 0.9985, 0.9995, 0.9998, 0.9999,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        device='cuda:0'),
 tensor([9.9499e-01, 9.6916e-01, 9.2324e-01, 8.5929e-01, 7.8049e-01, 6.9096e-01,
         5.9539e-01, 4.9858e-01, 4.0505e-01, 3.1862e-01, 2.4213e-01, 1.7731e-01,
         1.2475e-01, 8.4031e-02, 5.3970e-02, 3.2884e-02, 1.8890e-02, 1.0151e-02,
         5.0500e-03, 2.2934e-03, 9.3160e-04, 3.2772e-04, 9.4130e-05, 1.9118e-05,
         0.0000e+00], device='cuda:0'))

In [33]:
optimizer = optim.Adam(list(dec.parameters()) + list(enc.parameters()) + list(trans.parameters()), lr=.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)

In [36]:
def get_X_back(x):
    nb_x = x.shape[0]
    x = x * x_std.to(dev).unsqueeze(0).expand(nb_x, -1) + x_mean.to(dev).unsqueeze(0).expand(nb_x, -1)
    return logit_back(x)

def train(epoch):
    
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        x0 = data.view(data.shape[0], -1).to(dev)
        
        x0 = (x0 - x_mean.to(dev).unsqueeze(0).expand(bs, -1))/x_std.to(dev).unsqueeze(0).expand(bs, -1)
        
        optimizer.zero_grad()
        
        if sampling_t0:
            t0 = torch.randint(0, T_MAX - 1, [x0.shape[0]]).to(dev)
            x_t0, sigma_x_t0 = dif.diffuse(x0, t0, torch.zeros(x0.shape[0]).long().to(dev))
        else:
            t0 = torch.zeros(x0.shape[0]).to(dev).long()
            x_t0 = x0
        
        
        z_t0 = enc(x_t0, pos_enc(t0.float().unsqueeze(1)))
        #z_t0 = z_t0 + torch.randn(z_t0.shape).to(dev) * (1 - dif.alphas[t0]).sqrt().unsqueeze(1).expand(-1, z_t0.shape[1])
        t = torch.torch.distributions.Uniform(t0.float() + 1, torch.ones_like(t0) * T_MAX).sample().long().to(dev)
                
        z_t, sigma_z = dif.diffuse(z_t0, t, t0)
        x_t, sigma_x = dif.diffuse(x_t0, t, t0)
        
        
        mu_x_pred = dec(z_t, pos_enc(t.float().unsqueeze(1)))
        KL_x = ((mu_x_pred - x_t)**2).sum(1) / sigma_x**2
        
        mu_z_pred = trans(z_t, pos_enc(t.float().unsqueeze(1)))
        mu, sigma = dif.prevMean(z_t0, z_t, t)
        KL_z = ((mu - mu_z_pred)**2).sum(1) / sigma**2
        
        loss = KL_x.mean(0) + KL_z.mean(0)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    scheduler.step(train_loss)
    zT = torch.randn(64, latent_s).to(dev)
    z_t = zT
    for t in range(T_MAX - 1, 0, -1):
        t_t = torch.ones(64, 1).to(dev) * t
        if t > 0:
            sigma = ((1 - dif.alphas[t-1])/(1 - dif.alphas[t]) * dif.betas[t]).sqrt()
        else:
            sigma = 0
        z_t = trans(z_t, pos_enc(t_t))  + torch.randn(z_t.shape).to(dev) * sigma 
        if (t - 1) % 1 == 0:
            x_t = dec(z_t, pos_enc(t_t - 1))
            save_image(get_X_back(x_t).view(64, 1, 28, 28), './Samples/Generated/sample_gen_' + str(epoch) + '_' + str(t - 1) + '.png')
            x_t, _ = dif.diffuse(x0, (torch.ones(x0.shape[0]).to(dev) * t - 1).long(), torch.zeros(x0.shape[0]).long().to(dev))
            save_image(get_X_back(x_t).view(x0.shape[0], 1, 28, 28), './Samples/Real/sample_real_' + str(epoch) + '_' + str(t - 1) + '.png')
                
    
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [37]:
for i in range(500):
    train(i)

====> Epoch: 0 Average loss: 13.6109
====> Epoch: 1 Average loss: 13.6508
====> Epoch: 2 Average loss: 13.6467
====> Epoch: 3 Average loss: 13.5321
====> Epoch: 4 Average loss: 13.5222
====> Epoch: 5 Average loss: 13.5984
====> Epoch: 6 Average loss: 13.5132
====> Epoch: 7 Average loss: 13.6559
Epoch   156: reducing learning rate of group 0 to 3.9063e-06.
====> Epoch: 8 Average loss: 13.5622
====> Epoch: 9 Average loss: 13.5851
====> Epoch: 10 Average loss: 13.6523
====> Epoch: 11 Average loss: 13.5469
====> Epoch: 12 Average loss: 13.6851
====> Epoch: 13 Average loss: 13.7257
====> Epoch: 14 Average loss: 13.5980
====> Epoch: 15 Average loss: 13.5527
====> Epoch: 16 Average loss: 13.5808
====> Epoch: 17 Average loss: 13.6408
====> Epoch: 18 Average loss: 13.7192
====> Epoch: 19 Average loss: 13.4530
====> Epoch: 20 Average loss: 13.5762
====> Epoch: 21 Average loss: 13.5473
====> Epoch: 22 Average loss: 13.5870


====> Epoch: 23 Average loss: 13.6704
====> Epoch: 24 Average loss: 13.6430
====> Epoch: 25 Average loss: 13.6136
====> Epoch: 26 Average loss: 13.5350
====> Epoch: 27 Average loss: 13.5284
====> Epoch: 28 Average loss: 13.5425
====> Epoch: 29 Average loss: 13.6714
Epoch   178: reducing learning rate of group 0 to 1.9531e-06.
====> Epoch: 30 Average loss: 13.5216
====> Epoch: 31 Average loss: 13.6131
====> Epoch: 32 Average loss: 13.5899
====> Epoch: 33 Average loss: 13.5992
====> Epoch: 34 Average loss: 13.5514
====> Epoch: 35 Average loss: 13.5727
====> Epoch: 36 Average loss: 13.6403
====> Epoch: 37 Average loss: 13.5130
====> Epoch: 38 Average loss: 13.5822
====> Epoch: 39 Average loss: 13.5321
====> Epoch: 40 Average loss: 13.5401
Epoch   189: reducing learning rate of group 0 to 9.7656e-07.
====> Epoch: 41 Average loss: 13.6255
====> Epoch: 42 Average loss: 13.6147
====> Epoch: 43 Average loss: 13.5787
====> Epoch: 44 Average loss: 13.6023
====> Epoch: 45 Average loss: 13.6041
==

====> Epoch: 47 Average loss: 13.6582
====> Epoch: 48 Average loss: 13.6679
====> Epoch: 49 Average loss: 13.5462
====> Epoch: 50 Average loss: 13.5880
====> Epoch: 51 Average loss: 13.6004
Epoch   200: reducing learning rate of group 0 to 4.8828e-07.
====> Epoch: 52 Average loss: 13.5515
====> Epoch: 53 Average loss: 13.6047
====> Epoch: 54 Average loss: 13.6065
====> Epoch: 55 Average loss: 13.6075
====> Epoch: 56 Average loss: 13.5237
====> Epoch: 57 Average loss: 13.5277
====> Epoch: 58 Average loss: 13.6851
====> Epoch: 59 Average loss: 13.6180
====> Epoch: 60 Average loss: 13.5526
====> Epoch: 61 Average loss: 13.5731
====> Epoch: 62 Average loss: 13.5742
Epoch   211: reducing learning rate of group 0 to 2.4414e-07.
====> Epoch: 63 Average loss: 13.6006
====> Epoch: 64 Average loss: 13.5771
====> Epoch: 65 Average loss: 13.6631
====> Epoch: 66 Average loss: 13.5780
====> Epoch: 67 Average loss: 13.5933
====> Epoch: 68 Average loss: 13.6660
====> Epoch: 69 Average loss: 13.5278


====> Epoch: 70 Average loss: 13.5441
====> Epoch: 71 Average loss: 13.5879
====> Epoch: 72 Average loss: 13.6121
====> Epoch: 73 Average loss: 13.6628
Epoch   222: reducing learning rate of group 0 to 1.2207e-07.
====> Epoch: 74 Average loss: 13.5688
====> Epoch: 75 Average loss: 13.6274
====> Epoch: 76 Average loss: 13.5520
====> Epoch: 77 Average loss: 13.6196
====> Epoch: 78 Average loss: 13.6275
====> Epoch: 79 Average loss: 13.5822
====> Epoch: 80 Average loss: 13.5301
====> Epoch: 81 Average loss: 13.5634
====> Epoch: 82 Average loss: 13.6013
====> Epoch: 83 Average loss: 13.6657
====> Epoch: 84 Average loss: 13.6488
Epoch   233: reducing learning rate of group 0 to 6.1035e-08.
====> Epoch: 85 Average loss: 13.5676
====> Epoch: 86 Average loss: 13.5037
====> Epoch: 87 Average loss: 13.5516
====> Epoch: 88 Average loss: 13.5205
====> Epoch: 89 Average loss: 13.4930
====> Epoch: 90 Average loss: 13.5835
====> Epoch: 91 Average loss: 13.5110
====> Epoch: 92 Average loss: 13.5703
==

====> Epoch: 94 Average loss: 13.5381
====> Epoch: 95 Average loss: 13.6529
Epoch   244: reducing learning rate of group 0 to 3.0518e-08.
====> Epoch: 96 Average loss: 13.6821
====> Epoch: 97 Average loss: 13.5916
====> Epoch: 98 Average loss: 13.6517
====> Epoch: 99 Average loss: 13.6300
====> Epoch: 100 Average loss: 13.5952
====> Epoch: 101 Average loss: 13.5110
====> Epoch: 102 Average loss: 13.5458
====> Epoch: 103 Average loss: 13.6430
====> Epoch: 104 Average loss: 13.5967
====> Epoch: 105 Average loss: 13.5385
====> Epoch: 106 Average loss: 13.5796
Epoch   255: reducing learning rate of group 0 to 1.5259e-08.
====> Epoch: 107 Average loss: 13.7268
====> Epoch: 108 Average loss: 13.6856
====> Epoch: 109 Average loss: 13.5165
====> Epoch: 110 Average loss: 13.6072
====> Epoch: 111 Average loss: 13.6026
====> Epoch: 112 Average loss: 13.5789
====> Epoch: 113 Average loss: 13.6123
====> Epoch: 114 Average loss: 13.5300
====> Epoch: 115 Average loss: 13.5836
====> Epoch: 116 Average

====> Epoch: 117 Average loss: 13.4861
====> Epoch: 118 Average loss: 13.5228
====> Epoch: 119 Average loss: 13.5364
====> Epoch: 120 Average loss: 13.6371
====> Epoch: 121 Average loss: 13.6376
====> Epoch: 122 Average loss: 13.5368
====> Epoch: 123 Average loss: 13.6456
====> Epoch: 124 Average loss: 13.6212
====> Epoch: 125 Average loss: 13.5552
====> Epoch: 126 Average loss: 13.6714
====> Epoch: 127 Average loss: 13.4861
====> Epoch: 128 Average loss: 13.5944
====> Epoch: 129 Average loss: 13.5632
====> Epoch: 130 Average loss: 13.5972
====> Epoch: 131 Average loss: 13.5322
====> Epoch: 132 Average loss: 13.6344
====> Epoch: 133 Average loss: 13.5635
====> Epoch: 134 Average loss: 13.5186
====> Epoch: 135 Average loss: 13.5575
====> Epoch: 136 Average loss: 13.5848
====> Epoch: 137 Average loss: 13.5556
====> Epoch: 138 Average loss: 13.6047
====> Epoch: 139 Average loss: 13.6143


====> Epoch: 140 Average loss: 13.4537
====> Epoch: 141 Average loss: 13.6076
====> Epoch: 142 Average loss: 13.7688
====> Epoch: 143 Average loss: 13.5463
====> Epoch: 144 Average loss: 13.5685
====> Epoch: 145 Average loss: 13.6614


KeyboardInterrupt: 

In [None]:
torch.arange(10).float()/10

In [39]:
import os
import torch
import torch.utils.data as data
from os.path import join
from PIL import Image, ImageOps
import random
import torchvision.transforms as transforms


this_root = os.path.abspath(os.path.dirname(__file__))


def load_image(file_path, input_height=128, input_width=None, output_height=128, output_width=None,
               crop_height=None, crop_width=None, is_random_crop=True, is_mirror=False, is_gray=False):
    if input_width is None:
        input_width = input_height
    if output_width is None:
        output_width = output_height
    if crop_width is None:
        crop_width = crop_height

    img = Image.open(file_path)
    if is_gray is False and img.mode is not 'RGB':
        img = img.convert('RGB')
    if is_gray and img.mode is not 'L':
        img = img.convert('L')

    if is_mirror and random.randint(0, 1) is 0:
        img = ImageOps.mirror(img)

    if input_height is not None:
        img = img.resize((input_width, input_height), Image.BICUBIC)

    if crop_height is not None:
        [w, h] = img.size
        if is_random_crop:
            cx1 = random.randint(0, w - crop_width)
            cx2 = w - crop_width - cx1
            cy1 = random.randint(0, h - crop_height)
            cy2 = h - crop_height - cy1
        else:
            cx2 = cx1 = int(round((w - crop_width) / 2.))
            cy2 = cy1 = int(round((h - crop_height) / 2.))
        img = ImageOps.crop(img, (cx1, cy1, cx2, cy2))

    img = img.resize((output_height, output_width), Image.BICUBIC)
    return img


def load_fake_image(img, input_height, input_width, output_height, output_width):
    fake_image = torch.load(img)
    return fake_image


def get_list_filenames(root_path):
    list = []
    for root, dirs, files in os.walk(root_path):
        for file in files:
            if not file.endswith(".jpg"):
                continue
            path = os.path.join(root, file).replace(root_path, '')
            list.append(path)


class Dataset(data.Dataset):
    def __init__(self, root_path, filename='1000_fake_tensor_cifar_10', dataset_type='celeba', input_height=128,
                 crop_height=None, crop_width=None, is_random_crop=False, is_mirror=True,
                 is_gray=False):
        """
        :param root_path: Path to the directory of the dataset
        :param filename: Name of the file
        :param dataset_type: Which dataset we are referring to
        :param input_height: Height of the image. Default set to 128
        :param crop_height:
        :param crop_width:
        :param is_random_crop:
        :param is_mirror:
        :param is_gray:
        """
        super(Dataset, self).__init__()
        self.dataset_type = dataset_type
        self.root_path = root_path
        self.input_height = input_height
        self.is_random_crop = is_random_crop
        self.is_mirror = is_mirror
        self.crop_height = crop_height
        self.crop_width = crop_width
        self.filename = filename
        self.is_gray = is_gray

        if dataset_type is 'celeba':
            self.image_filenames = get_list_filenames(root_path)

            self.input_transform = transforms.Compose([

                transforms.Resize([self.input_height, self.input_height]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

            # db = datasets.ImageFolder(root, transform=transform)
            indice = list(range(0, 5000))
            try_sampler = data.SubsetRandomSampler(indice)

    def __getitem__(self, index):
        if self.dataset_type is 'celeba':
            img = load_image(join(self.root_path, self.image_filenames[index]),
                             self.input_height, self.input_width, self.output_height, self.output_width,
                             self.crop_height, self.crop_width, self.is_random_crop, self.is_mirror, self.is_gray)

            img = self.input_transform(img)

    def __len__(self):
        return len(self.image_filenames)


if __name__ == '__main__':
    trainset = Dataset(this_root, dataset_type='fake_generated')
    trainloader = data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    # transform = transforms.Compose([ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    # trainloader = data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)

    # torch_file = os.path.join(this_root, '4_fake_tensor_cifar_10')
    #
    # fake = torch.load(torch_file)

    # print(len(fake))



NameError: name '__file__' is not defined

In [40]:
class Res_Block(nn.Module):
    """
    A single Res Block
    """

    def __init__(self, in_channels=64, out_channels=64, avg=False, upsample=False, ngpu=1):  # groups=1, scale=1.0
        super(Res_Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.avg = avg
        self.avgpool = nn.AvgPool2d(2)
        self.upsample = upsample
        # self.upsample_layer = nn.Upsample(scale_factor=2, mode='nearest') #was deprecated
        self.upsample_layer = Interpolate(scale_factor=2, mode='nearest')
        self.addon = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        self.ngpu = ngpu
        self.layers = [self.conv1, self.bn, self.relu, self.conv2]
        if in_channels > out_channels:
            self.sample = 1
        elif in_channels == out_channels:
            self.sample = 0
        else:
            self.sample = -1
            if self.upsample:
                self.layers = [self.upsample_layer, self.conv1, self.bn, self.relu, self.conv2]

    def forward(self, input):  # for encoder and generator
        if self.sample == 0:
            if self.upsample:
                input = self.upsample_layer(input)
            residual = input
            if self.ngpu == 0:
                output = self.relu(self.bn(self.conv1(input)))
                output = self.conv2(output)
                output += residual
                output = self.relu(self.bn(output))
            else:
                gpu_ids = range(self.ngpu)
                self.net = nn.Sequential(*self.layers)
                output = nn.parallel.data_parallel(self.net, input, gpu_ids)
            if self.avg:
                output = self.avgpool(output)


        elif self.sample == -1:  # for encoder, out_ch should be in_ch * 2
            identity = self.addon(input)
            output = self.relu(self.bn(self.conv1(input)))
            output = self.conv2(output)
            output += identity
            if self.avg == True:
                output = self.avgpool(output)

        else:  # for generator, out_ch should be in_ch/2
            if self.upsample:
                input = self.upsample_layer(input)
            identity = self.addon(input)
            output = self.relu(self.bn(self.conv1(input)))
            output = self.conv2(output)
            output += identity
            output = self.relu(self.bn(output))

        return output


class Interpolate(nn.Module):
    """
    Wrapper interpolate function
    """
    def __init__(self, scale_factor, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
        return x


class Intro_enc(nn.Module):
    """
    Encoder model
    """
    def __init__(self, num_col=3, img_dim=256, z_dim=512, ngpu=1):  # groups=1, scale=1.0
        super(Intro_enc, self).__init__()
        self.dim = img_dim
        self.nc = num_col
        self.c_dim = self.dim // 8
        self.layers = [nn.Conv2d(self.nc, self.c_dim, 5, 1, 2, bias=False),
                       nn.BatchNorm2d(self.c_dim),
                       nn.LeakyReLU(0.2),
                       nn.AvgPool2d(2)]
        self.zdim = self.dim * 2
        self.fc = nn.Linear(z_dim * 4 * 4, 2 * z_dim)
        self.ngpu = ngpu

        if self.dim == 256:  # 32, 64, 128, 256, 512, 512
            # 32 * 128 * 128
            self.layers.extend([Res_Block(32, 64, avg=True, ngpu=ngpu),  # 64 * 64 * 64
                                Res_Block(64, 128, avg=True, ngpu=ngpu),  # 128 * 32 * 32
                                Res_Block(128, 256, avg=True, ngpu=ngpu),  # 256 * 16 * 16
                                Res_Block(256, 512, avg=True, ngpu=ngpu),  # 512 * 8 * 8
                                Res_Block(512, 512, avg=True, ngpu=ngpu),
                                Res_Block(512, 512, ngpu=ngpu)])  # 512 * 4 * 4

        elif self.dim == 128:  # 16, 32, 64, 128, 256, 256
            # I assume the channel sequence start from 16 for 128*128 image(as in 1024*1024)
            # instead of 32 in 256*256, so that it can have similar number of Res-block
            # (while 5 for 128*128，6 for 256*256, 8 for 1024*1024)
            # 16 * 64 * 64
            '''
            self.net.add_model('res64', Res_Block(16, 32, avg=True))# 32 * 32 * 32
            self.net.add_model('res64', Res_Block(32, 64, avg=True))# 64 * 16 * 16
            self.net.add_model('res128', Res_Block(64, 128, avg=True))# 128 * 8 * 8
            self.net.add_model('res256', Res_Block(128, 256, avg=True))# 256 * 4 * 4
            '''
            self.layers.extend([
                Res_Block(16, 32, avg=True, ngpu=ngpu),
                Res_Block(32, 64, avg=True, ngpu=ngpu),
                Res_Block(64, 128, avg=True, ngpu=ngpu),
                Res_Block(128, 256, avg=True, ngpu=ngpu),
                Res_Block(256, 256, ngpu=ngpu)
            ])

        self.net = nn.Sequential(*self.layers)

    def forward(self, input):
        if self.ngpu == 0:
            output = self.net(input)
            output = output.view(output.size(0), -1)
            output = self.fc(output)
        else:
            gpu_ids = range(self.ngpu)
            output = nn.parallel.data_parallel(self.net, input, gpu_ids)
            output = output.view(output.size(0), -1)  # reshape
            output = nn.parallel.data_parallel(self.fc, output, gpu_ids)

        mean, logvar = output.chunk(2, dim=1)  # although dunno why

        return mean, logvar


class Intro_gen(nn.Module):
    """
    Generator model
    """
    def __init__(self, img_dim=256, num_col=3, z_dim=512, ngpu=1):
        super(Intro_gen, self).__init__()
        self.dim = img_dim
        self.nc = num_col
        self.z_dim = z_dim
        self.fc = nn.Linear(self.z_dim, self.z_dim * 4 * 4)
        self.relu = nn.ReLU(True)
        self.ngpu = ngpu

        if self.z_dim == 512:
            self.layers = [
                Res_Block(512, 512, ngpu=ngpu),
                Res_Block(512, 512, upsample=True, ngpu=ngpu),
                Res_Block(512, 256, upsample=True, ngpu=ngpu),
                Res_Block(256, 128, upsample=True, ngpu=ngpu),
                Res_Block(128, 64, upsample=True, ngpu=ngpu),
                Res_Block(64, 32, upsample=True, ngpu=ngpu),
                Res_Block(32, 32, upsample=True, ngpu=ngpu),
                nn.Conv2d(32, num_col, 5, 1, 2)
            ]

        elif self.z_dim == 256:
            self.layers = [
                Res_Block(256, 256, ngpu=ngpu),
                Res_Block(256, 128, upsample=True, ngpu=ngpu),
                Res_Block(128, 64, upsample=True, ngpu=ngpu),
                Res_Block(64, 32, upsample=True, ngpu=ngpu),
                Res_Block(32, 16, upsample=True, ngpu=ngpu),
                Res_Block(16, 16, upsample=True, ngpu=ngpu),
                nn.Conv2d(16, num_col, 5, 1, 2)
            ]

        self.net = nn.Sequential(*self.layers)

    def forward(self, input):
        # input: latent vector
        input = self.relu(self.fc(input))
        input = input.view(-1, self.z_dim, 4, 4)
        if self.ngpu == 0:
            output = self.net(input)
        else:
            gpu_ids = range(self.ngpu)
            output = nn.parallel.data_parallel(self.net, input, gpu_ids)

        return output

In [None]:
net = Intro_enc()

In [None]:
T_MAX = 25
latent_s = 25
t_emb_s = 1
pos_enc = StupidPositionalEncoder(T_MAX)#PositionalEncoder(t_emb_s//2)#
dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dec = TemporalDecoder(784, latent_s, 256, t_emb_s).to(dev)
enc = TemporalEncoder(784, latent_s, 256, t_emb_s).to(dev)
trans = TransitionNet(latent_s, 100, t_emb_s).to(dev)
dif = dataDiffuser(beta_min=1e-2, beta_max=1., t_max=T_MAX).to(dev)
sampling_t0 = False
(1 - dif.alphas).sqrt(), (dif.alphas).sqrt()