In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data.dataloader import DataLoader
from autoencoder_class import Autoencoder
import torch.nn as nn
import matplotlib.pyplot as plt
from autoencoder_functions import fit
from sklearn import metrics
import numpy as np

BATCH_SIZE = 5
EPOCHS = 10

# Transforms
transform = transforms.Compose([transforms.ToTensor()])

# Train_dataset
train_dataset = MNIST(
    root="../../data/MNIST",train=True, transform=transform, download=True
)
# Test_dataset
test_dataset = MNIST(
    root="../../data/MNIST", train=False, transform=transform, download=True
)
# Train_dataloader
train_dl = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
# Test_dataloader
test_dl = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4
)

# Iterator
dataiter = iter(train_dl)
# Runs through the batches of data
data=dataiter.next() 

#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Autoencoder object
model = Autoencoder()
# load it to the specified device, either gpu or cpu
model.to(device=device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-5)

# Calling fit method for training
outputs = fit(epochs=EPOCHS,model=model, criterion=criterion,optimizer=optimizer,
            train_dl=train_dl,test_dl=test_dl,metric=None)

In [None]:
for k in range (0,EPOCHS,1):
    plt.figure(figsize=(9,2))
    plt.gray()
    imgs = outputs[k][1].cpu().detach().numpy()
    imgs_recon =  outputs[k][2].cpu().detach().numpy()
    for batch_number,item in enumerate(imgs):
        if batch_number >= 9: break
        plt.subplot(2,9,batch_number + 1)
        item = item.reshape(-1,28,28)  # -1 is for the channel
        plt.imshow(item[0])
    for batch_number,item in enumerate(imgs_recon):
        if batch_number >= 9: break
        plt.subplot(2,9,9+batch_number + 1)
        item = item.reshape(-1,28,28)  # -1 is for the channel
        plt.imshow(item[0])
    
    score1=np.sqrt(metrics.mean_squared_error(imgs,imgs_recon))
    print(score1)
    break