<a href="https://colab.research.google.com/github/BitWeavre/DRL_Classification_Projects/blob/main/PPO_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Download data to Colab

In [None]:
# uncomment URLs to download all files
urls = [
    'https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz',
    'https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz',
    'https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz',
    'https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz',
    'https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz',
    'https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz',
    # 'https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz',
    # 'https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz',
    # 'https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz',
    # 'https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz',
    # 'https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz',
    # 'https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz'
]

Create the directory where all the data will be placed

In [None]:
import os

os.mkdir('all-data')

FileExistsError: ignored

Download the data, extract the file, place the images in ./all-data/ and remove the zip files to save space

In [None]:
for idx, url in enumerate(urls):
    filename = url.rsplit('/', maxsplit=1)[1]

    print(f'Downloading data file {idx + 1}/{len(urls)} ....')
    os.system(f'wget {url}')

    print(f'Extracting data file {idx + 1}/{len(urls)} ....')
    os.system(f'tar -xvzf {filename}')

    print(f'Moving images of data file {idx + 1}/{len(urls)} to ./all-data/ ....')
    os.system(f'mv images/* all-data/')

    print(f'Removing zipped data file {idx + 1}/{len(urls)} ....')
    os.unlink(filename)
    os.system(f'rm -rf images')

    print('\n')

In [None]:
num_downloads = len(os.listdir('all-data'))
print(f'Number of images download: {num_downloads}')

Install metadata file also

In [None]:
!wget https://raw.githubusercontent.com/ingus-t/SPAI/master/resources/Data_Entry_2017.csv

# Install packages

In [None]:
!pip install stable-baselines3[extra]

# Dataset

Imports

In [None]:
import torch
import pandas as pd
import torch.utils.data
import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Tuple, Dict, Optional
from torchvision.transforms import transforms

