[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HSE-LAMBDA/DeepGenerativeModels/blob/spring-2021/seminars/seminar-1-autoencoders/0.autoencoders.ipynb)


In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from tqdm.notebook import tqdm

from sklearn.manifold import TSNE

### Data

In [3]:
from torchvision.datasets import MNIST
from torchvision import transforms

from torch.utils.data import DataLoader

mnist_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
batch_size = 64


In [4]:
train_dataset = MNIST('./mnist_root', train=True, transform=mnist_transforms, download=True)
test_dataset = MNIST('./mnist_root', train=False, transform=mnist_transforms, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_root/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_root/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_root/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_root/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_root/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_root/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_root/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_root/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_root/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Simple AE

![alt text](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F87a8b660-2da9-40c7-bcf0-c65f36c3f30b%2FUntitled.png?table=block&id=c7daee2c-a8a3-4c2a-b53b-1f072cbe32af&width=3580&userId=007ae5b1-7ba2-466d-8d98-87eb8085b484&cache=v2 "Logo Title Text 1")


In [5]:
class Block(nn.Module):
    def __init__(self, in_features, out_features, kernel, stride=1, bias=False, upsample=False):
        super().__init__()
        self.upsample = upsample

        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, padding=(kernel-1)//2, bias=bias)
        self.norm = nn.BatchNorm2d(out_features)
        self.act = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False, recompute_scale_factor=False)
        return self.act(self.norm(self.conv(x)))
        
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # your code
        )
        
        self.decoder = nn.Sequential(
            # your code
        )
        
    def forward(self, x):
        # your code
        return x
    
    def get_latent_features(self, x):
        # your code






In [6]:
ae = AutoEncoder().to(device)
optim = torch.optim.Adam(ae.parameters(), lr=0.001)
criterion = nn.MSELoss()

Чтобы проверить осмысленность latent-space репрезентаций, построим простую кластеризацию до и после обучения

In [2]:
tsne = TSNE()
ae.eval()

