In [1]:
import os
from torch import optim, nn, utils, Tensor
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [3]:
import torch
# tensor board to represent models!
# check

print("using torch", torch.__version__)

torch.manual_seed(47)
x = torch.arange(6)
x = x.view(2,3)
print("X", x)

using torch 1.12.1+cu102
X tensor([[0, 1, 2],
        [3, 4, 5]])


In [4]:
def draw_circle(radius, center_x=0.5, center_y=0.5, size=28):
    # draw a circle using coordinates for the center and the radius
    circle = plt.Circle((center_x,center_y), radius, color='k', fill=False)
    fig,ax = plt.subplots(figsize=(1,1))
    ax.add_patch(circle)
    ax.axis('off')
    buf = fig.canvas.print_to_buffer()
    plt.close()
    # converts matplotlib figure into pil image, make it grayscale and resize it
    return np.array(Image.frombuffer('RGBA', buf[1], buf[0]).convert('L').resize((int(size), int(size))))

def gen_circles(n, size=28):
    # generates random coordinates around (0.5, 0.5) as center points
    center_x = np.random.uniform(0.0,0.03, size=n).reshape(-1,1)+.5
    center_y = np.random.uniform(0.0,0.03, size=n).reshape(-1,1)+.5
    # generates random radius sizes between 0.03 and 0.47
    radius = np.random.uniform(0.03,0.47, size=n).reshape(-1,1)
    sizes = np.ones((n,1))*size

    coords = np.concatenate([radius, center_x, center_y, sizes], axis=1)
    # generates circles using draw_circle function
    circles = np.apply_along_axis(func1d=lambda v: draw_circle(*v), axis=1, arr=coords)
    return circles, radius


np.random.seed(42)
# generates 1'000 circles
circles, radius = gen_circles(1000)

circles_ds = utils.data.TensorDataset(torch.as_tensor(circles).unsqueeze(1).float()/255, torch.as_tensor(radius))
circles_dl = utils.data.DataLoader(circles_ds, batch_size=32, shuffle=True, drop_last=True)


In [8]:
def set_seed(self, seed=42):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)

class Encoder(nn.Module):
    def __init__(self, input_shape, z_size, base_model):
        super().__init__()
        self.input_shape = input_shape
        self.z_size = z_size
        self.base_model = base_model

        # appends the "lin_latent" linear layer to map form "output_size"
        # given by the base model to desired size of the representation ( z_size)
        output_size = self._get_output_size()
        self.lin_latent = nn.Linear(output_size, z_size)

    def _get_output_size(self):
        # builds a dummy batch containing one dummy tensor
        # fully of zeros with the same shape as the inputs
        device = next(self.base_model.parameters()).device.type
        dummy = torch.zeros(1, *self.input_shape, device=device)
        # sends the dummy batch through the base model to get
        # the output size produced by it
        size = self.base_model(dummy).size(1)
        return size

    def forward(self, x):
        # forwards the input through the base model and then the "lin_latent" layer
        # to get the representation (z)
        base_out = self.base_model(x)
        out = self.lin_latent(base_out)
        return out

set_seed(13)

# defined out representation (z) as a vector of size one
z_size = 1
# our images are 1@28x28
input_shape = (1,28,28) # (C,H,W)

base_model = nn.Sequential(
    # (C,H,W) -> C*H*W
    nn.Flatten(),
    # C*H*W -> 2048
    nn.Linear(np.prod(input_shape), 2048),
    nn.LeakyReLU(),
    # 2048 -> 2048
    nn.Linear(2048,2048),
    nn.LeakyReLU()
)

encoder = Encoder(input_shape, z_size, base_model)

In [9]:
x, _ = circles_ds[7]
z = encoder(x)
z

tensor([[-0.1209]], grad_fn=<AddmmBackward0>)

In [10]:
decoder = nn.Sequential(
    # z_size -> 2048
    nn.Linear(z_size, 2048),
    nn.LeakyReLU(),
    # 2048 -> 2048
    nn.Linear(2048, 2048),
    nn.LeakyReLU(),
    # 2048 -> C*H*W
    nn.Linear(2048, np.prod(input_shape)),
    # C*H*W -> (C, H, W)
    nn.Unflatten(1, input_shape)
)

In [11]:
x_tilde = decoder(z)
x_tilde, x_tilde.shape