In [None]:
class NIHBreastCancerDataset(torch.utils.data.Dataset):
    """NIH Breast Cancer Dataset class that can be used by torch Dataloaders."""

    classes = [
        'Atelectasis',
        'Cardiomegaly',
        'Consolidation',
        'Edema',
        'Effusion',
        'Emphysema',
        'Fibrosis',
        'Hernia',
        'Infiltration',
        'Mass',
        'No Finding',
        'Nodule',
        'Pleural_Thickening',
        'Pneumonia',
        'Pneumothorax'
    ]
    class_to_id = {class_: id_ for id_, class_ in enumerate(classes)}

    def __init__(
            self,
            root_dir: Path,
            metadata_path: Path,
            data_to_use: Optional[Path] = None,
            image_resize_shape: Optional[Tuple[int, int]] = None
    ) -> None:
        super().__init__()

        # initialize attributes
        self.root_dir = root_dir
        self.metadata_path = metadata_path
        self.data_to_use = data_to_use
        self.image_resize_shape = image_resize_shape or (224, 224)

        # initialize the image transformations pipeline
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.image_resize_shape),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])

        # parse image paths into a list
        self._image_paths = self._parse_root_directory()
        # parse features and labels
        metadata = self._parse_metadata()
        self._img_name_to_features, self._img_name_to_labels = metadata

    def _parse_root_directory(self) -> List[Path]:
        """Parses the contents of the root directory, thus returning all the
            image paths."""
        # list to be returned
        src_image_paths = []

        """
        # set with data than can be used
        valid_image_names = set()
        if self.data_to_use is not None:
            with open(self.data_to_use, 'r') as fp:
                valid_image_names = set(fp.read().split('\n'))

        # get the directories that contain images
        src_dirs = sorted(list(self.root_dir.glob('images*')))
        for src_dir in src_dirs:

            # get the image paths
            actual_src_dir = src_dir/'images'
            img_paths = sorted(list(actual_src_dir.iterdir()))

            # remove unneeded data if specified
            if valid_image_names:
                img_paths = [
                    img_path
                    for img_path in img_paths
                    if img_path.name in valid_image_names
                ]

            # add the images in these directories to the complete list
            src_image_paths += img_paths
        """

        src_image_paths = sorted(list(self.root_dir.iterdir()))

        return src_image_paths

    def _parse_metadata(self) -> \
            Tuple[Dict[str, Tuple[int, int]], Dict[str, List[int]]]:
        """Parses the metadata file that contains features and the labels."""
        # read data and keep useful columns
        df = pd.read_csv(self.metadata_path)
        df = df[['Image Index', 'Finding Labels',
                 'Patient Age', 'Patient Gender']]

        # https://www.science.org/doi/10.1126/science.279.5358.1831h
        df = df[df['Patient Age'] <= 122]

        # convert male-female to binary value (male == 0, female == 1)
        df['Patient Gender'] = (df['Patient Gender'] == 'F').astype(int)

        # split labels into a list, and then convert them into a list of IDs
        df['Finding Labels'] = df['Finding Labels'].str.split('|')
        df['Finding Labels'] = df['Finding Labels'].apply(
            lambda labels: [self.class_to_id[label] for label in labels]
        )

        # create a dictionaries for the extra features
        img_names_and_features = df[['Image Index', 'Patient Age',
                                     'Patient Gender']].values
        img_names_to_features = {
            img_name: (age, gender)
            for img_name, age, gender in img_names_and_features
        }

        # and for the labels
        img_names_and_labels = df[['Image Index', 'Finding Labels']].values
        img_name_to_labels = {
            img_name: labels for img_name, labels in img_names_and_labels
        }

        return img_names_to_features, img_name_to_labels

    def _read_img(self, item: int) -> torch.Tensor:
        img = plt.imread(self._image_paths[item])
        if img.ndim == 3:
            img = img.mean(axis=2)
        img = self.transforms(img)
        return (img * 255).to(torch.uint8)

    def _get_features(self, item: int) -> torch.Tensor:
        img_name = self._image_paths[item].name
        return torch.tensor(self._img_name_to_features[img_name],
                            dtype=torch.uint8)

    def _get_labels(self, item: int) -> List[int]:
        img_name = self._image_paths[item].name
        return self._img_name_to_labels[img_name]

    @staticmethod
    def num_classes() -> int:
        return len(NIHBreastCancerDataset.classes)

    @property
    def input_shapes(self) -> Tuple[Tuple[int, int], Tuple[int]]:
        return self.image_resize_shape, (self._get_features(0).numel(),)

    def __len__(self) -> int:
        return len(self._image_paths)

    def __getitem__(
            self,
            item: int
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
        return self._read_img(item), self._get_features(item), \
            self._get_labels(item)

In [None]:
root_directory = Path('all-data')
metadata_path = Path('Data_Entry_2017.csv')

example_dataset = NIHBreastCancerDataset(
    root_dir=root_directory,
    metadata_path=metadata_path,
    data_to_use=None
)
len(example_dataset)

In [None]:
example_dataset[35]

# Environment

Imports

In [None]:
import gym
import numpy as np

from stable_baselines3.common.vec_env import DummyVecEnv
from typing import List, Tuple, Dict, Optional, Callable, Union

A basic reward function for the agent

In [None]:
def basic_reward_function(prediction: int, labels: List[int]) -> float:
    """A basic reward function; Returns +1 if the predicted labels is inside
        the list of truth labels, else -1."""
    return 1 if prediction in labels else -1

In [None]:
class NIHBreastCancerEnv(gym.Env):
    """Class used to represent the NIH Breast Cancer decision Environment."""

    metadata = {'render_modes': ['human']}

    def __init__(
        self,
        reward_function: Callable,
        dataset: NIHBreastCancerDataset,
        horizon: int = 100,
        seed: Optional[int] = None
    ) -> None:
        """
        :param reward_function: The function that computes the reward
            of an action given the labels.
        :param dataset: The dataset from which the env will sample images.
        :param horizon: The horizon (max number of steps) of an episode.
        :param seed: The random seed that is used for reproducibility. If None,
            a random seed will be used.
        """
        super(NIHBreastCancerEnv, self).__init__()
        self.seed(seed)

        # initialize important variables
        self.reward_function = reward_function
        self.dataset = dataset
        self.horizon = horizon
        self.default_seed = seed

        # initialize state and action spaces
        img_shape, features_shape = dataset.input_shapes
        num_classes = dataset.num_classes()
        self.observation_space = gym.spaces.Dict({
            'image': gym.spaces.Box(low=0,
                                    high=255,
                                    shape=(3, *img_shape),
                                    dtype=np.uint8),
            'features': gym.spaces.Box(low=np.array([0, 0]),
                                       high=np.array([122, 1]),
                                       shape=features_shape,
                                       dtype=np.uint8)
        })
        self.action_space = gym.spaces.Discrete(num_classes)
        self.action_space.seed(seed)

        # define the images that will appear in this episode
        self.states: List[int] = []

        # define the maximum episode length
        self.current_episode: int = 0

    def _initialize_episode_data(self, seed: Optional[int] = None) -> List[int]:
        """Creates a list with the indices of the samples to be used in
            the current episode."""
        generator = np.random.RandomState(seed=seed)
        indices = generator.choice(len(self.dataset), self.horizon, replace=False)
        return indices.tolist()

    @property
    def terminal_state(self) -> Dict[str, np.ndarray]:
        """Returns a dummy terminal state, where all features are 0."""
        img_shape, features_shape = self.dataset.input_shapes
        return {'image': np.zeros(img_shape),
                'features': np.zeros(features_shape)}

    def _sample(self, item: int) -> \
            Tuple[np.ndarray, np.ndarray, List[int]]:
        """Returns the processed data of a state."""
        assert item < len(self.states)
        image, features, labels = self.dataset[self.states[item]]
        return image.detach().cpu().numpy(), \
            features.detach().cpu().numpy(), \
            labels

    def observe(self) -> Dict[str, np.ndarray]:
        """Returns an observation of the current state."""
        image, features, _ = self._sample(0)
        return {'image': image, 'features': features}

    def done(self) -> bool:
        """Returns True if the environment is in a terminal state;
            Else False."""
        return len(self.states) == 0 or self.current_episode == self.horizon

    def reset(
            self,
            *,
            seed: Optional[int] = None,
            return_info: bool = False,
            options: Optional[dict] = None,
    ) -> Union[Dict[str, np.ndarray], Tuple[Dict[str, np.ndarray], Dict]]:
        """Resets the environment to its initial state and returns it."""
        self.states = self._initialize_episode_data(seed)
        self.current_episode = 0
        obs = self.observe()
        return (obs, {}) if return_info else obs

    def step(
            self,
            action: int
    ) -> Tuple[Dict[str, np.ndarray], float, bool, Dict]:
        """Performs one step in the environment, by applying the
            specified action."""
        # get the data of the current state and then remove it
        _, _, labels = self._sample(0)
        del self.states[0]

        # advance one time step
        self.current_episode += 1

        # the next state is the (image, features) in the start of the list
        next_state = self.observe() if not self.done() else self.terminal_state
        # get the reward from the reward function
        reward = self.reward_function(action, labels)
        # whether the environment has reached a terminal state
        done = True if self.done() else False
        # the info dictionary
        info = {}

        return next_state, reward, done, info

    def render(self, mode='human'):
        """Renders the environment. Used for visualization,
            currently not used."""
        pass

In [None]:
def create_env() -> NIHBreastCancerEnv:
    return NIHBreastCancerEnv(
        reward_function=basic_reward_function,
        dataset=example_dataset,
        horizon=1000
    )

num_vec_envs = 4
env_initializers = [create_env for _ in range(num_vec_envs)]
vec_envs = DummyVecEnv(env_initializers)

assert len(vec_envs.envs) == num_vec_envs
vec_envs.reset()

result = vec_envs.step(np.arange(num_vec_envs))
print(result[0]['image'].shape, result[0]['features'].shape)  # next states
print(result[1].shape)  # rewards
print(result[2].shape)  # dones
print(len(result[3]))  # infos

# Feature Extractor

Imports

In [None]:
import gym
import torch
import torch.nn as nn
import torch.fx as fx

from typing import Dict, Tuple, Union, Optional
from torchvision.models import feature_extraction
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

Wrapper model for a torch feature extractor

In [None]:
class FeatureSubnetWrapper(nn.Module):
    """Wraps a feature extractor so that it can be used directly without
        instantiating the desired layer."""

    def __init__(self, model: fx, until_layer: str) -> None:
        super(FeatureSubnetWrapper, self).__init__()
        self.model = model
        self.until_layer = until_layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)[self.until_layer]