latent_mnist = [
    ae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()

In [3]:
%%time

tsne_mnist = tsne.fit_transform(latent_mnist)
plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()

In [4]:
losses = []
ae.train()

for image, _ in tqdm(train_loader, desc='train loop', leave=True):
    optim.zero_grad()
    image = image.to(device)
    out = ae(image)
    loss = criterion(out, image)
    loss.backward()
    optim.step()
    losses.append(loss.item())
plt.plot(losses)
plt.show()

In [7]:
tsne = TSNE()
ae.eval()


latent_mnist = [
    ae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()
tsne_mnist = tsne.fit_transform(latent_mnist)

plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()

In [8]:
test_batch = next(iter(test_loader))
ae.eval()

for image in test_batch[0]:
    plt.subplot(1, 2, 1)
    plt.imshow(image.squeeze())
    plt.subplot(1, 2, 2)
    plt.imshow(ae(image.unsqueeze(0).to(device)).squeeze().detach().to('cpu'))
    plt.show()

### Denoising AE

![image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F08600deb-8305-45d6-9f0c-3effe0608f30%2FUntitled.png?table=block&id=4c9c7452-8f5c-439d-8535-4e091b402d16&width=3580&userId=007ae5b1-7ba2-466d-8d98-87eb8085b484&cache=v2)

In [12]:
class DenoisingBlock(nn.Module):
    def __init__(self, in_features, out_features, kernel, stride=1, bias=False, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, padding=(kernel-1)//2, bias=bias)
        self.norm = nn.BatchNorm2d(out_features)
        self.act = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False, recompute_scale_factor=False)
        x = x + torch.randn_like(x) * 0.05
        x = self.dropout(x)
        return self.act(self.norm(self.conv(x)))
        
class DenoisingAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # your code
        )
        
        self.decoder = nn.Sequential(
            # your code
        )
        
    def forward(self, x):
        # your code
        return x
    
    def get_latent_features(self, x):
        # your code
    
def add_noise_to_input(x):
    # your code

In [13]:
dae = DenoisingAutoEncoder().to(device)
optim = torch.optim.Adam(dae.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [9]:
tsne = TSNE()
dae.eval()

latent_mnist = [
    dae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()
tsne_mnist = tsne.fit_transform(latent_mnist)
plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()


In [10]:
losses = []
dae.train()

for image, _ in tqdm(train_loader, desc='train loop', leave=True):
    optim.zero_grad()
    image = image.to(device)
    out = dae(add_noise_to_input(image))
    loss = criterion(out, image)
    loss.backward()
    optim.step()
    losses.append(loss.item())
plt.plot(losses)
plt.show()

In [11]:
tsne = TSNE()
dae.eval()


latent_mnist = [
    dae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()
tsne_mnist = tsne.fit_transform(latent_mnist)

plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()

In [12]:
test_batch = next(iter(test_loader))
dae.eval()

for image in test_batch[0]:
    plt.subplot(1, 3, 1)
    plt.imshow(image.squeeze())
    plt.subplot(1, 3, 2)
    plt.imshow(add_noise_to_input(image).squeeze())
    plt.subplot(1, 3, 3)
    plt.imshow(dae(image.unsqueeze(0).to(device)).squeeze().detach().to('cpu'))
    plt.show()

##### Denoising AE generation

In [18]:
test_batch = next(iter(test_loader))
test_images = test_batch[0].to(device)
dae.eval()

images_latent = dae.get_latent_features(test_images).to('cpu')

In [13]:
for l_image in images_latent:
    plt.hist(l_image.view(-1).detach())
plt.show()

In [14]:
for l_image in images_latent:
    plt.hist(torch.randn_like(l_image).view(-1).detach())
plt.show()

In [15]:
test_noise = torch.randn_like(images_latent) * 0.8
sampled_mnist = dae.decoder(test_noise.to(device)).to('cpu')
    

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(sampled_mnist[ind].squeeze().detach())


### Sparse AE

##### with KL regularization

In [6]:
class OldBlock(nn.Module):
    def __init__(self, in_features, out_features, kernel, stride=1, bias=False, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_features, out_features, kernel, stride=stride, padding=(kernel-1)//2, bias=bias)
        self.act = nn.Sigmoid()
        
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False, recompute_scale_factor=False)
        return self.act(self.conv(x))


class SparseAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # your code
        )
        
        self.decoder = nn.Sequential(
            # your code
        )
        
    def forward(self, x):
        # your code
        return x
    
    def get_latent_features(self, x):
        # your code

In [7]:
sae = SparseAutoEncoder().to(device)
optim = torch.optim.Adam(sae.parameters(), lr=0.001)
criterion = nn.MSELoss()

Здесь вам придется написать функцию которая считает KL между двумя биномиальными распределениями:

$\text{KL}(B(p) || B(\hat{p})) = p log \frac{p}{\hat{p}} + (1 - p) log \frac{1 - p}{1 - \hat{p}}$



In [8]:
FIXED_SPARSE_CONST = 0.05

def kl_divergence(x):
    x = x.view(x.shape[0], -1)
    # here we need the probability distributions
    # fortunately we already using sigmoid
    
    # here you want to calculate \hat{p} - mean value of your layer output
    p_hat = # your code
    p = torch.ones_like(p_hat) * FIXED_SPARSE_CONST
    
    # calculate KL here:
    return # your code

def calculate_sparse_loss(sae, image):
    loss = 0
    x = image
    for block in sae.encoder[:-1]:
        # apply the layer 
        x = block(x)
        # calculate KL
        loss += kl_divergence(x)
    x = sae.encoder[-1](x)
    # note that we are using sigmoid on each encoder layer except the last
    loss += kl_divergence(torch.sigmoid(x))
    
    for block in sae.decoder[:-1]:
        x = block(x)
        loss += kl_divergence(x)
    x = sae.decoder[-1](x)
    # note that we are using sigmoid on each decoder layer except the last
    loss += kl_divergence(torch.sigmoid(x))

    return loss


In [16]:
losses = []
sae.train()

for image, _ in tqdm(train_loader, desc='train loop', leave=True):
    optim.zero_grad()
    image = image.to(device)
    out = sae(image)
    loss = criterion(out, image) + 0.001 * calculate_sparse_loss(sae, image)
    loss.backward()
    optim.step()
    losses.append(loss.item())
plt.plot(losses)
plt.show()

In [17]:
tsne = TSNE()
sae.eval()


latent_mnist = [
    sae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()
tsne_mnist = tsne.fit_transform(latent_mnist)

plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()

In [18]:
test_batch = next(iter(test_loader))
sae.eval()

for image in test_batch[0]:
    plt.subplot(1, 2, 1)
    plt.imshow(image.squeeze())
    plt.subplot(1, 2, 2)
    plt.imshow(sae(image.unsqueeze(0).to(device)).squeeze().detach().to('cpu'))
    plt.show()

In [19]:
test_batch = next(iter(test_loader))
test_images = test_batch[0].to(device)
sae.eval()

images_latent = sae.get_latent_features(test_images).to('cpu')

Раз мы использовали KL loss, мы ожидаем увидеть все активации сдвинутые в ноль

In [20]:
for l_image in images_latent:
    plt.hist(l_image.view(-1).detach())
plt.show()

In [21]:
for l_image in images_latent:
    plt.hist(torch.sigmoid(l_image).view(-1).detach())
plt.show()

In [22]:
for l_image in images_latent:
    plt.hist(torch.sigmoid(torch.randn_like(l_image)).view(-1).detach())
plt.show()

In [23]:
test_noise = (torch.randn_like(images_latent) * 10) - 5
sampled_mnist = sae.decoder(test_noise.to(device)).to('cpu')
    

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(sampled_mnist[ind].squeeze().detach())


### Sparse AE

##### with L1 regularization

In [17]:
class SparseAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # your code
        )
        
        self.decoder = nn.Sequential(
            # your code
        )
        
    def forward(self, x):
        # your code
        return x
    
    def get_latent_features(self, x):
        # your code

In [18]:
sae = SparseAutoEncoder().to(device)
optim = torch.optim.Adam(sae.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [19]:
def l1_loss(x):
    return # your code


def calculate_sparse_loss(sae, image):
    loss = 0
    x = image
    for block in sae.encoder[:-1]:
        x = block.conv(x)
        loss += l1_loss(x)
        x = block.act(block.norm(x))
    x = sae.encoder[-1](x)
    loss += l1_loss(x)

    for block in sae.decoder[:-1]:
        x = block.conv(x)
        loss += l1_loss(x)
        x = block.act(block.norm(x))
    x = sae.decoder[-1](x)
    loss += l1_loss(x)
    return loss


In [24]:
losses = []
sae.train()

for image, _ in tqdm(train_loader, desc='train loop', leave=True):
    optim.zero_grad()
    image = image.to(device)
    out = sae(image)
    loss = criterion(out, image) + 0.001 * calculate_sparse_loss(sae, image)
    loss.backward()
    optim.step()
    losses.append(loss.item())
plt.plot(losses)
plt.show()

In [25]:
tsne = TSNE()
sae.eval()

latent_mnist = [
    sae.get_latent_features(x.to(device)).detach().to('cpu').view(batch_size, -1) for x, _ in tqdm(test_loader)
]
latent_mnist = torch.cat(latent_mnist, dim=0).numpy()
tsne_mnist = tsne.fit_transform(latent_mnist)

plt.scatter(tsne_mnist[:, 0], tsne_mnist[:, 1])
plt.show()

In [26]:
test_batch = next(iter(test_loader))
sae.eval()

for image in test_batch[0]:
    plt.subplot(1, 2, 1)
    plt.imshow(image.squeeze())
    plt.subplot(1, 2, 2)
    plt.imshow(sae(image.unsqueeze(0).to(device)).squeeze().detach().to('cpu'))
    plt.show()

In [23]:
test_batch = next(iter(test_loader))
test_images = test_batch[0].to(device)
sae.eval()

images_latent = sae.get_latent_features(test_images).to('cpu')

Раз мы использовали L1 лосс, мы бы ожидали увидеть выход энкодера сильно сдвинутый в ноль

In [27]:
for l_image in images_latent:
    plt.hist(l_image.view(-1).detach())
plt.show()

In [28]:
for l_image in images_latent:
    plt.hist(torch.randn_like(l_image).view(-1).detach())
plt.show()

In [29]:
test_noise = torch.randn_like(images_latent) * 0.1
sampled_mnist = sae.decoder(test_noise.to(device)).to('cpu')
    

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(sampled_mnist[ind].squeeze().detach())
