In [37]:
from DataLoading import DataLoad
# Test DataLoading
data_path = "./Data/images_original"
data_loader = DataLoad(data_path)
X, y = data_loader.fetch_dataset(dx=0, dy=0, dimx=128, dimy=128)
genre_names = data_loader.get_genre_names()

In [38]:
import os
import imageio
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from DataLoading import DataLoad
from EncoderCVAE import CVAE_Encoder
from DecoderCVAE import CVAE_Decoder
from Loss import get_loss
from CVAETrainerConstruct import TrainerCVAE
from datetime import datetime
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [39]:
class GenreDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        # Here label is an integer. For the CVAE, we want a condition vector.
        num_classes = 10
        condition = torch.nn.functional.one_hot(torch.tensor(label), num_classes=num_classes).float()
        return img, condition

In [40]:
def prepare_data():
    DATASET_PATH = "./Data/images_original"
    dx, dy = 0, 0
    dim = 128
    batch_size = 64

    dataload = DataLoad(DATASET_PATH)
    all_photos, all_labels = dataload.fetch_dataset(dx, dy, dim, dim)
    
    # Normalize images by converting from uint8 [0, 255] to float [0, 1].
    # Note: transforms.ToTensor() does the conversion and also reorders dimensions to (C, H, W).
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    # Split data into training and validation sets.
    X_train, X_val, y_train, y_val = train_test_split(all_photos, all_labels, test_size=0.1, random_state=365)
    
    train_dataset = GenreDataset(X_train, y_train, transform=transform)
    val_dataset = GenreDataset(X_val, y_val, transform=transform)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

In [41]:
def main():
    LATENT_SPACE_SIZE = 128
    CONDITION_DIM = 10
    
    train_loader, test_loader = prepare_data()
    
    encoder = CVAE_Encoder(latent_dim=LATENT_SPACE_SIZE, condition_dim=CONDITION_DIM, input_shape=(3, 128, 128), use_embedding=False)
    decoder = CVAE_Decoder(latent_dim=LATENT_SPACE_SIZE, condition_dim=CONDITION_DIM)
    
    trainer = TrainerCVAE(
        trainloader=train_loader,
        testloader=test_loader,
        Encoder=encoder,
        Decoder=decoder,
        latent_dim=LATENT_SPACE_SIZE,
        device="cuda"
    )
    
    trainer.train(num_epochs=50, factor=10)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    torch.save(encoder.state_dict(), f"cvae_encoder_{timestamp}.pth")
    torch.save(decoder.state_dict(), f"cvae_decoder_{timestamp}.pth")
    print(f"Models saved as cvae_encoder_{timestamp}.pth and cvae_decoder_{timestamp}.pth")

if __name__ == "__main__":
    main()

Epoch [1/50], Train Loss: 66204.4737, Test Loss: 35487.7109
Epoch [2/50], Train Loss: 18599.9132, Test Loss: 11320.0312
Epoch [3/50], Train Loss: 8128.0350, Test Loss: 7815.2983
Epoch [4/50], Train Loss: 5873.3525, Test Loss: 10292.7354
Epoch [5/50], Train Loss: 5079.1071, Test Loss: 17583.8945
Epoch [6/50], Train Loss: 4291.8943, Test Loss: 4350.2566
Epoch [7/50], Train Loss: 3897.8581, Test Loss: 7825.2402
Epoch [8/50], Train Loss: 3679.6737, Test Loss: 3954.7218
Epoch [9/50], Train Loss: 3739.7708, Test Loss: 8792.3535
Epoch [10/50], Train Loss: 4007.2799, Test Loss: 5630.2461
Epoch [11/50], Train Loss: 3403.7545, Test Loss: 8118.7983
Epoch [12/50], Train Loss: 3249.3814, Test Loss: 3697.3263
Epoch [13/50], Train Loss: 3138.9090, Test Loss: 3744.4889
Epoch [14/50], Train Loss: 3055.0602, Test Loss: 6068.6685
Epoch [15/50], Train Loss: 3185.0572, Test Loss: 4247.1689
Epoch [16/50], Train Loss: 3043.8439, Test Loss: 4100.9945
Epoch [17/50], Train Loss: 2969.6236, Test Loss: 3267.2416