Base CNN model used to extract features from images

In [None]:
class CustomCombinedExtractor(BaseFeaturesExtractor):
    """Custom Feature Extraction Module."""

    def __init__(
            self,
            observation_space: gym.spaces.Dict,
            device: torch.device,
            model_name: str,
            weights: Optional[str],
            until_layer: Union[int, str],
            freeze_pretrained_model: bool,
            img_dims: Tuple[int, int],
            num_dataset_features: int,
            n_classes: int
    ) -> None:
        super(CustomCombinedExtractor, self).__init__(observation_space,
                                                      features_dim=1)

        # define the device where the models will be placed on
        self.device = device

        # define some important variables
        self.img_dims = img_dims
        self.num_dataset_features = num_dataset_features
        self.n_classes = n_classes

        # define the feature extractors
        cnn_extractor, mlp_extractor = self._get_feature_extractors(
            model_name=model_name,
            weights=weights,
            freeze_pretrained_model=freeze_pretrained_model,
            until_layer=until_layer
        )
        extractors = {'image': cnn_extractor, 'features': mlp_extractor}
        self.extractors = nn.ModuleDict(extractors)

        # do a forward pass in the CNN to see output shape
        cnn_out_shape = self.get_model_out_size(
            model=cnn_extractor,
            img_dims=img_dims,
            device=device
        )
        # get the MLP output size in order to compute the final features dim
        mlp_out_size = mlp_extractor[-1].weight.shape[0]

        # features dimension
        self._features_dim = cnn_out_shape + mlp_out_size

    def forward(
            self,
            observations: Union[Dict[str, torch.Tensor], torch.Tensor]
    ) -> torch.Tensor:
        """Extract features and concatenate them in a Tensor."""
        # extract features and place them in a list
        extractors_out = [extractor(observations[key])
                          for key, extractor in self.extractors.items()]

        # concatenate extracted features to get out shape:
        #   (B, self._features_dim)
        latent_vector = torch.cat(extractors_out, dim=1)
        return latent_vector

    def load_pretrained_model(self, model_name: str, weights: str) -> nn.Module:
        """Loads or Downloads a pretrained CNN model and returns it."""
        
        # some issues arise with hubconf.py, catch the error, fix it and retry
        try:
            model = torch.hub.load('pytorch/vision', model_name, weights=weights)
        except:
            os.system("""sed -i 's/from torchvision.models import get_model_weights, get_weight/from torchvision.models import get_weight/g' ~/.cache/torch/hub/pytorch_vision_main/hubconf.py""")
            model = torch.hub.load('pytorch/vision', model_name, weights=weights)
        return model.to(self.device)

    @staticmethod
    def get_feature_subnet(
            model: nn.Module,
            until_layer: Union[int, str]
    ) -> nn.Module:
        """Given a pretrained model, it returns the model that constructs the
            features, having removed layers from the end."""
        # check for correctness and ensure "until_layer" is a string
        model_layers = feature_extraction.get_graph_node_names(model)[0]
        if isinstance(until_layer, str):
            if until_layer not in model_layers:
                raise RuntimeError(f'Layer "{until_layer}" not in model.')
        elif isinstance(until_layer, int):
            until_layer = model_layers[:-until_layer]
        else:
            raise RuntimeError(f'Wrong type for "until_layer" argument.')

        # get the feature extractor and return it
        feature_extractor = feature_extraction.create_feature_extractor(
            model,
            return_nodes=[until_layer]
        )
        return FeatureSubnetWrapper(feature_extractor, until_layer)

    @staticmethod
    def get_model_out_size(
            model: nn.Module,
            img_dims: Tuple[int, int],
            device: torch.device
    ) -> int:
        """Returns the output size that is produced by a model."""
        dims = (1, 3, img_dims[0], img_dims[1])
        return model(torch.zeros(dims).to(device)).shape[1]

    @staticmethod
    def get_layer_out_size(
            model: nn.Module,
            img_dims: Tuple[int, int],
            until_layer: Union[int, str],
            device: torch.device
    ) -> int:
        """Returns the output size that is produced by a layer of a model."""
        feature_subnet = CustomCombinedExtractor.get_feature_subnet(
            model,
            until_layer
        )
        return CustomCombinedExtractor.get_model_out_size(
            feature_subnet,
            img_dims,
            device
        )

    @staticmethod
    def freeze_model(model: nn.Module) -> None:
        """Freezes a PyTorch model, i.e. makes it non trainable."""
        for param in model.parameters():
            param.requires_grad = False
        model.eval()

    def _get_feature_extractors(
            self,
            model_name: str,
            weights: Optional[str],
            freeze_pretrained_model: bool = True,
            until_layer: Union[int, str] = None
    ) -> Tuple[nn.Module, nn.Module]:
        """Returns the NNs used to extract features from observations."""
        # load the model and get the output until the specified layer
        model = self.load_pretrained_model(model_name, weights=weights)
        feature_subnet = self.get_feature_subnet(
            model,
            until_layer=until_layer
        )

        # flatten the outputs (might not be needed)
        feature_subnet = nn.Sequential(
            feature_subnet,
            nn.Flatten()
        )

        # if specified, freeze the weights of the pre-trained model
        if freeze_pretrained_model:
            self.freeze_model(feature_subnet)

        # create the MLP for the other dataset features
        feat_mlp = nn.Sequential(
            nn.Linear(in_features=self.num_dataset_features, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=32),
        )

        return feature_subnet.to(self.device), feat_mlp.to(self.device)

