In [1]:
import os
import gc
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
from PIL import Image
from diffusers import AutoencoderKL
from tqdm import tqdm

In [2]:
# Custom dataset to load images from triplet directories
class AnimeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with subdirectories, each containing frame1.jpg, frame2.jpg, frame3.jpg.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        # Loop over each subdirectory and add available frame*.jpg files to the list
        for subdir in os.listdir(root_dir):
            full_subdir = os.path.join(root_dir, subdir)
            if os.path.isdir(full_subdir):
                # Use all 3 frames per triplet as independent training samples
                for i in range(1, 4):
                    img_path = os.path.join(full_subdir, f"frame{i}.jpg")
                    if os.path.exists(img_path):
                        self.image_paths.append(img_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# Function to compute KL divergence loss for the VAE
def kl_divergence(mu, logvar):
    # Standard VAE KL divergence between N(mu, sigma) and N(0,1)
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())


In [3]:
torch.cuda.is_available()

True

In [4]:
torch.cuda.device_count()

4

In [5]:
[torch.cuda.device(i) for i in range(torch.cuda.device_count())]

[<torch.cuda.device at 0x7ff042904ec0>,
 <torch.cuda.device at 0x7ff169f24cd0>,
 <torch.cuda.device at 0x7ff0428ac910>,
 <torch.cuda.device at 0x7ff1781f2b10>]

In [6]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=3)

In [7]:
# Configuration parameters
train_dir = "/data/ggonsior/atd12k/train"
batch_size = 8
epochs = 10
learning_rate = 1e-4
image_size = (256, 256)

# Define image transformations (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    # Normalize to [-1, 1]; adjust if your model requires different normalization.
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Create dataset and dataloader
train_dataset = AnimeDataset(train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [8]:
train_dataset

<__main__.AnimeDataset at 0x7ff042905550>

In [9]:
# Load the pre-trained VAE ("kl-f8-anime2") from diffusers
print("Loading pre-trained VAE...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
vae.to(device)
vae.train()  # Set the model to training mode

Loading pre-trained VAE...


AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [10]:
# Set up optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

In [None]:
total_loss, total_recon_loss, total_kl_loss = 0.0, 0.0, 0.0

for batch in train_loader:
    batch = batch.to(device)
    optimizer.zero_grad()

    posterior = vae.encode(batch).latent_dist
    z = posterior.sample()
    mu = posterior.mean
    sigma = posterior.logvar

    del posterior
    gc.collect()
    torch.cuda.empty_cache()

    reconstruction = vae.decode(z).sample

    recon_loss = F.mse_loss(reconstruction, batch)
    kl_loss = kl_divergence(mu, sigma) / batch.size(0)
    loss = recon_loss + kl_loss

    # Backpropagation
    loss.backward()
    optimizer.step()

    # Accumulate losses
    total_loss += loss.item() * batch.size(0)
    total_recon_loss += recon_loss.item() * batch.size(0)
    total_kl_loss += kl_loss.item() * batch.size(0)

    # Update batch progress bar
    pbar_epoch.update(1)
    pbar_epoch.set_postfix({"Loss": loss.item(), "KL": kl_loss.item(), "Recon": recon_loss.item()})

    del batch, reconstruction, mu, sigma, z
    gc.collect()
    torch.cuda.empty_cache()

    break

In [None]:
# Training loop with tqdm for progress tracking
for epoch in range(epochs):
    total_loss, total_recon_loss, total_kl_loss = 0.0, 0.0, 0.0

    # Outer tqdm for epoch progress
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", position=0, leave=True) as pbar_epoch:
        for batch in tqdm(train_loader, desc="Batch Progress", position=1, leave=False):
            batch = batch.to(device)
            optimizer.zero_grad()

            # Forward pass
            posterior = vae.encode(batch).latent_dist
            z = posterior.sample()
            reconstruction = vae.decode(z).sample

            # Compute losses
            recon_loss = F.mse_loss(reconstruction, batch)
            kl_loss = kl_divergence(posterior.mean, posterior.logvar) / batch.size(0)
            loss = recon_loss + kl_loss

            # Backpropagation
            loss.backward()
            optimizer.step()

            # Accumulate losses
            total_loss += loss.item() * batch.size(0)
            total_recon_loss += recon_loss.item() * batch.size(0)
            total_kl_loss += kl_loss.item() * batch.size(0)

            # Update batch progress bar
            pbar_epoch.update(1)
            pbar_epoch.set_postfix({"Loss": loss.item(), "KL": kl_loss.item(), "Recon": recon_loss.item()})

    # Compute epoch losses
    avg_loss = total_loss / len(train_dataset)
    avg_recon = total_recon_loss / len(train_dataset)
    avg_kl = total_kl_loss / len(train_dataset)

    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")

Epoch 1/10:   0%|          | 0/1640 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/1640 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 3 has a total capacity of 31.73 GiB of which 78.19 MiB is free. Including non-PyTorch memory, this process has 31.65 GiB memory in use. Of the allocated memory 29.57 GiB is allocated by PyTorch, and 1.72 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)