In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [61]:
features = 16
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
        
        # Encoder
        self.enc_fc1 = torch.nn.Linear(in_features=28*28, out_features=512)
        self.enc_fc2 = torch.nn.Linear(in_features=512, out_features=features*2)
        
        # Decoder
        self.dec_fc1 = torch.nn.Linear(in_features=features, out_features=512)
        self.dec_fc2 = torch.nn.Linear(in_features=512, out_features=28*28)
        
        
    def reparameterize(self, mu, log_var):
        """
        Recieves log variance and mu. Outputs a sample.
        """
        var = torch.exp(0.5 * log_var) # TODO: why do we have the 0.5 term?
        eps = torch.randn_like(var) # Mean 0, Var 1
        sample = mu + eps * var # Gives us samples with mean MU and std VAR
        return sample
    
    
    def forward(self, x):
        x = F.relu(self.enc_fc1(x))
        x = F.relu(self.enc_fc2(x)).view(-1, 2, features)
        
        # We now have some latent representation with shape [N, feature * 2]
        # We will reshape it to shape [-1, 2, feature], so it's now [N, 2, features]
        # Where [:, 0, :] is the mean, and [:, 1, :] is the variance
        mean = x[:, 0, :]
        log_var = x[:, 1, :]
        
        reparameterized = self.reparameterize(mean, log_var)
        reconstructed = F.relu(self.dec_fc1(reparameterized))
        reconstructed = torch.sigmoid(self.dec_fc2(reconstructed))
        return reconstructed, mean, log_var
        
        

In [62]:
import torch
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
matplotlib.style.use('ggplot')

In [63]:
# leanring parameters
epochs = 10
batch_size = 64
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [64]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

# train and validation data
train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
val_data = datasets.MNIST(
    root='../input/data',
    train=False,
    download=True,
    transform=transform
)

In [65]:
train_loader = DataLoader (
    train_data, 
    batch_size=batch_size,
    shuffle=True)

test_loader = DataLoader (
    val_data, 
    batch_size=batch_size,
    shuffle=False)

In [66]:
model = LinearVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
bce_loss = nn.BCELoss(reduction="sum")

In [67]:
def compute_loss(reconstruction, gt, bce_loss, mu, logvar):
    # Reconstruction loss - ask the prediction to be like the GT (TODO: why not L1/L2? )
    bce_loss = bce_loss(reconstruction, gt)
    
    deviation_from_normal_regualrizer =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return bce_loss + deviation_from_normal_regualrizer


In [68]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        data, _ = data
        data = data.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        loss = compute_loss(reconstruction, data, bce_loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss
    

In [73]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data, _ = data
            data = data.to(device)
            data = data.view(data.size(0), -1)
            reconstruction, mu, logvar = model(data)
            loss = compute_loss(reconstruction, data, bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction.view(batch_size, 1, 28, 28)[:8]))
                save_image(both.cpu(), f"../outputs/output{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [76]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, test_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

  2%|▏         | 15/937 [00:00<00:06, 145.32it/s]

Epoch 1 of 10


938it [00:08, 106.04it/s]                         
157it [00:01, 111.19it/s]                         
  1%|          | 9/937 [00:00<00:10, 86.77it/s]

Train Loss: 160.3874
Val Loss: 159.3884
Epoch 2 of 10


938it [00:08, 106.00it/s]                         
157it [00:01, 136.48it/s]                         
  1%|▏         | 13/937 [00:00<00:08, 111.39it/s]

Train Loss: 159.1175
Val Loss: 158.2902
Epoch 3 of 10


938it [00:09, 102.36it/s]                         
157it [00:01, 152.44it/s]                         
  0%|          | 2/937 [00:00<00:47, 19.83it/s]

Train Loss: 158.3716
Val Loss: 157.3566
Epoch 4 of 10


938it [00:08, 105.55it/s]                         
157it [00:01, 111.97it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 120.26it/s]

Train Loss: 157.6253
Val Loss: 156.6747
Epoch 5 of 10


938it [00:08, 108.22it/s]                         
157it [00:01, 139.34it/s]                         
  1%|          | 8/937 [00:00<00:11, 77.70it/s]

Train Loss: 156.9492
Val Loss: 156.3417
Epoch 6 of 10


938it [00:09, 103.70it/s]                         
157it [00:01, 134.69it/s]                         
  1%|          | 9/937 [00:00<00:11, 84.35it/s]

Train Loss: 156.5453
Val Loss: 155.7874
Epoch 7 of 10


938it [00:08, 108.15it/s]                         
157it [00:01, 129.59it/s]                         
  1%|          | 10/937 [00:00<00:09, 95.70it/s]

Train Loss: 155.9827
Val Loss: 155.3387
Epoch 8 of 10


938it [00:08, 104.63it/s]                         
157it [00:01, 118.22it/s]                         
  1%|▏         | 12/937 [00:00<00:08, 105.99it/s]

Train Loss: 155.7166
Val Loss: 154.7301
Epoch 9 of 10


938it [00:08, 104.45it/s]                         
157it [00:01, 119.31it/s]                         
  1%|▏         | 14/937 [00:00<00:06, 139.36it/s]

Train Loss: 155.2923
Val Loss: 154.6789
Epoch 10 of 10


938it [00:08, 109.35it/s]                         
157it [00:01, 128.63it/s]                         

Train Loss: 155.0199
Val Loss: 154.7058



