### Continuos Training

#### Libraries

In [None]:
import os

import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###### create folder structure

In [47]:
folder_path = "eval/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

In [48]:
num_epochs = 1
batch_size = 8
learning_rate = 1e-3

In [49]:
folder_path = "../data/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

quicksave_path = "./quicksave/"
if not os.path.exists(quicksave_path):
    os.makedirs(quicksave_path)
    
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5], [0.5]
        )
    ]
)

dataset = datasets.MNIST(root=folder_path, train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size = batch_size, shuffle = True)

#### Load model

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )
        
        self.decoder = nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [51]:
model = torch.load('complete.pth')

lossFunction = nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay= 1e-8)

In [52]:
def create_checkpoint(model,idx,epoch, loss):

    filename =  quicksave_path + "{:0{}}".format(idx, len(str(num_epochs)))
    torch.save(model, filename + ".pth")

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss},
                filename)

### Training

In [53]:
start_epoch = 1
end_epoch = start_epoch + num_epochs

for epoch in range(start_epoch, end_epoch):
    losses=[]

    for (image, _) in loader:
        image = image.reshape(-1, 28*28).to(device)

        reconstructed = model(image)

        loss = lossFunction(reconstructed, image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss)
    
    if epoch % 10 == 0:
        # Show input and reconstructed images side by side
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
        axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
        axes[0].axis("off")
        axes[0].set_title("Input Image")
        axes[1].imshow(
            reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
        )
        axes[1].axis("off")
        axes[1].set_title("Reconstructed Image")
        plt.tight_layout()
        plt.savefig(os.path.join(folder_path, f"test_{epoch}.png"))  # Save the figure
        plt.show()
        plt.close()

        create_checkpoint(model, num_epochs, epoch, loss)