# Image classification task

## Setup

In [1]:
FORCE_CPU = True

SEED = 349287

In [2]:
import os

import torch
from torch.utils.data import DataLoader as TorchDataLoader

import torchvision
from torchvision.datasets import VisionDataset, MNIST, CIFAR10, FashionMNIST
from torchvision import transforms

from tensordict.nn import (
    TensorDictModule,
    TensorDictModuleBase,
    TensorDictSequential,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import CompositeDistribution
from tensordict.tensordict import TensorDict, TensorDictBase

from tqdm import tqdm

from pvg import Parameters, ImageClassificationParameters, ScenarioType, TrainerType
from pvg.experiment_settings import ExperimentSettings
from pvg.scenario_base import DataLoader, Dataset
from pvg.constants import IC_DATA_DIR

In [3]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Parameters

In [5]:
params = Parameters(
    scenario=ScenarioType.IMAGE_CLASSIFICATION,
    trainer=TrainerType.SOLO_AGENT,
    dataset="fashion-mnist",
    image_classification=ImageClassificationParameters(
        selected_classes=(0, 2),
    ),
)
params

Parameters(scenario=<ScenarioType.IMAGE_CLASSIFICATION: 'image_classification'>, trainer=<TrainerType.SOLO_AGENT: 'solo_agent'>, dataset='fashion-mnist', seed=6198, max_message_rounds=8, prover_reward=1.0, verifier_reward=1.0, verifier_terminated_penalty=-1.0, agents=None, ppo=None, solo_agent=SoloAgentParameters(num_epochs=500, batch_size=256, learning_rate=0.001, body_lr_factor=0.01, test_size=0.2), image_classification=ImageClassificationParameters(selected_classes=(0, 1)))

In [6]:
settings = ExperimentSettings(device=device)

## Dataset

In [7]:
class ImageClassificationDataset(Dataset):
    """A dataset for the image classification task.

    Uses a torchvision dataset, and removes all the classes apart from two (determined
    by `params.image_classification.selected_classes`).
    """

    x_dtype = torch.float32
    y_dtype = torch.int64

    dataset_class_map: dict[str, VisionDataset] = {
        "mnist": MNIST,
        "fashion-mnist": FashionMNIST,
        "cifar10": CIFAR10,
    }

    def _build_tensor_dict(self) -> TensorDict:
        # Load the dataset
        dataset_class = self.dataset_class_map[self.params.dataset]
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        torch_dataset = dataset_class(
            self.raw_dir, train=True, transform=transform, download=True
        )

        # Get the whole dataset as a single batch
        full_dataset_loader = TorchDataLoader(
            torch_dataset, batch_size=len(torch_dataset)
        )
        images, labels = next(iter(full_dataset_loader))

        # Select the classes we want for binary classification
        selected_classes = self.params.image_classification.selected_classes
        index = (labels == selected_classes[0]) | (labels == selected_classes[1])
        images = images[index]
        labels = labels[index]
        labels = (labels == selected_classes[1]).to(self.y_dtype)

        # Create the pixel features, which are all zeros
        x = torch.zeros(
            images.shape[0],
            *images.shape[-2:],
            self.params.max_message_rounds,
            dtype=self.x_dtype,
        )

        return TensorDict(
            dict(image=images, x=x, y=labels), batch_size=images.shape[0]
        )

    @property
    def raw_dir(self) -> str:
        """The path to the directory containing the raw data."""
        return os.path.join(IC_DATA_DIR, self.params.dataset, "raw")

    @property
    def processed_dir(self) -> str:
        """The path to the directory containing the processed data."""
        selected_classes = self.params.image_classification.selected_classes
        return os.path.join(
            IC_DATA_DIR,
            self.params.dataset,
            (
                f"processed_{self.params.max_message_rounds}"
                f"_{selected_classes[0]},{selected_classes[1]}"
            ),
        )

In [8]:
dataset = ImageClassificationDataset(params, settings)
dataset

ImageClassificationDataset(
    fields={
        image: MemoryMappedTensor(shape=torch.Size([12000, 1, 28, 28]), device=cpu, dtype=torch.float32, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([12000, 28, 28, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([12000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([12000]),
    device=None,
    is_shared=False)