# PPO

## PPO Policy

Imports

In [None]:
import gym
import torch

from torch import nn
from stable_baselines3.common.policies import ActorCriticPolicy
from typing import Optional, Union, Type, List, Tuple, Dict, Callable

In [None]:
class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the feature extractor.

    :param feature_dim: Dimension of the features extracted with the
        features_extractor (e.g. features from a CNN).
    :param last_layer_dim_pi: Number of neurons of the last layer of the
        policy network
    :param last_layer_dim_vf: Number of neurons of the last layer of the
        value network
    """

    def __init__(
            self,
            feature_dim: int,
            last_layer_dim_pi: int = 64,
            last_layer_dim_vf: int = 64,
    ):
        super(CustomNetwork, self).__init__()

        # save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # policy network
        self.policy_net = nn.Sequential(
            nn.Linear(in_features=feature_dim, out_features=last_layer_dim_pi),
            nn.ReLU()
        )
        # value network
        self.value_net = nn.Sequential(
            nn.Linear(in_features=feature_dim, out_features=last_layer_dim_vf),
            nn.ReLU()
        )

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :return: latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.policy_net(features), self.value_net(features)

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            lr_schedule: Callable[[float], float],
            net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
            activation_fn: Type[nn.Module] = nn.Tanh,
            device: torch.device = None,
            model_name: str = None,
            weights: str = None,
            until_layer: Union[str, int] = None,
            freeze_pretrained_model: bool = None,
            img_dims: Tuple[int, int, int] = None,
            num_dataset_features: int = None,
            n_classes: int = None,
            *args,
            **kwargs,
    ):
        features_extractor_kwargs = {
            'device': device,
            'model_name': model_name,
            'weights': weights,
            'until_layer': until_layer,
            'freeze_pretrained_model': freeze_pretrained_model,
            'img_dims': img_dims,
            'num_dataset_features': num_dataset_features,
            'n_classes': n_classes
        }
        super(CustomActorCriticPolicy, self).__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init=False,
            features_extractor_kwargs=features_extractor_kwargs,
            *args,
            **kwargs,
        )

        # disable orthogonal initialization
        self.ortho_init = False

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim).to(self.device)

## PPO Agent

Imports

In [None]:
import time
import torch
import numpy as np

from pathlib import Path
from stable_baselines3 import PPO
from typing import Tuple, Dict, Union, Optional
from stable_baselines3.common.vec_env import DummyVecEnv

In [None]:
class PPOAgent:
    """Wrapper class for the PPO agent of Stable Baselines 3."""

    def __init__(
            self,
            envs: DummyVecEnv = None,
            device: torch.device = None,
            cnn: str = None,
            weights: str = None,
            load_pretrained_agent: bool = False,
            pretrained_agent_path: Path = None,
            freeze_pretrained_model: bool = True,
            img_dims: Tuple[int, int] = (224, 224),
            num_dataset_features: int = 2,
            n_classes: int = 15,
            until_layer: Union[str, int] = None,
            n_steps: int = 32,
            batch_size: int = 32,
            gamma: float = 0.99,
            gae_lambda: float = 0.95,
            clip_range: float = 0,
            tensorboard_log: Optional[Path] = None,
            verbose: int = 1
    ) -> None:
        # set variables
        self.envs = envs

        # either load a pretrained agent or create a new one from scratch
        if load_pretrained_agent:
            self.model = PPO.load(pretrained_agent_path, device=device)
            print(f'Loaded model from path: {pretrained_agent_path}.')
        else:
            # specify model
            policy_kwargs = dict(
                features_extractor_class=CustomCombinedExtractor,
                device=device,
                model_name=cnn,
                weights=weights,
                freeze_pretrained_model=freeze_pretrained_model,
                img_dims=img_dims,
                num_dataset_features=num_dataset_features,
                n_classes=n_classes,
                until_layer=until_layer
            )
            self.model = PPO(
                CustomActorCriticPolicy,
                self.envs,
                n_steps=n_steps,
                batch_size=batch_size,
                gamma=gamma,
                gae_lambda=gae_lambda,
                clip_range=clip_range,
                policy_kwargs=policy_kwargs,
                verbose=verbose,
                tensorboard_log=tensorboard_log,
                device=device
            )

    def learn(self, timesteps: int) -> None:
        """Learn for a specific number of timesteps, unless
            the early stopping callback kicks in."""
        start = time.time()
        self.model.learn(total_timesteps=timesteps)
        print(f'\nTime for training: {time.time() - start:.2f} seconds.\n')

    def choose_action(self, observation: Dict[str, np.ndarray]) -> int:
        """Chooses (and returns) an action, given a specific observation."""
        return int(self.model.predict(observation)[0])

    def get_scores(self, observation: Dict[str, np.ndarray]) -> np.ndarray:
        """Returns the predicted action distribution (policy)
            for the given state."""
        observation, _ = self.model.policy.obs_to_tensor(observation)
        with torch.no_grad():
            features = self.model.policy.extract_features(observation)
            latent_pi, _ = self.model.policy.mlp_extractor(features)
            policy_net_out = self.model.policy.action_net(latent_pi)
            logits = torch.softmax(policy_net_out, dim=1)
            return logits[0].cpu().numpy()

    def save(self, save_path: Path) -> None:
        self.model.save(save_path)
        print(f'Saved the model at path: {save_path}.')

## PPO Main

Fixed variables

In [None]:
root_directory = Path('all-data')
metadata_path = Path('Data_Entry_2017.csv')
image_resize_shape = (384, 384)

Environment Hyperparameters

In [None]:
horizon = 1000
num_vec_envs = 2

Agent Hyperparameters

In [None]:
# https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.PPO

cnn_extractor_name = 'vit_b_16'
cnn_extractor_weights_name = 'IMAGENET1K_SWAG_E2E_V1'
until_layer = 'getitem_5'
total_timesteps = 32 * 200  # ToDo: CHANGE THIS TO TRAIN FOR LONGER :)
freeze_pretrained_cnn = True
n_steps = 32
batch_size = 32
gamma = 0.99
gae_lambda = 0.95
clip_range = 0.2

Tensorboard Logs directory

In [None]:
tensorboard_logdir = Path('PPO-logs')

Glue everything together

In [None]:
nih_dataset = NIHBreastCancerDataset(
    root_dir=root_directory,
    metadata_path=metadata_path,
    data_to_use=None,
    image_resize_shape=image_resize_shape
)

def create_env() -> NIHBreastCancerEnv:
    return NIHBreastCancerEnv(
        reward_function=basic_reward_function,
        dataset=nih_dataset,
        horizon=horizon
    )

env_initializers = [create_env for _ in range(num_vec_envs)]
vec_envs = DummyVecEnv(env_initializers)

Create the agent and start training

In [None]:
agent = PPOAgent(
    envs=vec_envs,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    cnn=cnn_extractor_name,
    weights=cnn_extractor_weights_name,
    freeze_pretrained_model=freeze_pretrained_cnn,
    img_dims=nih_dataset.input_shapes[0],
    n_classes=nih_dataset.num_classes(),
    until_layer=until_layer,
    n_steps=n_steps,
    batch_size=batch_size,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    tensorboard_log=tensorboard_logdir,
    verbose=1
)

In [None]:
agent.learn(timesteps=total_timesteps)

# Evaluation

Some utilities for evaluation

In [None]:
import numpy as np

from typing import Tuple, List


def get_predictions_on_test_env(
        agent: PPOAgent,
        test_env: NIHBreastCancerEnv
) -> Tuple[List[np.ndarray], List[List[int]]]:
    """Runs the agent on the test environment and returns the predictions
        along with the truth labels."""
    # lists to be returned
    predictions, truth_labels = [], []

    # run simulation
    done = False
    obs = test_env.reset()
    while not done:
        _, _, labels = test_env._sample(0)
        values = agent.get_scores(obs)
        predictions.append(values)
        truth_labels.append(labels)
        obs, _, done, _ = test_env.step(0)

    return predictions, truth_labels


def adapt_truth_labels(
        all_logits: List[np.ndarray],
        truth_labels: List[List[int]],
        num_classes: int
) -> np.ndarray:
    """Returns a new list of labels such that only one label exists for
        each data point. The label of a data point corresponds to the
        actual label if predicted correctly by the model, else a random
        label different from the one the model predicted."""
    new_labels = []
    for logits, labels in zip(all_logits, truth_labels):
        pred_label = logits.argmax()
        if pred_label in labels:  # predicted correctly -> argmax == truth
            new_labels.append(pred_label)
        else:  # predicted wrong -> argmax != truth
            wrong_labels = set(range(num_classes)).difference(labels)
            wrong_label = np.random.choice(list(wrong_labels))
            new_labels.append(wrong_label)
    return np.array(new_labels)

Create test dataset and test environment

In [None]:
# for demonstration purposes we use the same dataset as test set, in your institutional machine you HAVE TO change this
test_root_directory = Path('all-data')

nih_test_dataset = NIHBreastCancerDataset(
    root_dir=test_root_directory,
    metadata_path=metadata_path,
    data_to_use=None,
    image_resize_shape=image_resize_shape
)

test_env = NIHBreastCancerEnv(
    reward_function=basic_reward_function,
    dataset=nih_test_dataset,
    # horizon=len(test_dataset)  # Uncomment this line when evaluating
    horizon = 10  # this is for demonstration purposes, remove it when evaluating
)

Evaluate

In [None]:
# get the predictions of the model on the test set
all_logits, all_truth_labels = get_predictions_on_test_env(agent, test_env)

# get the actual actions with argmax
pred_labels = np.array([logits.argmax() for logits in all_logits])

# adapt the truth labels to match the multiclass setting for AUC
num_classes = len(NIHBreastCancerDataset.classes)
adapted_labels = adapt_truth_labels(
    all_logits,
    all_truth_labels,
    num_classes=num_classes
)

# convert the logits to a numpy array
all_logits = np.stack(all_logits)

Compute various metrics

In [None]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score


accuracy = accuracy_score(adapted_labels, pred_labels)
f1 = f1_score(adapted_labels, pred_labels, average='micro')
auc = roc_auc_score(
    adapted_labels,
    all_logits,
    average='macro',
    multi_class='ovo',
    labels=np.arange(len(NIHBreastCancerDataset.classes))
)

Print statistics

In [None]:
print('----------    PPO Performance    ----------')
print(f'\tAccuracy: {accuracy:.2f}')
print(f'\tF1 Score: {f1:.2f}')
print(f'\tAUC: {auc:.2f}')

Don't forget to check the tensorboard logs for the plots!

In [None]:
import time

time.sleep(5)

In [None]:
%load_ext tensorboard
%tensorboard --logdir PPO-logs

Save the model

In [None]:
save_path = Path('PPO-model')
agent.save(save_path)