In [1]:
import torch
import torchvision
from Architectures.ConvolutionalAutoEncoder import ConvVAE
from DataObjects.DataLoader import DataLoader
import os

In [2]:
torch.manual_seed(34)

<torch._C.Generator at 0x7a99e0fa5ed0>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Get the current directory of the script (DataObjects)
current_script_dir = os.getcwd()

# Go up one level to the project root
project_root = os.path.dirname(current_script_dir)

data_dir = os.path.join(project_root, 'data', 'downscaled')

In [5]:
batch_size = 32

Data_cats = DataLoader(data_dir , batch_size=batch_size, shuffle=True)
print(f"Number of cat batches: {len(Data_cats)}")
Data_dogs = DataLoader(data_dir + '/dog/', batch_size=batch_size, shuffle=True)
print(f"Number of dog batches: {len(Data_dogs)}")

Number of cat batches: 933
Number of dog batches: 0


In [6]:
epochs = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvVAE.from_pretrained(device)

model.train_architecture(Data_cats, epochs)
model.save_model(f"../Saved_Models/cae_model.pt")
model.load_model(f"../Saved_Models/cae_model.pt", map_location=device)

VAE initialized on device: cuda
First conv layer device: cuda:0
VAE model moved to device: cuda
VAE model parameters on device: cuda:0
Epoch [1/20], Batch [0/933], Total Loss: 13.3593, Recon Loss: 0.7248, KL Loss: 12.6345
Epoch [1/20], Batch [100/933], Total Loss: 0.2238, Recon Loss: 0.0833, KL Loss: 0.1406
Epoch [1/20], Batch [200/933], Total Loss: 0.1582, Recon Loss: 0.0772, KL Loss: 0.0810
Epoch [1/20], Batch [300/933], Total Loss: 0.1239, Recon Loss: 0.0738, KL Loss: 0.0501
Epoch [1/20], Batch [400/933], Total Loss: 0.0936, Recon Loss: 0.0604, KL Loss: 0.0332
Epoch [1/20], Batch [500/933], Total Loss: 0.0958, Recon Loss: 0.0583, KL Loss: 0.0375
Epoch [1/20], Batch [600/933], Total Loss: 0.1121, Recon Loss: 0.0649, KL Loss: 0.0472
Epoch [1/20], Batch [700/933], Total Loss: 0.1052, Recon Loss: 0.0734, KL Loss: 0.0317
Epoch [1/20], Batch [800/933], Total Loss: 0.1151, Recon Loss: 0.0630, KL Loss: 0.0521
Epoch [1/20], Batch [900/933], Total Loss: 0.0960, Recon Loss: 0.0589, KL Loss: 0.

In [7]:
samples = model.generate(num_samples=4)
for i, img in enumerate(samples):
    torchvision.utils.save_image(img, f"../Saved_Models/sample_{i}.png")