In [1]:
import torch
import torchvision
from Architectures.ConvolutionalAutoEncoder import ConvVAE
from Architectures.EnhancedVAE import EnhancedConvVAE
from Architectures.VAE_utilis import *
from DataObjects.DataLoader import DataLoader
import os

VAE initialized on device: cuda
First conv layer device: cuda:0
VAE model moved to device: cuda
VAE model parameters on device: cuda:0


In [2]:
torch.manual_seed(34)

<torch._C.Generator at 0x7055a43aded0>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Get the current directory of the script (DataObjects)
current_script_dir = os.getcwd()

# Go up one level to the project root
project_root = os.path.dirname(current_script_dir)

data_dir = os.path.join(project_root, 'data', 'downscaled')

In [5]:
batch_size = 32

Data_cats = DataLoader(data_dir , batch_size=batch_size, shuffle=True)
print(f"Number of cat batches: {len(Data_cats)}")
Data_dogs = DataLoader(data_dir + '/dog/', batch_size=batch_size, shuffle=True)
print(f"Number of dog batches: {len(Data_dogs)}")

Number of cat batches: 933
Number of dog batches: 0


In [6]:
# Initialize the VAE
vae = EnhancedConvVAE(
    device=device,
    input_channels=3,
    latent_dim=256,
    base_channels=64,
    image_size=64,
    learning_rate=1e-4,
    beta=0.5,
    gamma=1000,
    use_spectral_norm=True
).to(device) # Ensure model is on the correct device

optimizer = vae.configure_optimizers()

print("VAE model initialized.")

EnhancedConvVAE initialized on device: cuda
VAE model initialized.


In [7]:
# 4. Training the Model
num_epochs = 10 # You can set this higher for better results
print("\nTraining EnhancedConvVAE...")
training_history = vae.train_architecture(Data_cats, epochs=num_epochs, use_capacity=True)



Training EnhancedConvVAE...
Starting EnhancedConvVAE training...
Epoch [1/10], Batch [100/933], Total Loss: 3253.8054, Recon Loss: 0.0809, KL Loss: 3.2787, Capacity: 0.03
Epoch [1/10], Batch [200/933], Total Loss: 1970.4435, Recon Loss: 0.0698, KL Loss: 2.0204, Capacity: 0.05
Epoch [1/10], Batch [300/933], Total Loss: 1174.7891, Recon Loss: 0.0707, KL Loss: 1.2497, Capacity: 0.07
Epoch [1/10], Batch [400/933], Total Loss: 1004.3373, Recon Loss: 0.0616, KL Loss: 1.1043, Capacity: 0.10
Epoch [1/10], Batch [500/933], Total Loss: 774.5616, Recon Loss: 0.0675, KL Loss: 0.8995, Capacity: 0.12
Epoch [1/10], Batch [600/933], Total Loss: 669.1547, Recon Loss: 0.0597, KL Loss: 0.8191, Capacity: 0.15
Epoch [1/10], Batch [700/933], Total Loss: 428.9076, Recon Loss: 0.0654, KL Loss: 0.6038, Capacity: 0.17
Epoch [1/10], Batch [800/933], Total Loss: 375.9575, Recon Loss: 0.0619, KL Loss: 0.5759, Capacity: 0.20
Epoch [1/10], Batch [900/933], Total Loss: 2248.2878, Recon Loss: 0.0796, KL Loss: 0.5201,

In [8]:
# 7. Generate and Visualize Images
vae.eval() # Set to evaluation mode

# Generate diverse samples
num_samples_gen = 25
print("\nGenerating diverse samples...")
diverse_samples = vae.generate_diverse(num_samples_gen, temperature=1.2)
vae.save_generated_images(diverse_samples, folder_path='./generated_images_diverse', prefix='diverse_sample')



Generating diverse samples...
Saved 25 generated images to ./generated_images_diverse
