<a href="https://colab.research.google.com/github/K3dA2/VQ-VAE/blob/main/train_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! pip install kaggle
! mkdir ~/.kaggle

mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [3]:
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/kaggle.json
! kaggle datasets download scribbless/another-anime-face-dataset
! unzip another-anime-face-dataset.zip

Dataset URL: https://www.kaggle.com/datasets/scribbless/another-anime-face-dataset
License(s): GPL-2.0
another-anime-face-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  another-anime-face-dataset.zip
replace animefaces256cleaner/10004131_result.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [4]:
!git clone https://github.com/K3dA2/VQ-VAE.git

fatal: destination path 'VQ-VAE' already exists and is not an empty directory.


In [5]:
import sys
sys.path.append('/content/VQ-VAE/')

In [6]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import datetime
import os
import torch.nn.utils as utils
from model import VQVAE
from utils import get_data_loader,count_parameters
import uuid
import os
import random
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

out shape: torch.Size([2, 3, 128, 128])
loss shape: 0.5168724656105042
torch.Size([1, 3, 64, 64])


In [7]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, fname))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Return the image and a dummy label (e.g., 0)
        return image, 0

def get_data_loader(path, batch_size, num_samples=None, shuffle=True):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.7002, 0.6099, 0.6036), (0.2195, 0.2234, 0.2097))
    ])

    full_dataset = CustomDataset(root_dir=path, transform=transform)

    if num_samples is None or num_samples > len(full_dataset):
        num_samples = len(full_dataset)
    print("data length:", len(full_dataset))

    if shuffle:
        indices = random.sample(range(len(full_dataset)), num_samples)
    else:
        indices = list(range(num_samples))

    subset_dataset = Subset(full_dataset, indices)

    data_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=shuffle)

    return data_loader

In [8]:
path = '/content/animefaces256cleaner'
batch_size = 32
data_loader = get_data_loader(path, batch_size, num_samples=None, shuffle=True)
writer = SummaryWriter(log_dir='/content/drive/MyDrive/VQ-VAE Weights/logs')

data length: 92219


In [9]:
def training_loop(n_epochs, optimizer, model, loss_fn, device, data_loader,\
                   max_grad_norm=1.0, epoch_start = 0,\
                    save_img = False, show_img = True):
    model.train()
    for epoch in range(epoch_start,n_epochs):
        loss_train = 0.0

        progress_bar = tqdm(data_loader, desc=f'Epoch {epoch}', unit=' batch')
        for batch_idx, (imgs, _) in enumerate(progress_bar):
            imgs = imgs.to(device)


            outputs,vq_loss = model(imgs)
            mse_loss = loss_fn(outputs, imgs)
            loss = mse_loss + vq_loss 

            optimizer.zero_grad()
            loss.backward()
            #utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

            loss_train += loss.item()

            # Log losses to TensorBoard
            writer.add_scalar('Loss/Total', loss.item(), epoch * len(data_loader) + batch_idx)
            writer.add_scalar('Loss/Reconstruction', mse_loss.item(), epoch * len(data_loader) + batch_idx)
            writer.add_scalar('Loss/VectorQuantization', vq_loss.item(), epoch * len(data_loader) + batch_idx)

            # Log embeddings to TensorBoard
            if batch_idx % 10 == 0:
                writer.add_embedding(
                    model.codebook.weight.data,
                    metadata=[f"embedding_{i}" for i in range(model.num_embeddings)],
                    global_step=epoch * len(data_loader) + batch_idx,
                    tag='Codebook'
                )

            progress_bar.set_postfix(loss=loss.item())

        # Save model checkpoint with the current epoch in the filename
        model_filename = f'waifu-vqvae.pth'
        model_path = os.path.join('/content/drive/MyDrive/VQ-VAE Weights/', model_filename)

        with open("waifu-vqvae_epoch-loss.txt", "a") as file:
            file.write(f"{loss_train / len(data_loader)}\n")

        print('{} Epoch {}, Training loss {}'.format(datetime.datetime.now(), epoch, loss_train / len(data_loader)))
        if epoch % 20 == 0:
            if show_img:
                pred_images = model.inference(1, 14, 14)
                plt.imshow(np.transpose(pred_images[-1].cpu().numpy(), (1, 2, 0)))
                plt.show()
            if save_img:
                pred_images = model.inference(1, 14, 14)
                pred_images = np.transpose(pred_images[-1].cpu().numpy(), (1, 2, 0))
                random_filename = str(uuid.uuid4()) + '.png'

                # Specify the directory where you want to save the image
                save_directory = path

                # Create the full path including the directory and filename
                full_path = os.path.join(save_directory, random_filename)
                # Save the image with the random filename
                plt.savefig(full_path, bbox_inches='tight', pad_inches=0)

            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, model_path)

In [10]:
path = '/content/animefaces256cleaner'
#model_path = '/content/drive/MyDrive/VQ-VAE Weights/waifu-vqvae.pth'

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

model = VQVAE()  # Assuming Unet is correctly imported and defined
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
#loss_fn = nn.L1Loss().to(device)
loss_fn = nn.MSELoss().to(device)
print(count_parameters(model))

using device: cpu
31209996


In [11]:
# Optionally load model weights if needed
#checkpoint = torch.load(model_path)
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#epoch = checkpoint['epoch']

In [None]:
training_loop(
    n_epochs=1000,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    device=device,
    data_loader=data_loader,
    epoch_start= 0,

)

Epoch 0:   0%|          | 0/2882 [00:00<?, ? batch/s]