In [106]:
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [132]:
torch.cuda.is_available

<function torch.cuda.is_available() -> bool>

In [134]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [33]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

In [65]:
data_train = datasets.MNIST(root='./data/train', transform=transform, download=True)
data_test = datasets.MNIST(root='./data/test', transform=transform, train=False, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/train/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/train/MNIST/raw/train-images-idx3-ubyte.gz to ./data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/train/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/train/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/train/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/test/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/test/MNIST/raw/train-images-idx3-ubyte.gz to ./data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/test/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/test/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/test/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/test/MNIST/raw



In [35]:
print(data.train)

True


In [66]:
img0, label0 = data_train[0]
img1, label1 = data_test[0]

print(label0, label1)

5 7


In [126]:
class Encodeur(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(Encodeur, self).__init__()
        self.latent_size = latent_size
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2*latent_size)
        )
    
    def forward(self, image):
        batch = image.size(0)
        image = image.view(batch, -1)
        latent_vecteur = self.model(image)
        mu = latent_vecteur[:, :self.latent_size]
        covariance = latent_vecteur[:, self.latent_size:]
        
        return mu, covariance, 


class Decodeur(nn.Module):
    def __init__(self, output_size, hidden_size, latent_size):
        super(Decodeur, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
    
    def forward(self, mu, covariance):
        """
        output : batch_size * (28*28)
        """
        z = Normal(mu, torch.exp(covariance))
        latent_vecteur = z.sample()
        
        return self.model(latent_vecteur)

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()
        self.encodeur = Encodeur(input_size, hidden_size, latent_size)
        self.decodeur = Decodeur(input_size, hidden_size, latent_size)
    
    def forward(self, image):
        mu, covariance = self.encodeur(image)
        res = self.decodeur(mu, covariance)
        
        return mu, covariance, res

In [None]:
batch_size = 512
input_size = 28*28
hidden_size = 128
latent_size = 10
max_iters = 200
lr = 1e-3

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=True, drop_last=True)
model = VAE(input_size, hidden_size, latent_size).to(device)
bce = nn.BCELoss()
optim = torch.optim.Adam(model.parameters(), lr=lr)
writer = SummaryWriter(log_dir = './runs')

for iters in range(max_iters):
    sum_loss = 0
    for img,label in train_loader:
        mu, covariance, res = model(img.to(device))
        loss1 = -1/2 * (1 + covariance - mu**2 - torch.exp(covariance)).sum()
        loss2 = bce(res, img.view(batch_size, -1).to(device))
        loss = loss1 + loss2
        sum_loss += loss.item()
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    writer.add_scalar('loss/train', sum_loss, iters)
    print(sum_loss)
    sum_loss = 0
    with torch.no_grad():
        for img,label in test_loader:
            mu, covariance, res = model(img.to(device))
            loss1 = -1/2 * (1 + covariance - mu**2 - torch.exp(covariance)).sum()
            loss2 = bce(res, img.view(batch_size, -1).to(device))
            loss = loss1 + loss2
            sum_loss += loss.item()
    
    writer.add_scalar('loss/test', sum_loss, iters)
    
    if iters%20 == 0:
        with torch.no_grad():
            img = model.decodeur(mu=torch.randn(latent_size).to(device), covariance=torch.randn(latent_size).to(device))
            writer.add_image('imageGénéré', img.view(1,28,28), iters)
        
    

189.4261498749256
44.645963579416275
35.22678515315056
32.47794723510742
31.752962678670883
31.4384223818779
31.256068259477615
31.128083169460297
31.055584996938705
30.992048233747482
30.94728535413742
30.91357684135437
30.88748347759247
30.8677981197834
30.85184234380722
30.83478742837906
30.82668587565422
30.821056753396988
30.814650267362595
30.807825714349747
30.804563492536545
30.833574682474136
30.7927143573761
30.822667211294174
30.78850546479225
30.80947345495224
30.807479232549667
30.788697242736816
30.817652493715286
30.782933682203293
30.787940472364426
30.78425484895706
30.835609525442123
30.778607100248337
30.774868339300156
30.77417430281639
30.81570142507553
30.78007736802101
30.77284526824951
30.77166309952736
30.77048233151436
30.793953716754913
30.772417902946472
30.769355684518814
30.786439150571823
30.769646108150482
30.783532321453094
30.768202543258667
30.76737153530121
30.790520548820496


In [120]:
encodeur = Encodeur(28*28, 128, 10)
decodeur = Decodeur(28*28, 128, 10)
mu, co = encodeur(torch.cat([img0, img1], dim=0))
res = decodeur(mu, co)
print(mu, co.size())
print(res.size())

tensor([[-0.1059,  0.0229, -0.1617,  0.0562,  0.0357,  0.0489, -0.0448, -0.1077,
          0.1297, -0.0115],
        [-0.0162,  0.0474, -0.0532,  0.0369,  0.0137, -0.0117, -0.1122, -0.1225,
          0.0955,  0.0116]], grad_fn=<SliceBackward>) torch.Size([2, 10])
torch.Size([2, 784])


In [113]:
z = Normal(mu, co)
z.sample()

tensor([[ 0.9880, -0.8637,  0.6546, -0.7828,  0.9836, -1.2632, -0.6844, -1.3263,
          0.8261, -1.0049],
        [ 0.6324, -0.7441,  0.1238, -0.7501, -0.7084, -0.7242, -1.5366, -1.6214,
         -0.7658, -0.2188]])

In [89]:
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward()
print(input)
print(target)
print(output)

tensor([ 0.2410, -0.1575, -0.5553], requires_grad=True)
tensor([1., 0., 1.])
tensor(0.7354, grad_fn=<BinaryCrossEntropyBackward>)
