In [12]:
# !pip install cova
!pip install torch torchvision
!pip install matplotlib
!pip install numpy



In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
import covalent as ct

In [14]:
@ct.electron
def create_linear_ae():
    class LinearAutoencoder(nn.Module):
        """Autoencoder with 4 hidden layers."""

        def __init__(self):
            super(LinearAutoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Linear(28*28, 128),  # input size = 784 -> hidden size = 128
                nn.ReLU(True),
                nn.Linear(128, 64),  # hidden size = 128 -> hidden size = 64
                nn.ReLU(True),
                nn.Linear(64, 12),  # hidden size = 64 -> hidden size = 12
                nn.ReLU(True),
                nn.Linear(12, 3),  # hidden size = 12 -> output size = 3
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),  # input size = 3 -> hidden size = 12
                nn.ReLU(True),
                nn.Linear(12, 64),  # hidden size = 12 -> hidden size = 64
                nn.ReLU(True),
                nn.Linear(64, 128),  # hidden size = 64 -> hidden size = 128
                nn.ReLU(True),
                nn.Linear(128, 28*28),  # hidden size = 128 -> output size = 784
                nn.Sigmoid()  # output with pixel intensity in [0,1]
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded

    return LinearAutoencoder()

In [15]:
def create_conv_ae():
    class ConvAutoencoder(nn.Module):
        """Autoencoder with 3 hidden layers."""

        def __init__(self):
            super(ConvAutoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(1, 16, 3, stride=2, padding=1),  # input size = 1x28x28 -> hidden size = 16x14x14
                nn.ReLU(True),
                nn.Conv2d(16, 32, 3, stride=2, padding=1),  # hidden size = 16x14x14 -> hidden size = 32x7x7
                nn.ReLU(True),
                nn.Conv2d(32, 64, 7),  # hidden size = 32x7x7 -> hidden size = 64x1x1
            )

            self.decoder = nn.Sequential(
                nn.ConvTranspose2d(64, 32, 7),  # input size = 64x1x1 -> hidden size = 32x7x7
                nn.ReLU(True),
                nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # hidden size = 32x7x7 -> hidden size = 16x14x14
                nn.ReLU(True),
                nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),  # hidden size = 16x14x14 -> hidden size = 1x28x28
                nn.Sigmoid()  # output with pixels in [0,1]
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded

    return ConvAutoencoder()

In [16]:
@ct.electron
def data_loader(
    batch_size: int,
    train: bool,
    download: bool = True,
    shuffle: bool = False,
    transform: transforms.Compose = None,
) -> torch.utils.data.DataLoader:
    """
    Loads the Fashion MNIST dataset.

    Args:
        batch_size: The batch size.
        train: Whether to load the training or test set.
        download: Whether to download the dataset.
        shuffle: Whether to shuffle the dataset.
        transform: A transform to apply to the dataset.

    Returns:
        A DataLoader for the Fashion MNIST dataset.
    """
    if transform is None:
        transform = transforms.Compose([transforms.ToTensor()])

    dataset = datasets.FashionMNIST(
        "./data", train=train, download=download, transform=transform
    )

    return torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

