# Install libraries

In [None]:
!pip install wandb

In [None]:
!export WANDB_BASE_URL="https://api.wandb.ai"

!export WANDB_API_KEY=WANDB_API_KEY=c7f04b65bf1a67f58b3289457f323c51ff7d913d

In [None]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Imports

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from dataclasses import dataclass, asdict
from torchvision import transforms
from PIL import Image
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
import numpy as np
import tqdm
import wandb

# Unzip

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip "/content/drive/MyDrive/Bachelor's Project/data/original/complete_data.zip"

# Config

In [None]:
@dataclass
class Config:
  image_size: int
  embedding_size: int
  shape_before_flattening: int
  device: str
  epochs: int
  batch_size: int
  lr: float

  def dumps(self):
      return {k: str(v) for k, v in asdict(self).items()}

In [None]:
config = Config(
    image_size = 512,
    embedding_size = 128,
    shape_before_flattening = (128, 64, 64),
    device = 'cuda',
    epochs = 100,
    batch_size = 32,
    lr = 1e-5
)

In [None]:
wandb.init(
    project="vae_synthetic",

    config=config.dumps()
)

[34m[1mwandb[0m: Currently logged in as: [33maroba18[0m ([33mdupiti[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Preprocessing

In [None]:
class DataProcessor:

  def __init__(self, config):
    self.config = config
    self.transformer = self.__get_image_transformer()

  def __get_image_transformer(self):
    transformer = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.Grayscale(),
        transforms.ToTensor(),
    ])
    return transformer

  def process(self, input_image):
    input_image = self.transformer(input_image)
    return input_image

# Model Architecture

In [None]:
def vae_gaussian_kl_loss(mu, logvar):
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return KLD.mean()

def reconstruction_loss(x_reconstructed, x):
    # Assuming that the final layer of the decoder uses a sigmoid activation
    bce_loss = nn.MSELoss()
    # Flatten the inputs for BCELoss
    x_reconstructed_flat = x_reconstructed.view(x_reconstructed.size(0), -1)
    x_flat = x.view(x.size(0), -1)
    return bce_loss(x_reconstructed_flat, x_flat)

def vae_loss(y_pred, y_true):
    mu, logvar, recon_x = y_pred
    recon_loss = reconstruction_loss(recon_x, y_true)
    kld_loss = vae_gaussian_kl_loss(mu, logvar)
    return 5_000 * recon_loss + kld_loss

In [None]:
class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        # get the shape of the tensor for the mean and log variance
        batch, dim = z_mean.shape
        # generate a normal random tensor (epsilon) with the same shape as z_mean
        # this tensor will be used for reparameterization trick
        epsilon = Normal(0, 1).sample((batch, dim)).to(z_mean.device)
        # apply the reparameterization trick to generate the samples in the
        # latent space
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

In [None]:
class Encoder(nn.Module):
    def __init__(self, image_size, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 128, 3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)  # BatchNorm layer
        self.conv2 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)  # BatchNorm layer
        self.conv3 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)  # BatchNorm layer
        self.flatten = nn.Flatten()
        self.fc_mean = nn.Linear(128 * (image_size // 8) * (image_size // 8), embedding_dim)
        self.fc_log_var = nn.Linear(128 * (image_size // 8) * (image_size // 8), embedding_dim)

        self.sampling = Sampling()

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.flatten(x)
        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z


In [None]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim, shape_before_flattening):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(embedding_dim, shape_before_flattening[0] * shape_before_flattening[1] * shape_before_flattening[2])
        self.reshape = lambda x: x.view(-1, *shape_before_flattening)
        self.deconv1 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv3 = nn.ConvTranspose2d(128, 1, 3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = self.reshape(x)
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = torch.sigmoid(self.deconv3(x))  # Sigmoid activation for the final layer
        return x


In [None]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        # initialize the encoder and decoder
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, x):
        # pass the input through the encoder to get the latent vector
        z_mean, z_log_var, z = self.encoder(x)
        # pass the latent vector through the decoder to get the reconstructed
        # image
        reconstruction = self.decoder(z)
        # return the mean, log variance and the reconstructed image
        return z_mean, z_log_var, reconstruction

# Dataset and Dataloader

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import glob
import os

In [None]:
image_dir = '/content/redfin_images'

In [None]:
class ImageFolderDataset(Dataset):

    def __init__(self, image_dir, config):
        self.image_dir = image_dir
        self.processor = DataProcessor(config)
        self.image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])

        model_input = self.processor.process(image)

        return model_input

# Define variables

## Eval Function

In [None]:
def show_images(images, title="Images"):
    """Display a batch of images"""
    images = torchvision.utils.make_grid(images, nrow=5, normalize=True)
    plt.figure(figsize=(15, 15))
    plt.imshow(images.permute(1, 2, 0))
    plt.title(title)
    plt.axis('off')
    plt.show()

## Training Parameters

In [None]:
encoder = Encoder(config.image_size, config.embedding_size).to(config.device)
decoder = Decoder(config.embedding_size, config.shape_before_flattening).to(config.device)
model = VAE(encoder, decoder)

In [None]:
dataset = ImageFolderDataset(image_dir=image_dir, config=config)

In [None]:
val_split = int(np.floor(0.01 * dataset.__len__()))
train_split = dataset.__len__() - val_split
train_dataset, val_dataset = random_split(dataset, [train_split, val_split])

In [None]:
shuffle = True
num_workers = 2

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

In [None]:
optimizer = optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()), lr=config.lr
)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

## Training Loop

In [None]:
print_steps = 200

In [None]:
train_losses = []
val_losses = []


for epoch in tqdm.notebook.tqdm(range(config.epochs), desc='Epoch'):
    model.train()
    running_loss = 0.0

    if epoch % 5 == 0:
        model_path = '/content/model.pth'
        torch.save(model.state_dict(), model_path)
        artifact = wandb.Artifact('model_epoch_' + str(epoch), type='model')
        artifact.add_file(model_path)
        wandb.log_artifact(artifact)

    for batch_idx, data in tqdm.notebook.tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False, desc='Train Batch'):
        data = data.to(config.device)
        optimizer.zero_grad()
        pred = model(data)
        loss = vae_loss(pred, data)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if (batch_idx + 1) % print_steps == 0:

            # Calculate average training loss
            avg_train_loss = running_loss / print_steps
            train_losses.append(avg_train_loss)

            model.eval()
            val_running_loss = 0.0
            with torch.no_grad():
                for val_data in val_dataloader:
                    val_data = val_data.to(config.device)
                    val_pred = model(val_data)
                    val_loss = vae_loss(val_pred, val_data)
                    val_running_loss += val_loss.item()
                avg_val_loss = val_running_loss / len(val_dataloader)
                val_losses.append(avg_val_loss)

                val_batch = next(iter(val_dataloader))
                val_batch = val_batch.to(config.device)
                val_pred = model(val_batch)
                recon_images = val_pred[2]


            wandb.log({"training_loss": avg_train_loss})
            wandb.log({"val_loss": avg_val_loss})

            wandb.log({
                "examples": [
                    wandb.Image(data[0], caption="original"),
                    wandb.Image(pred[2][0], caption="predicted"),
                ]
            })

            print(f"Epoch [{epoch+1}/{config.epochs}], Step [{batch_idx+1}/{len(train_dataloader)}], "
                  f"Avg Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")

            # Reset running loss and switch back to training mode
            running_loss = 0.0
            model.train()


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [1/100], Step [200/723], Avg Train Loss: 821.0845, Avg Val Loss: 425.2524
Epoch [1/100], Step [400/723], Avg Train Loss: 527.4630, Avg Val Loss: 448.3988
Epoch [1/100], Step [600/723], Avg Train Loss: 338.1154, Avg Val Loss: 282.3483


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [2/100], Step [200/723], Avg Train Loss: 298.4787, Avg Val Loss: 238.9663
Epoch [2/100], Step [400/723], Avg Train Loss: 263.9883, Avg Val Loss: 224.1928
Epoch [2/100], Step [600/723], Avg Train Loss: 956.8742, Avg Val Loss: 230.2114


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [3/100], Step [200/723], Avg Train Loss: 213.2930, Avg Val Loss: 195.1843
Epoch [3/100], Step [400/723], Avg Train Loss: 200.0919, Avg Val Loss: 186.3053
Epoch [3/100], Step [600/723], Avg Train Loss: 189.9062, Avg Val Loss: 180.5246


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [4/100], Step [200/723], Avg Train Loss: 183.5033, Avg Val Loss: 171.3447
Epoch [4/100], Step [400/723], Avg Train Loss: 179.4137, Avg Val Loss: 174.6257
Epoch [4/100], Step [600/723], Avg Train Loss: 179.8701, Avg Val Loss: 167.5737


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [5/100], Step [200/723], Avg Train Loss: 176.6008, Avg Val Loss: 166.9488
Epoch [5/100], Step [400/723], Avg Train Loss: 172.8564, Avg Val Loss: 165.2418
Epoch [5/100], Step [600/723], Avg Train Loss: 173.0259, Avg Val Loss: 167.0895


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [6/100], Step [200/723], Avg Train Loss: 171.5945, Avg Val Loss: 170.1729
Epoch [6/100], Step [400/723], Avg Train Loss: 170.7028, Avg Val Loss: 163.2735
Epoch [6/100], Step [600/723], Avg Train Loss: 170.2426, Avg Val Loss: 162.3817


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [7/100], Step [200/723], Avg Train Loss: 169.3066, Avg Val Loss: 164.5463
Epoch [7/100], Step [400/723], Avg Train Loss: 168.1267, Avg Val Loss: 161.8022
Epoch [7/100], Step [600/723], Avg Train Loss: 168.0057, Avg Val Loss: 163.4158


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [8/100], Step [200/723], Avg Train Loss: 168.0993, Avg Val Loss: 161.7508
Epoch [8/100], Step [400/723], Avg Train Loss: 167.1786, Avg Val Loss: 159.8101
Epoch [8/100], Step [600/723], Avg Train Loss: 165.1433, Avg Val Loss: 159.1689


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [9/100], Step [200/723], Avg Train Loss: 165.0502, Avg Val Loss: 159.4012
Epoch [9/100], Step [400/723], Avg Train Loss: 165.3089, Avg Val Loss: 156.2823
Epoch [9/100], Step [600/723], Avg Train Loss: 163.8555, Avg Val Loss: 160.7358


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [10/100], Step [200/723], Avg Train Loss: 164.4315, Avg Val Loss: 159.4444
Epoch [10/100], Step [400/723], Avg Train Loss: 162.2987, Avg Val Loss: 153.8335
Epoch [10/100], Step [600/723], Avg Train Loss: 162.7925, Avg Val Loss: 154.8332


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [11/100], Step [200/723], Avg Train Loss: 161.7581, Avg Val Loss: 156.6472
Epoch [11/100], Step [400/723], Avg Train Loss: 161.4868, Avg Val Loss: 157.1423
Epoch [11/100], Step [600/723], Avg Train Loss: 160.5475, Avg Val Loss: 153.0179


Train Batch:   0%|          | 0/723 [00:00<?, ?it/s]

Epoch [12/100], Step [200/723], Avg Train Loss: 159.5105, Avg Val Loss: 155.2147


In [None]:
wandb.finish()

In [None]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Print Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()