In [None]:
import os
current_dir = os.path.abspath("")
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))

import sys
sys.path.append(parent_dir)

In [None]:
import idx2numpy
import numpy as np

#from torchvision import transforms
from torchvision.transforms import v2 # for torchvision > 0.15
import torch.nn as nn
import torcheval.metrics as metrics

%matplotlib inline 
import matplotlib.pyplot as plt

from utils.dataloaders import *
from utils.preprocessing import *
from utils.training_utils import *
from utils.metrics import *
from utils.losses import *

from datasets.load_cifar10 import *

from denoising_autoencoder.models import *
from datetime import date

# Load datasets

In [None]:
data_dir = "../datasets/CIFAR10/cifar-10-batches-py"

In [None]:
train_images, train_filenames, train_labels, test_images, test_filenames, test_labels, label_names = load_cifar10(data_dir)

# Preprocessing Data

In [None]:
# Image scaling
train_images = train_images/255
test_images = test_images/255

# Calculate mean and std
"""
mean = []
std = []

for i in range(train_images.shape[-1]):
    mean.append(np.mean(train_images[:, :, :, i]))
    std.append(np.std(train_images[:, :, :, i]))
"""
mean = np.mean(train_images, axis=(0, 1, 2))
std = np.std(train_images, axis=(0, 1, 2))

In [None]:
train_transform = v2.Compose([
    v2.ToTensor(), # deprecated # Will transpose the image, since torch expected the input image to be of shape [H, W, C] -> [C, H, W]
    v2.Normalize(mean=mean, std=std),
    AddDropoutNoise(p=0.5),
    v2.ToDtype(torch.float32, scale=False),
])

In [None]:
train_dataset = CifarGenerativeDataset(images=train_images, transforms=train_transform)
test_dataset = CifarGenerativeDataset(images=test_images, transforms=train_transform)

# Define training configurations

In [None]:
# Model configuration
in_channels = 3
enc_hidden_channels = 16
emb_channels = 4
dec_hidden_channels = 16

# Training configuration
epochs = 25
batchsize = 128
learning_rate = 0.001
num_workers = 4
device = "cuda"
use_wandb = True
use_tensorboard = False

# Checkpoint dir
save_dir = ".\checkpoints"
exp_name = "denoising_conv_autoencoder_{}".format(str(date.today()))
save_dir = os.path.join(save_dir, exp_name)
if not os.path.exists(save_dir):
    os.system("mkdir {}".format(save_dir))

In [None]:
train_loader = get_loader(train_dataset, batchsize, num_workers, shuffle=True, drop_last=False)
test_loader = get_loader(test_dataset, batchsize, num_workers, shuffle=False, drop_last=False)

In [None]:
model = Conv_AutoEncoder(in_channels, enc_hidden_channels, emb_channels, dec_hidden_channels)

In [None]:
loss_func = nn.MSELoss()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
metrics = [
    # The first thre metrics only work on images
    metrics.FrechetInceptionDistance(device=device), # Only work on three-channel images
    # metrics.StructuralSimilarity(device=device),
    metrics.PeakSignalNoiseRatio(device=device),
    # metrics.MeanSquaredError(device=device)
]
metric_weights = [1.0, 0.0]

In [None]:
if use_wandb:
    # initialize wandb for usage
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project="denoising-conv-cifar-10-autoencoder", 
        # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
        name=exp_name, 
        # Track hyperparameters and run metadata
        config={
            "learning_rate": learning_rate,
            "architecture": str(model.__class__.__name__),
            "dataset": "MNIST",
            "epochs": epochs,
        })

In [None]:
train_epochs(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    loss_func=loss_func,
    metrics=metrics,
    metric_weights=metric_weights,
    device=device,
    num_epochs=epochs,
    log_rate=5,
    save_rate=10,
    save_dir=save_dir,
    use_tensorboard=use_tensorboard,
    use_wandb=use_wandb,
    interval=0
)