In [17]:
@ct.electron
def train_model(
    model: nn.Module,
    lr: float,
    data_loader: torch.utils.data.DataLoader,
    epochs: int,
    log_interval: int,
):
    """
    Trains the given model on the Fashion MNIST dataset.

    Args:
        model: A PyTorch model.
        lr: The learning rate.
        data_loader: A DataLoader for the Fashion MNIST dataset.
        epochs: The number of epochs to train for.
        log_interval: The number of epochs to wait before logging in the outputs.

    Returns:
        train_loss: A list of training losses for each epoch.
        outputs: Contains epoch number, the original image, and the reconstructed image at each training step.
        model: The trained model.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()
    outputs = []
    train_loss = []
    print(f"Training {model.__class__.__name__}")
    for epoch in range(1, epochs + 1):
        running_loss = 0
        for (data, _) in data_loader:
            if model.__class__.__name__ == 'LinearAutoencoder':
                data = data.view(data.size(0), -1)
            recon = model(data)
            loss = nn.MSELoss()(recon, data)  # mean squared error loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss = running_loss / len(data_loader)
        train_loss.append(loss)
        if epoch % log_interval == 0:
            outputs.append((epoch, data, recon))
            print(f"Epoch {epoch}, loss: {loss}")
    return train_loss, outputs, model

In [18]:
@ct.electron
def test_model(
    model: nn.Module,
    data_loader: torch.utils.data.DataLoader,
):
    """
    Tests the given model on the Fashion MNIST dataset.

    Args:
        model: A PyTorch model.
        data_loader: A DataLoader for the Fashion MNIST dataset.

    Returns:
        avg_test_loss: The average loss for the test set.
    """
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for (data, _) in data_loader:
            if model.__class__.__name__ == 'LinearAutoencoder':
                data = data.view(data.shape[0], -1)
            recon = model(data)
            loss = nn.MSELoss()(recon, data)  # mean squared error loss
            test_loss += loss.item()
    avg_test_loss = test_loss / len(data_loader)
    print(f"Average test loss: {avg_test_loss}")
    return avg_test_loss

In [19]:
@ct.electron
def experiment(
    model: nn.Module,
    epochs: int,
    log_interval: int,
    batch_size_train: int = 64,
    batch_size_test: int = 1000,
    lr: float = 1e-3
):
    """
    Workflow of training and testing a given autoencoder (linear or convolutional) on the Fashion MNIST dataset.

    Args:
        model: A PyTorch model.
        epochs: The number of epochs to train for.
        log_interval: The number of epochs to wait before logging in the outputs.
        batch_size_train: The batch size for the training set.
        batch_size_test: The batch size for the test set.
        lr: The learning rate.

    Returns:
        train_loss: The training loss at each epoch.
        avg_test_loss: The average loss for the test set.
        outputs: Contains epoch number, the original image, and the reconstructed image at each training step.
    """
    train_loader = data_loader(batch_size=batch_size_train, train=True)
    test_loader = data_loader(batch_size=batch_size_test, train=False)

    train_loss, outputs, model = train_model(
        model=model,
        lr=lr,
        data_loader=train_loader,
        epochs=epochs,
        log_interval=log_interval,
    )
    avg_test_loss = test_model(model=model, data_loader=test_loader)

    return train_loss, avg_test_loss, outputs

In [20]:
@ct.lattice
def run_experiments(
    models: list,
    epochs: int,
    log_interval: int,
    batch_size_train: int = 64,
    batch_size_test: int = 1000,
    lr: float = 1e-3
):
    """
    Run experiments of training and testing a series of autoencoders on the Fashion MNIST dataset.

    Args:
        models: A list of PyTorch models.
        epochs: The number of epochs to train for.
        log_interval: The number of epochs to wait before logging in the outputs.
        batch_size_train: The batch size for the training set.
        batch_size_test: The batch size for the test set.
        lr: The learning rate.

    Returns:
        train_loss: The training loss at each epoch.
        avg_test_loss: The average loss for the test set.
        outputs: Contains epoch number, the original image, and the reconstructed image at each training step.
    """
    train_losses = []
    avg_test_losses = []
    full_outputs = []
    for model in models:
        train_loss, avg_test_loss, outputs = experiment(
            model=model,
            epochs=epochs,
            log_interval=log_interval,
            batch_size_train=batch_size_train,
            batch_size_test=batch_size_test,
            lr=lr,
        )
        train_losses.append(train_loss)
        avg_test_losses.append(avg_test_loss)
        full_outputs.append(outputs)

    return train_losses, avg_test_losses, full_outputs

In [21]:
dispatch_id = ct.dispatch(run_experiments)(models=[create_linear_ae(), create_conv_ae()], epochs=50, log_interval=10)
results = ct.get_result(dispatch_id=dispatch_id, wait=True)