In [1]:
import torch
import torchvision
from Architectures.ConvolutionalAutoEncoder import ConvAutoEncoder
from DataObjects.DataLoader import DataLoader
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Get the current directory of the script (DataObjects)
current_script_dir = os.getcwd()

# Go up one level to the project root
project_root = os.path.dirname(current_script_dir)

data_dir = os.path.join(project_root, 'data', 'downscaled')

In [4]:
batch_size = 32

Data_cats = DataLoader(data_dir , batch_size=batch_size, shuffle=True)
print(f"Number of cat batches: {len(Data_cats)}")
Data_dogs = DataLoader(data_dir + '/dog/', batch_size=batch_size, shuffle=True)
print(f"Number of dog batches: {len(Data_dogs)}")

Number of cat batches: 933
Number of dog batches: 0


In [5]:
epochs = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvAutoEncoder.from_pretrained(device)

model.train_architecture(Data_cats, epochs)
model.save_model(f"../Saved_Models/cae_model.pt")
model.load_model(f"../Saved_Models/cae_model.pt", map_location=device)

Model initialized on device: cuda
First conv layer device: cuda:0
Model moved to device: cuda
Model parameters on device: cuda:0
Epoch [1/1], Batch [0/933], Loss: 0.7591
Epoch [1/1], Batch [100/933], Loss: 0.0587
Epoch [1/1], Batch [200/933], Loss: 0.0396
Epoch [1/1], Batch [300/933], Loss: 0.0303
Epoch [1/1], Batch [400/933], Loss: 0.0252
Epoch [1/1], Batch [500/933], Loss: 0.0216
Epoch [1/1], Batch [600/933], Loss: 0.0233
Epoch [1/1], Batch [700/933], Loss: 0.0189
Epoch [1/1], Batch [800/933], Loss: 0.0166
Epoch [1/1], Batch [900/933], Loss: 0.0174
Epoch [1/1] completed, Average Loss: 0.0349
Model saved to ../Saved_Models/cae_model.pt
Model loaded from ../Saved_Models/cae_model.pt and moved to cuda


In [7]:
samples = model.generate(num_samples=4)
for i, img in enumerate(samples):
    torchvision.utils.save_image(img, f"../Saved_Models/sample_{i}.png")