In [1]:
from PIL import Image
import os
from typing import Callable, List, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from dataclasses import dataclass

Let's define a class to store model arguments for easier parameter testing, and another one to store parameters describing input format.

In [2]:
@dataclass
class ModelArgs:
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size: int = 256
    epochs: int = 20
    learning_rate: float = 0.001
    criterion: nn.Module = nn.L1Loss()


@dataclass
class InputArgs:
    transform: Callable = transforms.Compose([
        transforms.ToTensor()  # Transform to [0, 1] range
    ])
    source_dir: str = os.getcwd() + "/data/example" # Directory where data is stored
    test_size: float = 0.2  # Ratio for train-test split
    crop_percentage: float = 0.2  # Percentage of the image to crop
    depth: int = 1
    use_grid: bool = False  # Whether to divide frames into grid or center-crop


First, we will prepare the data.

In [3]:

def get_sorted_image_paths(source_dir: str) -> List[str]:
    """
    Retrieve all .jpg file paths from the given directory and subdirectories,
    sorted lexicographically.
    """
    image_paths = []
    for root, dirs, files in os.walk(source_dir):
        for file in sorted(files):
            if file.lower().endswith('.jpg'):
                image_paths.append(os.path.join(root, file))
    return image_paths

def load_images(image_paths: List[str], depth: int) -> torch.Tensor:
    """
    Load images from the given paths, stack them into a tensor of the specified depth
    """
    images = []
    for path in image_paths:
        img = Image.open(path).convert('RGB')
        img_tensor = transforms.ToTensor()(img)
        images.append(img_tensor)
        if len(images) == depth:
            yield torch.stack(images)
            images = []
    if images:
        yield torch.stack(images)

def crop_into_grid(image: torch.Tensor, crop_percentage: float) -> torch.Tensor:
    """
    Crop the image into a grid based on the given crop percentage and return a 2D tensor
    of the cropped parts.
    """
    _, h, w = image.shape
    crop_size = int(min(h, w) * crop_percentage)
    grid = []
    for i in range(0, h - crop_size + 1, crop_size):
        for j in range(0, w - crop_size + 1, crop_size):
            cropped = image[:, i:i+crop_size, j:j+crop_size]
            grid.append(cropped)
    return torch.stack(grid)

def crop_center(image: torch.Tensor, crop_percentage: float) -> torch.Tensor:
    """
    Crop the center of the image based on the given crop percentage.
    """
    _, h, w = image.shape
    crop_size = int(min(h, w) * crop_percentage)
    top = (h - crop_size) // 2
    left = (w - crop_size) // 2
    return image[:, top:top+crop_size, left:left+crop_size]

class ImageDataset(Dataset):
    def __init__(self, image_paths: List[str], transform: transforms.Compose = None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')
        img_tensor = transforms.ToTensor()(img)
        if self.transform:
            img_tensor = self.transform(img_tensor)
        return img_tensor

def process_images(source_dir: str, crop_percentage: float, batch_size: int, test_size: float, use_grid: bool):
    """
    Process images from the source directory: load, crop, and split into train and test sets.
    """
    image_paths = get_sorted_image_paths(source_dir)
    if use_grid:
        dataset = ImageDataset(image_paths, transform=lambda x: crop_into_grid(x, crop_percentage))
    else:
        dataset = ImageDataset(image_paths, transform=lambda x: crop_center(x, crop_percentage))
    
    # Calculate the sizes for train and test datasets
    total_size = len(dataset)
    test_size = int(total_size * test_size)
    train_size = total_size - test_size

    # Split the dataset into train and test sets
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# TODO : now depth is always = 1; change it!
# (now we only consider the neighbourhood in the given time, not looking at past frames)
# TODO : add the possibility to keep the original or close to original sizes, but with lower resolution

# Example usage:
params = InputArgs
model_params = ModelArgs
train_loader, test_loader = process_images(params.source_dir, params.crop_percentage, model_params.batch_size, params.test_size, params.use_grid)

In [5]:
class Autoencoder(nn.Module):
    def __init__(self, sizes: Optional[List[int]] = None, kernel_sizes: Optional[List[int]] = None,
                 strides: Optional[List[int]] = None, paddings: Optional[List[int]] = None):
        super(Autoencoder, self).__init__()

        # Default values for sizes, kernel sizes, strides, and paddings if not provided
        sizes = sizes or [64, 128, 256, 512, 1024, 2048, 4096]
        kernel_sizes = kernel_sizes or [4] * (len(sizes) - 1)
        strides = strides or [2] * (len(sizes) - 1)
        paddings = paddings or [1] * (len(sizes) - 1)

        # Encoder layers
        self.encoder = []
        in_channels = 3  # Start with 3 input channels (for RGB images)
        # But TODO here the past frames should contribute to more in_channels
        for i in range(len(sizes) - 1):
            self.encoder.append(
                nn.Conv2d(in_channels, sizes[i], kernel_size=kernel_sizes[i], stride=strides[i], padding=paddings[i])
            )
            self.encoder.append(nn.BatchNorm2d(sizes[i]))
            self.encoder.append(nn.ReLU())
            in_channels = sizes[i]

        self.encoder = nn.Sequential(*self.encoder)

        # Decoder layers
        self.decoder = []
        for i in range(len(sizes) - 2, -1, -1):
            self.decoder.append(
                nn.ConvTranspose2d(in_channels, sizes[i], kernel_size=kernel_sizes[i], stride=strides[i], padding=paddings[i])
            )
            self.decoder.append(nn.BatchNorm2d(sizes[i]))
            self.decoder.append(nn.ReLU())
            in_channels = sizes[i]

        self.decoder.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=kernel_sizes[-1], stride=strides[-1], padding=paddings[-1]))
        self.decoder.append(nn.Sigmoid())  # Final output layer - outputs from the range [0,1]

        self.decoder = nn.Sequential(*self.decoder)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    def get_latent_space(self, x):
        """
        Extract the latent space representation from the encoder.
        """
        with torch.no_grad():
            latent = self.encoder(x)
        return latent


Training the model:

In [None]:
# With default architecture
model = Autoencoder()#.to(model_params.device)
model.to(model_params.device)
model.train()

for epoch in range(model_params.epochs):
    train_loss = 0
    for images in train_loader:
        images = images.to(model_params.device)

        outputs = model(images)
        loss = model_params.criterion(outputs, images)

        model_params.optimizer.zero_grad()
        loss.backward()
        model_params.optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f'Epoch [{epoch+1}/{model_params.epochs}], Loss: {train_loss:.4f}')

Testing and visualisation:

In [None]:
model.eval()
with torch.no_grad():
    examples = next(iter(test_loader))
    test_images, _ = examples
    test_images = test_images.to(model_params.device)
    reconstructed = model(test_images)

    test_images = test_images.clamp(0, 1)
    reconstructed = reconstructed.clamp(0, 1)

    # Visualisation of the reconstruction
    n = 4
    plt.figure(figsize=(20, 10))
    for i in range(n):
        # Original cropped images
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(test_images[i].cpu().permute(1, 2, 0))
        plt.title("Original")
        plt.axis('off')

        # Reconstructed images
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(reconstructed[i].cpu().permute(1, 2, 0))
        plt.title("Reconstruction")
        plt.axis('off')

    plt.show()

TODO: Getting the statistics and testing out different parameters 