In [None]:
import torch.nn as nn
import torchvision
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Image
import torch.nn.functional as F
import numpy as np

# Autoencoders

In [None]:
Image(url= "https://analyticsindiamag.com/wp-content/uploads/2020/07/The-structure-of-proposed-Convolutional-AutoEncoders-CAE-for-MNIST-In-the-middle-there.png")

In [None]:
Image(url="https://assets-global.website-files.com/5d7b77b063a9066d83e1209c/60bbe71203425680a535a476_pasted%20image%200.png")

* Compression
* Denoizing
* Super resolution
* Inpainting
* Deep fake

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_features, code_size=10):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(in_features=input_features, out_features=128)
        self.encoder_output_layer = nn.Linear(in_features=128, out_features=code_size)
        self.decoder_hidden_layer = nn.Linear(in_features=code_size, out_features=128)
        self.decoder_output_layer = nn.Linear(in_features=128, out_features=input_features)

    def forward(self, x):
        x = self.encoder_hidden_layer(x)
        x = torch.relu(x)
        x = self.encoder_output_layer(x)
        x = torch.relu(x)
        x = self.decoder_hidden_layer(x)
        x = torch.relu(x)
        x = self.decoder_output_layer(x)
        x = torch.relu(x)
        return x

In [None]:
train_dataset = torchvision.datasets.MNIST(
    root='datasets',
    download=True,
    train=True,
    transform=torchvision.transforms.ToTensor(),
)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
)

valid_dataset = torchvision.datasets.MNIST(
    root='datasets',
    download=True,
    train=False,
    transform=torchvision.transforms.ToTensor(),
)

valid_dataloader = torch.utils.data.DataLoader(
    dataset=valid_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=16,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = Autoencoder(input_features=784, code_size=10).to(device)

In [None]:
batch, _ = next(iter(train_dataloader))

figure = plt.figure()

for index in range(4):
    plt.subplot(2, 2, index + 1)
    plt.tight_layout()
    plt.imshow(batch[index][0], cmap='gray', interpolation='none')

In [None]:
sample = batch[4]
prediction = model(sample.view(sample.shape[0], -1).to(device))
plt.imshow(prediction.detach().cpu().reshape(28, 28), cmap='gray', interpolation='none')

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)


def train_step() -> float:
    model.train()
    
    loss = 0.
    for images, _ in train_dataloader:
        images = images.view(-1, 784).to(device)
        
        optimizer.zero_grad()
        output = model(images)
        train_loss = criterion(output, images)
        train_loss.backward()
        optimizer.step()
        loss += train_loss.item()
    
    with torch.no_grad():
        loss /= len(train_dataloader)
    return loss

In [None]:
epochs = 10

for _ in range(epochs):
    loss = train_step()    
    print(f'Avg. loss: {loss:.4f}')

In [None]:
batch, _ = next(iter(valid_dataloader))

In [None]:
sample = batch[index]

figure = plt.figure()

plt.subplot(1, 2, 1)
plt.tight_layout()

plt.imshow(sample[0], cmap='gray', interpolation='none')
    
plt.subplot(1, 2, 2)
plt.tight_layout()
prediction = model(sample.view(sample.shape[0], -1).to(device)).detach().cpu().reshape(28, 28)
plt.imshow(prediction, cmap='gray', interpolation='none')

index += 1

***

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder_hidden_layer = model.decoder_hidden_layer
        self.decoder_output_layer = model.decoder_output_layer

    def forward(self, x):
        x = self.decoder_hidden_layer(x)
        x = torch.sigmoid(x)
        x = self.decoder_output_layer(x)
        x = torch.sigmoid(x)
        return x

In [None]:
h = Decoder()

In [None]:
h(torch.full((1, ), 0.13).to(device)).shape

In [None]:
sample = h(torch.full((1, ), 0.15).to(device)).detach().cpu().reshape(28, 28)

plt.imshow(sample, cmap='gray', interpolation='none')

In [None]:
%matplotlib notebook

In [None]:
from matplotlib.animation import FuncAnimation

In [None]:
fig = plt.figure()
ax = fig.gca()
img = ax.imshow(h(torch.full((1, ), 1e-5).to(device)).detach().cpu().reshape(28, 28))

def animate(frame_num):
    p = frame_num / 1000.
    img.set_data(h(torch.full((1, ), p).to(device)).detach().cpu().reshape(28, 28))

anim = FuncAnimation(fig, animate, frames=1000, interval=1)

***

# VAE

In [None]:
Image(url= "https://miro.medium.com/max/1400/1*ohh8pBpSsMl3LmN0USxrLg.png")

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):  
        super(VariationalEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1)
        self.batch2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=0)  
        self.linear1 = nn.Linear(3*3*32, 128)
        self.linear2 = nn.Linear(128, latent_dims)
        self.linear3 = nn.Linear(128, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.kl = 0

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z    

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, latent_dims):
        super().__init__()

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [None]:
vae = VariationalAutoencoder(latent_dims=4)

In [None]:
optimizer= torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=1e-5)

In [None]:
def train_epoch():
    vae.train()
    train_loss = 0.0
    for x, _ in train_dataloader: 
        x_hat = vae(x)
        loss = ((x - x_hat)**2).sum() + vae.encoder.kl
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()

    return train_loss / len(train_dataloader.dataset)

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    loss = train_epoch()
    print(f'Avg. loss: {loss:.4f}')

In [None]:
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


vae.eval()

with torch.no_grad():

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, 4, device=device)

    # reconstruct images from the latent vectors
    img_recon = vae.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon.data[:100],10,5))
    plt.show()

# GAN

In [None]:
Image(url= "https://miro.medium.com/max/1400/1*M_YipQF_oC6owsU1VVrfhg.jpeg")