In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math

import zipfile
import os
# os.chdir('..')

import shutil
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt

from data.load_data import prepare_dog_dataset
from src.custom_vae import VAEDog

import torchvision
from torchvision import transforms

## Pobranie i przygotowanie danych

In [None]:
# !gdown 1KXRTB_q4uub_XOHecpsQjE4Kmv76sZbV -O data/all-dogs.zip
# linux
# !unzip -q data/all-dogs.zip -d data/all-dogs

# windows
zip_path = "data/all-dogs.zip"
extract_to = "data/"

os.makedirs(extract_to, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [None]:
def split_dataset(source_dir, train_dir, test_dir, test_size=0.2, random_state=42):
    image_files = [f for f in os.listdir(source_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    train_files, test_files = train_test_split(image_files, test_size=test_size, random_state=random_state)

    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    for file in train_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(train_dir, file))

    for file in test_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(test_dir, file))


source_dir = "data/all-dogs"
train_dir = "data/train/dogs"
test_dir = "data/test/dogs"

split_dataset(source_dir, train_dir, test_dir)

In [None]:
# Data loading
transform = transforms.Compose([
    transforms.Resize((56, 56)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 64
dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
from src.custom_vae import VAE, VAEDog

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100
lr = 3e-4
beta = .00025 # KL divergence weight
bs = 32 # batch_size

acc_steps = 100
effective_batch_size = bs * acc_steps

vae = VAE()
vae.to(device)

## Trenowanie modelu

In [None]:
from torch import optim 

def fit(model: nn.Module, optimizer: optim.Optimizer,
        dataloader: torch.utils.data.DataLoader,
        acc_steps:int=1, epochs:int=100,
        beta:float=.00025,
        save:bool = True,
        path_to_save:str='',
       verbose:int=10) -> list:
  train_losses = []
  # training loop
  for epoch in range(epochs):
    model.train()
    train_loss = 0 # Corrected: Initialized train_loss inside the epoch loop
    for i, (images, _) in enumerate(dataloader):
        images = images.to(device)

        # Forward pass
        reconstructed, encoded = model(images)

        # Compute loss
        recon_loss = nn.MSELoss()(reconstructed, images)

        mean, log_variance = torch.chunk(encoded, 2, dim=1)

        kl_div = -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
        # kl_div = -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp(), dim=1).mean()

        # beta_t = beta * min(1.0, epoch / warmup_epochs)
        # loss = recon_loss + beta_t * kl_div
        
        loss = recon_loss + beta * kl_div

        # Normalize the loss to account for accumulation
        loss = loss / acc_steps

        # Backward pass
        loss.backward()

        if (i + 1) % acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss += loss.item() * acc_steps

        if verbose != 0 and (i + 1) % verbose == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], '
                  f'Loss: {loss.item()*acc_steps:.4f}, Recon Loss: {recon_loss.item():.4f}, KL Div: {kl_div.item():.4f}')

    # Calculate average epoch loss
    train_losses.append(train_loss / len(dataloader))

    if save:
      if not os.path.exists(path_to_save):
          os.makedirs(path_to_save)
      torch.save(model.state_dict(), f'final.pth')

  return train_losses

In [None]:
optimizer = optim.Adam(vae.parameters(), lr=lr)

In [None]:
fit(vae, optimizer, dataloader,
    acc_steps, 100, beta,
    path_to_save = '/kaggle/working',
   verbose=200)

In [None]:
torch.save({
    'model_state_dict': vae.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, 'checkpoint_20.pth')