# Image classification task

## Setup

In [8]:
FORCE_CPU = True

SEED = 349287

In [9]:
import os

import torch
from torch.utils.data import DataLoader as TorchDataLoader

import torchvision
from torchvision.datasets import VisionDataset, MNIST, CIFAR10, FashionMNIST
from torchvision import transforms

from tensordict.nn import (
    TensorDictModule,
    TensorDictModuleBase,
    TensorDictSequential,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import CompositeDistribution
from tensordict.tensordict import TensorDict, TensorDictBase

from tqdm import tqdm

from pvg import Parameters, ImageClassificationParameters, ScenarioType, TrainerType
from pvg.experiment_settings import ExperimentSettings
from pvg.scenario_base import DataLoader, Dataset
from pvg.constants import IC_DATA_DIR
from pvg.image_classification import IMAGE_DATASETS, ImageClassificationDataset

In [10]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [11]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Parameters

In [12]:
params = Parameters(
    scenario=ScenarioType.IMAGE_CLASSIFICATION,
    trainer=TrainerType.SOLO_AGENT,
    dataset="fashion_mnist",
    image_classification=ImageClassificationParameters(
        selected_classes=(0, 2),
    ),
)
params

Parameters(scenario=<ScenarioType.IMAGE_CLASSIFICATION: 'image_classification'>, trainer=<TrainerType.SOLO_AGENT: 'solo_agent'>, dataset='fashion_mnist', seed=6198, max_message_rounds=8, pretrain_agents=False, test_size=0.2, d_representation=16, prover_reward=1.0, verifier_reward=1.0, verifier_terminated_penalty=-1.0, agents=None, ppo=None, solo_agent=SoloAgentParameters(num_epochs=100, batch_size=256, learning_rate=0.001, body_lr_factor=None), image_classification=ImageClassificationParameters(selected_classes=(0, 2)))

In [13]:
settings = ExperimentSettings(device=device)

## Dataset

In [14]:
dataset = ImageClassificationDataset(params, settings)
dataset

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7179394.82it/s]


Extracting /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/train-images-idx3-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 1356433.27it/s]


Extracting /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 6938275.20it/s]


Extracting /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 10796138.50it/s]


Extracting /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /home/sam/Code/Projects/PVG Experiments/data/image_classification/fashion_mnist/raw/FashionMNIST/raw



ImageClassificationDataset(
    fields={
        image: MemoryMappedTensor(shape=torch.Size([12000, 1, 28, 28]), device=cpu, dtype=torch.float32, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([12000, 28, 28, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([12000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([12000]),
    device=cpu,
    is_shared=False)