# Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np

from Modules import Autoencoder
from Modules import LinearAutoencoder
from LossFunctions import tan_square, log_product

# Device

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
print(f"Using {device} device")

# Dataset

In [None]:
# We are only interested to train autoencoders, so we flatten the 2dim of the images and
# squeeze the single channel dimension
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: torch.squeeze(x)),
    transforms.Lambda(lambda x: torch.flatten(x, start_dim=-2))])

train_dataset = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)

# Train and test data tensors loaded on the GPU
# train_data = train_dataset.data.flatten(start_dim=-2).to(device)
# test_data = test_dataset.data.flatten(start_dim=-2).to(device)

In [None]:
"""figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
    img, label = train_dataset[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()"""

In [None]:
batch_size = 8

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for X, _ in train_dataloader:
    print(f"Shape of X: {X.shape}")
    break

# Identity Transformer

In [None]:
class IdentityTransformer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # https://stackoverflow.com/questions/60908827/why-pytorch-nn-module-cuda-not-moving-module-tensor-but-only-parameters-and-bu
        self.positional_encoding = nn.Parameter(torch.rand((28), requires_grad=True)) # TODO see if requires_grad=True is needed here
        self.flatten = nn.Flatten()
        self.transformer = nn.Transformer(
            d_model=28, 
            nhead=1, 
            num_encoder_layers=3, 
            num_decoder_layers=3, 
            batch_first=True) # https://discuss.pytorch.org/t/why-is-sequence-batch-features-the-default-instead-of-bxsxf/8244

    def forward(self, input):
        """Takes a batch of 28*28 pictures (N,28,28) # TODO verifier que c'est pas (N,1,28,28) que donne le dataloader
        We want to have (N,S,E)=(N,28,28), i.e. each row is a 'token' and the whole picture is a sentence."""
        input = torch.add(input, self.positional_encoding) # picture += positional encoding
        # input = self.flatten(input) # (N,28,28) -> (N,784)
        # input = torch.unsqueeze(input, 2) # (N,784) -> (N,784,1)
        input = torch.squeeze(input) # TODO lié au TODO d'en haut (N,1,28,28)->(N,28,28)
        output = self.transformer(input, input)
        # output = torch.reshape(output, input.shape) # (N,1,784) -> (N,28,28)
        return output

# identity_transformer = IdentityTransformer().to(device)

# loss_function = nn.MSELoss()
# optimizer = torch.optim.Adam(identity_transformer.parameters(), lr=1e-4) # TODO train learning rate during training phase r adjust it 
# https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

# Autoencoder

In [None]:
autoencoder = Autoencoder(dim_sequence=28*28, dim_bottleneck=2, dim_positional_encoding=2, num_heads=2, num_layers=2).to(device)
# linear_autoencoder = LinearAutoencoder(dim_input=28*28, dim_bottleneck=2, step= 10, activation_function='ReLU').to(device)
loss_function = nn.MSELoss()
# loss_function = log_product
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3, weight_decay=0.01)
# TODO tellement de trucs à fine-tuner ici...

# Train

In [None]:
def train_autoencoder(dataloader: DataLoader, model, loss_function, optimizer):
    """Right now, when iterated over, the dataloader returns a data point AND a label. 
    TODO faire en sorte que le dataloader soit composé que des images, sans les labels."""

    size = len(dataloader.dataset)
    model.train()
    for batch, (X, _) in enumerate(dataloader):
        X = X.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_function(pred, X)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_autoencoder(dataloader, model, loss_function):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, _ in dataloader:
            X = X.to(device)
            pred = model(X)
            test_loss += loss_function(pred, X).item()
    test_loss /= num_batches
    print(f"Test Error: Avg loss: {test_loss:>8f} \n")

def train_linear_autoencoder(data, model, loss_function, optimizer):
    pred = model(data)
    loss = loss_function(pred, data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("Train loss:", loss)

def test_linear_autoencoder(data, model, loss_function):
    with torch.no_grad():
        pred = model(data)
        loss = loss_function(pred, data)

    print("Test loss:", loss)

# Train loop

In [None]:
# Train linear autoencoder
epochs = 10
for t in range(epochs):
    print("Epoch:", t)
    train_linear_autoencoder(train_data, linear_autoencoder, loss_function, optimizer)
    test_linear_autoencoder(test_data, linear_autoencoder, loss_function)

In [None]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_autoencoder(train_dataloader, autoencoder, loss_function, optimizer)
    test_autoencoder(test_dataloader, autoencoder, loss_function)
print("Done!")

In [None]:
n = 5
fig, axs = plt.subplots(n, 2, figsize=(2,5))
for i in range(n):
    # Input
    random_index = np.random.randint(len(test_data))
    img, label = test_data[random_index]
    ax = axs[i, 0]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.axis("off")
    if i == 0: ax.set_title('Input')

    # Output
    input_image = img.to(device).squeeze().flatten(start_dim=-2)
    output_image = autoencoder(input_image)
    output_image = output_image.reshape(img.shape)

    ax = axs[i, 1]
    ax.imshow(output_image.cpu().detach().numpy().squeeze(), cmap="gray")
    ax.axis("off")
    if i == 0: ax.set_title('Output')
plt.subplots_adjust(wspace=0.0)
plt.show()

In [None]:
# TODO faire un environnement "try" ou quand on quitte l'environnement on enregistre le modèle dans l'état actuel, la training curve, etc
# vérifier que si Windows décide d'update il quitte python de façon clean et lui laisse le temps d'enregistrer et tout

# ---

In [None]:
import gc
del(X)
del(y)
del(autoencoder)
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

