# Variational Auto Encoder
The logic for this has been taken from the original paper and Lecture 13 of CS231n

## Model

In [6]:
import torch
import torch.nn as nn
from datasets import tqdm

In [7]:
class VariationalAutoencoder(nn.Module):

    def __init__(self, input_dim, hidden_dim = 128*128*3, latent_dim = 32*32*3):
        super(VariationalAutoencoder, self).__init__()

        #Encoding
        self.img2hid = nn.Linear(input_dim, hidden_dim)
        self.hid2mu = nn.Linear(hidden_dim, latent_dim)
        self.hid2var = nn.Linear(hidden_dim, latent_dim)

        #Decoding
        self.z2hid = nn.Linear(latent_dim, hidden_dim)
        self.hid2img = nn.Linear(hidden_dim, input_dim)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):

        h = self.relu(self.img2hid(x))
        mu = self.hid2mu(h)
        sigma = self.hid2var(h)

        return mu, sigma

    def decode(self, z):

        h = self.relu(self.z2hid(z))

        return self.sigmoid(self.hid2img(h))

    def forward(self, x):

        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametarized = mu + epsilon * sigma
        reconstructed = self.decode(z_reparametarized)

        return reconstructed, mu, torch.log(sigma)

In [None]:
if __name__ == '__main__':
    x = torch.randn(10, 256 * 256)
    vae = VariationalAutoencoder(256*256)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)

## Loading dataset

This block is mostly AI generated.

In [8]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path

In [17]:
class create_image_dataset:
    def __init__(self, root_dir, batch_size=32, shuffle=True, num_workers=0):
        """
        Args:
            root_dir (string): Directory with all images
            batch_size (int): Mini-batch size
            shuffle (bool): Shuffle dataset order
            num_workers (int): Parallel data loading threads
        """
        if not os.path.isdir(root_dir):
            raise ValueError(f"Directory '{root_dir}' does not exist")
        
        # Define transformations: Resize + Grayscale (1 channel) -> Tensor
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
        ])
        
        # Create dataset instance
        self.dataset = self.ImageFolderDataset(root_dir, transform=self.transform)
        
        # Create DataLoader for batching
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers
        )
    
    def __iter__(self):
        """Iterator over batches"""
        return iter(self.dataloader)
    
    def __len__(self):
        """Number of batches per epoch"""
        return len(self.dataloader)
    
    class ImageFolderDataset:
        """Internal dataset class handling image loading"""
        def __init__(self, root_dir, transform=None):
            self.root_dir = root_dir
            self.transform = transform
            self.image_files = [
                f for f in os.listdir(root_dir) 
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ]
        
        def __len__(self):
            return len(self.image_files)
        
        def __getitem__(self, idx):
            img_path = os.path.join(self.root_dir, self.image_files[idx])
            with Image.open(img_path) as img:
                if self.transform:
                    return self.transform(img)
            return img  # Fallback (unlikely)

## Training

In [10]:
from tqdm import tqdm

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader =  create_image_dataset(root_dir="/kaggle/input/dataset/CSGO_model/dataset", batch_size=16, shuffle=True, num_workers=4)
model = VariationalAutoencoder(256*256, 128*128, 32*32).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [12]:
num_epochs = 1

In [19]:
if __name__ == "__main__":


    # torch.save(dataset_info, "/kaggle/input/dataset/dataset_info.pth")

    for epoch in range(num_epochs):
        for data in dataloader:

            for datium in data:
                
                data = datium.to(device).reshape(1,-1)
                print(data.shape)
                output, mu, logvar = model(data)
    
    
                reconstruction_loss = criterion(output, data)
                kl_loss = -torch.sum(1 + 2*logvar - mu.pow(2) - logvar.exp()/2)
    
                net_loss = reconstruction_loss + kl_loss
    
                optimizer.zero_grad()
                net_loss.backward()
                optimizer.step()
                iter.set_postfix(loss=net_loss.item())
    


TypeError: 'tqdm' object is not callable

## Predicitons
