In [1]:
import torch
import torchvision
from Architectures.ConvolutionalAutoEncoder import ConvVAE
from DataObjects.DataLoader import DataLoader
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Get the current directory of the script (DataObjects)
current_script_dir = os.getcwd()

# Go up one level to the project root
project_root = os.path.dirname(current_script_dir)

data_dir = os.path.join(project_root, 'data', 'downscaled')

In [4]:
batch_size = 32

Data_cats = DataLoader(data_dir , batch_size=batch_size, shuffle=True)
print(f"Number of cat batches: {len(Data_cats)}")
Data_dogs = DataLoader(data_dir + '/dog/', batch_size=batch_size, shuffle=True)
print(f"Number of dog batches: {len(Data_dogs)}")

Number of cat batches: 933
Number of dog batches: 0


In [5]:
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvVAE.from_pretrained(device)

model.train_architecture(Data_cats, epochs)
model.save_model(f"../Saved_Models/cae_model.pt")
model.load_model(f"../Saved_Models/cae_model.pt", map_location=device)

VAE initialized on device: cuda
First conv layer device: cuda:0
VAE model moved to device: cuda
VAE model parameters on device: cuda:0
Epoch [1/10], Batch [0/933], Total Loss: 13.1695, Recon Loss: 0.8283, KL Loss: 12.3411
Epoch [1/10], Batch [100/933], Total Loss: 0.3134, Recon Loss: 0.1418, KL Loss: 0.1716
Epoch [1/10], Batch [200/933], Total Loss: 0.1907, Recon Loss: 0.0717, KL Loss: 0.1190
Epoch [1/10], Batch [300/933], Total Loss: 0.1379, Recon Loss: 0.0666, KL Loss: 0.0714
Epoch [1/10], Batch [400/933], Total Loss: 0.1246, Recon Loss: 0.0738, KL Loss: 0.0508
Epoch [1/10], Batch [500/933], Total Loss: 0.1185, Recon Loss: 0.0592, KL Loss: 0.0593
Epoch [1/10], Batch [600/933], Total Loss: 0.1203, Recon Loss: 0.0673, KL Loss: 0.0530
Epoch [1/10], Batch [700/933], Total Loss: 0.1019, Recon Loss: 0.0562, KL Loss: 0.0457
Epoch [1/10], Batch [800/933], Total Loss: 0.0916, Recon Loss: 0.0617, KL Loss: 0.0299
Epoch [1/10], Batch [900/933], Total Loss: 0.0841, Recon Loss: 0.0585, KL Loss: 0.

In [6]:
samples = model.generate(num_samples=4)
for i, img in enumerate(samples):
    torchvision.utils.save_image(img, f"../Saved_Models/sample_{i}.png")