# section 1.2.1

In [None]:
import utils
import torch
import encoder1
import random
import numpy as np
from torch.utils.data import DataLoader, random_split
from utils import plot_tsne
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import random
MNIST = False

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder = encoder1.Autoencoder(mnist=MNIST).to(device)
classifier = encoder1.Classifier(autoencoder.encoder, num_classes=10).to(device)
def freeze_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

freeze_seeds(0)
mean = [0.5, 0.5, 0.5] if not MNIST else [0.5]
std = [0.5, 0.5, 0.5] if not MNIST else [0.5]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)  #one possible convenient normalization. You don't have to use it.
])
if MNIST:
    train_dataset = datasets.MNIST(root="/datasets/cv_datasets/data", train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST(root="/datasets/cv_datasets/data", train=False, download=False, transform=transform)
else:
    train_dataset = datasets.CIFAR10(root="/datasets/cv_datasets/data", train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root="/datasets/cv_datasets/data", train=False, download=True, transform=transform)
         

: 

In [None]:
batch_size=8
validation_split = 0.2
validation_size = int(len(train_dataset) * validation_split)
train_size = len(train_dataset) - validation_size

train_subset, val_subset = random_split(train_dataset, [train_size, validation_size])


# Create data loaders
train_loader = DataLoader(
    train_subset, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=2, 
)

val_loader = DataLoader(
    val_subset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
)

: 

In [None]:
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy of the network on the test images: {accuracy * 100:.2f}%')
    return accuracy


In [None]:
# Load the model
# load the model
def plot_reconstructions(autoencoder, data_loader, device, save_path='reconstructions.png'):
    autoencoder.eval()  # Set the model to evaluation mode
    images, _ = next(iter(data_loader))  # Get a batch of images
    images = images.to(device)
    
    # Pass the images through the autoencoder
    with torch.no_grad():
        reconstructions = autoencoder(images)
    
    # Move images and reconstructions to CPU for plotting
    images = images.cpu()
    reconstructions = reconstructions.cpu()
    
    # Randomly select 5 images
    indices = random.sample(range(len(images)), 5)
    selected_images = images[indices]
    selected_reconstructions = reconstructions[indices]
    
    # Plot the original images and their reconstructions
    fig, axes = plt.subplots(5, 2, figsize=(10, 15))
    for i, (original, reconstructed) in enumerate(zip(selected_images, selected_reconstructions)):
        # Original image
        axes[i, 0].imshow(original.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
        axes[i, 0].set_title("Original")
        axes[i, 0].axis("off")
        
        # Reconstructed image
        axes[i, 1].imshow(reconstructed.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
        axes[i, 1].set_title("Reconstructed")
        axes[i, 1].axis("off")
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Reconstructions saved to {save_path}")

def interpolate_and_plot(autoencoder, data_loader, device, save_path='interpolation.png'):
    autoencoder.eval()  # Set the model to evaluation mode
    images, _ = next(iter(data_loader))  # Get a batch of images
    images = images.to(device)
    
    # Randomly select two images
    indices = random.sample(range(len(images)), 2)
    image1 = images[indices[0]].unsqueeze(0)  # Add batch dimension
    image2 = images[indices[1]].unsqueeze(0)  # Add batch dimension
    
    # Encode the images
    with torch.no_grad():
        latent1 = autoencoder.encoder(image1)  # Assuming the encoder is accessible as `autoencoder.encoder`
        latent2 = autoencoder.encoder(image2)
    
    # Perform linear interpolation
    interpolations = []
    steps = 10
    for alpha in torch.linspace(0, 1, steps):
        interpolated_latent = (1 - alpha) * latent1 + alpha * latent2
        interpolations.append(interpolated_latent)
    
    # Decode the interpolated latents
    decoded_images = []
    with torch.no_grad():
        for latent in interpolations:
            decoded_image = autoencoder.decoder(latent)  # Assuming the decoder is accessible as `autoencoder.decoder`
            decoded_images.append(decoded_image.squeeze(0))  # Remove batch dimension
    
    # Move decoded images to CPU for plotting
    decoded_images = [img.cpu() for img in decoded_images]
    image1 = image1.cpu().squeeze(0)  # Remove batch dimension
    image2 = image2.cpu().squeeze(0)
    
    # Plot the original images and interpolations
    fig, axes = plt.subplots(1, steps + 2, figsize=(15, 5))
    axes[0].imshow(image1.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    axes[0].set_title("Image 1")
    axes[0].axis("off")
    
    for i, img in enumerate(decoded_images):
        axes[i + 1].imshow(img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
        axes[i + 1].set_title(f"Step {i + 1}")
        axes[i + 1].axis("off")
    
    axes[-1].imshow(image2.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    axes[-1].set_title("Image 2")
    axes[-1].axis("off")
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Interpolation saved to {save_path}")

def compute_mean_reconstruction_error(autoencoder, data_loader, device):
    autoencoder.eval()
    total_error = 0
    total_samples = 0
    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
            reconstructions = autoencoder(images)
            error = torch.abs(images - reconstructions).mean(dim=(1, 2, 3))  # MAE per image
            total_error += error.sum().item()
            total_samples += images.size(0)
    mean_error = total_error / total_samples
    print(f"Mean Reconstruction Error (MAE): {mean_error:.4f}")
    return mean_error

: 

In [None]:
if MNIST:
    autoencoder.load_state_dict(torch.load('autoencoder_MNIST.pth'))
    classifier.load_state_dict(torch.load('classifier_MNIST.pth'))
else:
    autoencoder.load_state_dict(torch.load('autoencoder_CIFAR10.pth'))
    classifier.load_state_dict(torch.load('classifier_CIFAR10.pth'))

# Set the model to evaluation mode
# Test the model
#plot_reconstructions(autoencoder, test_loader, device, save_path='reconstructions.png')
compute_mean_reconstruction_error(autoencoder, test_loader, device)
#test(classifier, test_loader, device)