(tensor([[[[ 1.9079e-01, -4.3900e-02, -4.9170e-02,  5.2142e-02, -8.0119e-02,
            -1.6324e-01,  3.8319e-02,  6.2965e-02, -3.7442e-02, -3.6085e-02,
             2.2930e-02, -1.2089e-01,  2.0558e-01,  1.3671e-01,  1.4607e-03,
             1.1066e-02, -1.3429e-01, -3.7842e-02,  6.1736e-02, -3.0216e-02,
            -7.4171e-02, -1.6376e-02, -6.4663e-02, -1.5638e-01, -9.6260e-02,
             5.3312e-02,  6.6354e-02, -2.6916e-02],
           [ 1.8874e-01,  9.7503e-02, -1.3948e-01, -1.2955e-01, -1.2210e-02,
             5.6815e-02, -7.5753e-02,  5.3484e-02,  6.4153e-02, -1.6740e-01,
            -5.0190e-02,  6.2855e-02,  9.7707e-02, -2.2777e-02, -1.1442e-01,
             1.6079e-01, -1.4634e-01,  2.0068e-01, -2.7668e-02,  6.3487e-02,
            -9.6758e-02,  3.0803e-02, -7.0600e-02,  1.1162e-01,  8.6569e-02,
            -1.5205e-02,  1.9754e-01,  1.1691e-01],
           [-7.3312e-02,  9.0082e-02, -4.4267e-02,  1.1471e-01,  6.9005e-02,
             1.5236e-03, -4.2527e-02,  1.4255e-01

In [13]:
class AutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.enc = encoder
        self.dec = decoder

    def forward(self, x):
        # when encoder met decoder
        enc_out = self.enc(x)
        return self.dec(enc_out)

model_ae = AutoEncoder(encoder, decoder)

In [14]:
set_seed(13)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_ae.to(device)
loss_fn = nn.MSELoss()
optim = optim.Adam(model_ae.parameters(), 0.0003)

num_epochs = 10

train_losses = []

for epoch in range(1, num_epochs+1):
    batch_losses = []
    for i, (x,_) in enumerate(circles_dl):
        model_ae.train()
        x = x.to(device)

        # step1 compute model's predicted output - forward pass
        yhat = model_ae(x)
        # step2 compute loss
        loss = loss_fn(yhat, x)
        # step3 compute gradients
        loss.backward()
        # step4 update param using gradients and learning rate
        optim.step()
        optim.zero_grad()

        batch_losses.append(np.array([loss.data.item()]))

    # avg over batches
    train_losses.append(np.array(batch_losses).mean(axis=0))

    print(f'Epoch {epoch:03d} | Loss >> {train_losses[-1][0]:.4f}')

Epoch 001 | Loss >> 0.1388
Epoch 002 | Loss >> 0.0062
Epoch 003 | Loss >> 0.0049
Epoch 004 | Loss >> 0.0048
Epoch 005 | Loss >> 0.0048
Epoch 006 | Loss >> 0.0048
Epoch 007 | Loss >> 0.0048
Epoch 008 | Loss >> 0.0048
Epoch 009 | Loss >> 0.0047
Epoch 010 | Loss >> 0.0045


In [5]:
def set_seed(self, seed=42):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)

class EncoderVar(nn.Module):
    def __init__(self, input_shape, z_size, base_model):
        super().__init__()
        self.z_size = z_size
        self.input_shape = input_shape
        self.base_model = base_model
        output_size = self.get_output_size()
        self.lin_mu = nn.Linear(output_size, z_size)
        self.lin_var = nn.Linear(output_size, z_size)

    def get_output_size(self):
        device = next(self.base_model.parameters()).device.type
        size = self.base_model(torch.zeros(1, *self.input_shape, device=device)).size(1)
        return size

    def kl_loss(self):
        kl_loss = -0.5*(1 + self.log_var - self.mu**2 - torch.exp(self.log_var))
        return kl_loss

    def forward(self, x):
        base_out = self.base_model(x)

        # now encoder produces means (mu) using the lin_mu output layer
        # and log variances (log_var) using hte lin_var output layer
        # compute the standard deviation (std) from the log variance
        self.mu = self.lin_mu(base_out)
        self.log_var = self.lin_var(base_out)
        std = torch.exp(self.log_var/2)

        # internal random input (epsilon)
        eps = torch.randn_like(self.mu)
        # and z vector
        z = self.mu + eps * std

        return z

In [8]:
set_seed(13)
# defined out representation (z) as a vector of size one
z_size = 1
# our images are 1@28x28
input_shape = (1,28,28) # (C,H,W)
base_model = nn.Sequential(
    # (C,H,W) -> C*H*W
    nn.Flatten(),
    # C*H*W -> 2048
    nn.Linear(np.prod(input_shape), 2048),
    nn.LeakyReLU(),
    # 2048 -> 2048
    nn.Linear(2048,2048),
    nn.LeakyReLU()
)


encoder_var = EncoderVar(input_shape, z_size, base_model)

decoder_var = nn.Sequential(
    # z_size -> 2048
    nn.Linear(z_size, 2048),
    nn.LeakyReLU(),
    # 2048 -> 2048
    nn.Linear(2048, 2048),
    nn.LeakyReLU(),
    # 2048 -> C*H*W
    nn.Linear(2048, np.prod(input_shape)),
    # C*H*W -> (C, H, W)
    nn.Unflatten(1, input_shape)
)

In [9]:
class LitVAutoEncoder(pl.LightningModule):
    def __init__(self, encoderV, decoderV):
        super().__init__()
        self.encoderV = encoderV
        self.decoderV = decoderV

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoderV(x)
        x_hat = self.decoderV(z)
        lossV = nn.functional.mse_loss(x_hat, x)
        return lossV

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [10]:
model_vae = LitVAutoEncoder(encoder_var, decoder_var)

trainer = pl.Trainer(limit_train_batches=100, max_epochs=100)
trainer.fit(model=model_vae, train_dataloaders=circles_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type       | Params
----------------------------------------
0 | encoderV | EncoderVar | 5.8 M 
1 | decoderV | Sequential | 5.8 M 
----------------------------------------
11.6 M    Trainable params
0         Non-trainable params
11.6 M    Total params
46.460    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  lossV = nn.functional.mse_loss(x_hat, x)


RuntimeError: The size of tensor a (28) must match the size of tensor b (784) at non-singleton dimension 3

In [20]:
x, y = next(iter(circles_dl))
zs = encoder_var(x)
reconstructed = decoder_var(zs)

loss_fn_raw = nn.MSELoss(reduction='none')
raw_mse = loss_fn_raw(reconstructed, x)
raw_mse.shape

torch.Size([32, 1, 28, 28])

In [23]:
raw_mse.sum(), nn.MSELoss(reduction='sum')(reconstructed, x)

(tensor(24936.3203, grad_fn=<SumBackward0>),
 tensor(24936.3203, grad_fn=<MseLossBackward0>))

In [24]:
sum_over_pixels = raw_mse.sum(dim=[1, 2, 3])
sum_over_pixels.mean()

tensor(779.2599, grad_fn=<MeanBackward0>)

In [25]:
raw_kl = encoder_var.kl_loss()
raw_kl.shape

torch.Size([32, 1])

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_vae.to(device)
loss_fn = nn.MSELoss(reduction='none')
optim = torch.optim.Adam(model_vae.parameters(), 0.0003)

num_epochs = 30

train_losses = []

reconstruction_loss_factor = 1

for epoch in range(1, num_epochs+1):
    batch_losses = []
    for i, (x, _) in enumerate(circles_dl):
        model_vae.train()
        x = x.to(device)

        # step1 compute predicted output - forward pass
        yhat = model_vae(x)
        # step2 compute loss
        # reduce (sum) over pixels (dim=[1, 2, 3])
        # then reduce (sum) over batch (dim=0)
        loss = loss_fn(yhat, x).sum(dim=[1, 2, 3]).sum(dim=0)
        # reduce (sum) over z (dim=1)
        # then reduce (sum) over batch (dim=0)
        kl_loss = model_vae.enc.kl_loss().sum(dim=1).sum(dim=0)
        # adding kl loss to original mse loss
        total_loss = reconstruction_loss_factor * loss + kl_loss
        # step3 compute gradients
        total_loss.backward()
        # step4 update param using gradients and learning rate
        optim.step()
        optim.zero_grad()

        batch_losses.append(np.array([total_loss.data.item(), loss.data.item(), kl_loss.data.item()]))

    # avg over batches
    train_losses.append(np.array(batch_losses).mean(axis=0))

    print(f'Epoch {epoch:03d} | Loss >> {train_losses[-1][0]:.4f}/ {train_losses[-1][1]:.4f}/ {train_losses[-1][2]:.4f}')



Epoch 001 | Loss >> 2915.1128/ 2876.9089/ 38.2040
Epoch 002 | Loss >> 186.8791/ 145.8311/ 41.0480
Epoch 003 | Loss >> 160.9378/ 133.0548/ 27.8830
Epoch 004 | Loss >> 143.0788/ 128.7050/ 14.3739
Epoch 005 | Loss >> 134.0556/ 128.5861/ 5.4695
Epoch 006 | Loss >> 130.6097/ 129.6743/ 0.9354
Epoch 007 | Loss >> 127.0032/ 126.0474/ 0.9557
Epoch 008 | Loss >> 125.9513/ 125.6287/ 0.3226
Epoch 009 | Loss >> 127.9104/ 127.6522/ 0.2581
Epoch 010 | Loss >> 127.8795/ 127.5232/ 0.3563
Epoch 011 | Loss >> 126.9872/ 126.6768/ 0.3105
Epoch 012 | Loss >> 127.4638/ 127.3409/ 0.1230
Epoch 013 | Loss >> 127.4611/ 127.1741/ 0.2869
Epoch 014 | Loss >> 127.8259/ 127.6131/ 0.2128
Epoch 015 | Loss >> 126.3635/ 126.1700/ 0.1935
Epoch 016 | Loss >> 131.2211/ 130.9681/ 0.2530
Epoch 017 | Loss >> 130.1056/ 129.8042/ 0.3014
Epoch 018 | Loss >> 129.5030/ 129.0640/ 0.4390
Epoch 019 | Loss >> 129.6932/ 129.3315/ 0.3617
Epoch 020 | Loss >> 127.4312/ 127.0978/ 0.3334
Epoch 021 | Loss >> 129.0741/ 128.9009/ 0.1733
Epoch 0