# WK14

## VAEs

### Regression + Un-Regression (Encoder + Decoder)

- Learn dense representation of patterns in data
- Instead of using these for classification, style transfer, etc, learn how to undo these compressions

#### Code:
- https://avandekleut.github.io/vae/

#### Explanation:
- https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf
- https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

In [2]:
!wget -q https://github.com/DM-GY-9103-2024F-H/9103-utils/raw/main/src/image_utils.py
!wget -q https://github.com/DM-GY-9103-2024F-H/WK14/raw/main/WK14_utils.py

# Get Clouds or faces or flowers
!wget -qO- https://github.com/DM-GY-9103-2024F-H/9103-utils/releases/latest/download/clouds.tar.gz | tar xz
!wget -qO- https://github.com/DM-GY-9103-2024F-H/9103-utils/releases/latest/download/flowers.tar.gz | tar xz
!wget -qO- https://github.com/DM-GY-9103-2024F-H/9103-utils/releases/latest/download/metfaces.tar.gz | tar xz

In [None]:
import torch

from os import listdir, path

from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.models import resnet34, ResNet34_Weights
from torchvision.models import vgg19, VGG19_Weights
from torchvision.transforms import v2

from image_utils import make_image, open_image

from WK14_utils import count_parameters

In [None]:
mdevice = "cuda" if torch.cuda.is_available() else "cpu"

class ImageDataset(Dataset):
  def __init__(self, imgs):
    super().__init__()
    self.loader_transform = v2.Compose([
      v2.Resize(128),
      v2.ToImage(),
      v2.ConvertImageDtype(torch.float),
    ])
    self.num_imgs = len(imgs)
    self.imgs = self.loader_transform(imgs)
    self.imgs = torch.stack(self.imgs, dim=0)[:, :3]
    self.imgs = self.imgs.to(mdevice)

  def __len__(self):
    return self.num_imgs

  def __getitem__(self, idx):
    return self.imgs[idx]

In [None]:
img_dir = "./data/image/metfaces"
filenames = sorted([f for f in listdir(img_dir) if f.endswith("png") or f.endswith("jpg")])

images = []
for fname in filenames[:50]:
  img = open_image(path.join(img_dir, fname))
  images.append(img)

In [None]:
ids = ImageDataset(images)
images_dl = DataLoader(ids, batch_size=128, shuffle=True)

In [None]:
torch.manual_seed(1010)
imgs = next(iter(images_dl))
display(v2.ToPILImage()(imgs[1]))

In [None]:
class VAE(nn.Module):
  def __init__(self, in_features, hidden_features=512, latent_features=64):
    super().__init__()
    self.img2hid = nn.Linear(in_features, hidden_features)
    self.hid2mean = nn.Linear(hidden_features, latent_features)
    self.hid2std = nn.Linear(hidden_features, latent_features)
    self.z2hid = nn.Linear(latent_features, hidden_features)
    self.hid2img = nn.Linear(hidden_features, in_features)
    self.N = torch.distributions.Normal(Tensor([0]).to(mdevice), Tensor([1]).to(mdevice))
    self.kl = 0

  def encode(self, x):
    x = torch.flatten(x, start_dim=1)
    hid = F.relu(self.img2hid(x))
    mean = self.hid2mean(hid)
    std = torch.exp(self.hid2std(hid))
    z = mean + std * self.N.sample(mean.shape)
    self.kl = (std**2 + mean**2 - torch.log(std) - 0.5).sum()
    return z

  def decode(self, x):
    hid = F.relu(self.z2hid(x))
    img = torch.sigmoid(self.hid2img(hid))
    return img.reshape(-1, 3, 128, 128)

  def forward(self, x):
    z = self.encode(x)
    return self.decode(z)

In [None]:
mdevice = "cuda" if torch.cuda.is_available() else "cpu"

model = VAE(in_features=128*128*3).to(mdevice)

learning_rate = 1e-4
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

torch.manual_seed(1010)
imgs = next(iter(images_dl))
y = model(imgs)

print("Input shape:", imgs.shape)
print("Output shape:", y.shape)
print("Parameters:", count_parameters(model))

In [None]:
for e in range(32):
  model.train()
  for imgs in images_dl:
    optim.zero_grad()
    y = model(imgs)
    rec_loss = ((imgs - y)**2).sum()
    loss= rec_loss + model.kl
    loss.backward()
    optim.step()

  if e % 4 == 3:
    print(f"Epoch: {e} loss: {loss.item():.4f} rec loss: {rec_loss.item():.4f} kl: {model.kl.item():.4f}")

In [None]:
with torch.no_grad():
  torch.manual_seed(1010)
  imgs = next(iter(images_dl))
  y = model(imgs)
  idx = 3
  display(v2.ToPILImage()(imgs[idx]))
  display(v2.ToPILImage()(y[idx]))

In [None]:
with torch.no_grad():
  z = torch.randn((16,64)).to(mdevice)
  ys = model.decode(z)
  for y in ys:
    display(v2.ToPILImage()(y))

In [None]:
display(v2.ToPILImage()(imgs[1]))
display(v2.ToPILImage()(imgs[4]))
display(v2.ToPILImage()(imgs[3]))

with torch.no_grad():
  z = model.encode(imgs)
  z = z[1] - z[4] + z[3]
  y = model.decode(z)
  display(v2.ToPILImage()(y[0]))

## Possible Next Steps

- Use CNN to encode/decode images
- Look at pre-conditioned VAEs