# 0. Depdencies & Imports

In [None]:
!pip install torch torchvision torchviz matplotlib numpy tqdm scikit-image tensorboard_logger

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Normal
from torch.utils.data import Dataset
from torchviz import make_dot

import skimage.measure

import matplotlib.pyplot as plt
from matplotlib import patches
import numpy as np


import unittest
import time
import shutil
import os
import math

import pickle

from tqdm.notebook import tqdm

from enum import Enum

In [None]:
from enum import Enum

class DatasetType(Enum):
    TRAIN = 1
    VALID = 2
    TEST  = 3

class DatasetName(Enum):
    MNIST = 1
    AUGMENTED = 2
    TRANSFORMED = 3
    AUGMENTED_MEDICAL = 4
    CLOSED_SQUARES = 5


In [None]:
class Config():
    def __init__(self):
        # glimpse network params
        self.patch_size      = 12         # size of extracted patch at highest res
        self.glimpse_scale   = 2         # scale of successive patches
        self.num_patches     = 5         # Num of downscaled patches per glimpse
        self.loc_hidden      = 128       # hidden size of loc fc layer
        self.glimpse_hidden  = 128       # hidden size of glimpse fc

        # core network params
        self.num_glimpses    = 6         # Num of glimpses, i.e. BPTT iterations
        self.hidden_size     = 256       # hidden size of rnn

        # reinforce params
        self.std             = 0.11      # gaussian policy standard deviation
        self.M               = 1         # Monte Carlo sampling for valid and test sets
        self.reward_multi    = 1         # reward multiplier(0-1] setting it to values <1
                                         # should make the policy less flakey ( I think. let's see the effects)

        # action network
        self.num_classes     = 10         # the number of classes

        # ETC params
        self.valid_size      = 0.1       # Proportion of training set used for validation
        self.batch_size      = 100       # Num of images in each batch of data
        self.num_workers     = 4         # Num of subprocesses to use for data loading
        self.shuffle         = True      # Whether to shuffle the train and valid indices
        self.show_sample     = False     # Whether to visualize a sample grid of the data

        # training params
        self.is_train        = True      # Whether to train(true) or test the model
        self.resume          = False     # Whether to resume training from checkpoint
        self.weight_decay    = 1e-5      # Weight decay for regularization
        self.momentum        = 0.5       # Nesterov momentum value TODO not used
        self.epochs          = 500       # Num of epochs to train for
        self.init_lr         = 0.001      # Initial learning rate value
        self.lr_patience     = 50        # Number of epochs to wait before reducing lr
        self.train_patience  = 100       # Number of epochs to wait before stopping train

        # other params
        self.use_gpu         = True      # Whether to run on the GPU
        self.best            = True      # Load best model or most recent for testing
        self.random_seed     = 1         # Seed to ensure reproducibility
        self.data_dir        = "./data"  # Directory in which data is stored
        self.ckpt_dir        = "./ckpt"  # Directory in which to save model checkpoints
        self.logs_dir        = "./logs/" # Directory in which Tensorboard logs wil be stored
        self.use_tensorboard = False     # Whether to use tensorboard for visualization
        self.print_freq      = 100       # How frequently to print training details
        self.plot_freq       = 1         # How frequently to plot glimpses
        self.dataset         = DatasetName.AUGMENTED_MEDICAL
        self.model_name      = "ram_{}_{}x{}_{}".format(
            self.num_glimpses,
            self.patch_size,
            self.patch_size,
            self.glimpse_scale,
        )

In [None]:
from torch.utils.data import Dataset
import numpy as np
import torch
import os
import torchvision
class AugmentedMedicalMNISTDataset(Dataset):
    """
    Augmented mnist meant to mimic whole-slide-images of tumor cells.
    9's represent cancer cells. There are 4 different labels, based on the number of 9's:

    zero 9's          - no cancer
    one 9             - isolated tumor cell
    two 9's           - micro-metastasis 
    three or more 9's - macro-metastasis

    Each image contains between 3 and 10 cells at random, which may be overlapping.
    It consists of 5000 items of each category(total 20.000) for training and 500(2.000) of each for testing
    of size 256 x 256. 
    """

    def __init__(self,
                 root_dir,
                 train,
                 data_dir="MEDNIST",
                 mnist_transform=None,
                 transform=None,
                 total_train=20000,
                 total_test=2000,
                 n_partitions_test=1,
                 n_partitions_train=5):

        self.mnist_transform = mnist_transform
        self.root_dir = root_dir
        self.train = train
        self.total = total_train if self.train else total_test
        self.n_partitions_test = n_partitions_test
        self.n_partitions_train = n_partitions_train
        self.dir = os.path.join(root_dir, data_dir, "train" if train else "test")
        self.transform = transform

        self.__create_dataset_if_needed()

        self.__load_data()

    def __dataset_exists(self):
        # mkdir if not exists
        os.makedirs(self.dir, exist_ok=True)
        len_files = len(os.listdir(self.dir))
        if len_files > 0:
            print("Data existing, skipping creation.")
            return True
        else:
            print("Dataset missing. Creating...")
        return False

    def __combine_images(self, images, output_dim):
        """
        Combines the given images into a single image of output_dim size. Combinations are done randomly and 
        overlapping is possible. Images will always be within bounds completely.
        """
        np_images = np.array(images)
        input_dim = np_images.shape[-1]
        new_image = np.zeros(shape=(output_dim, output_dim), dtype=np.float32)
        for image in np_images:
            i, j = np.random.randint(0, output_dim - input_dim, size=2)
            new_image[i:i + input_dim, j:j + input_dim] = image
        return new_image

    def __get_cell_counts(self, items_per_class_count, class_index):
        # exclusive
        max_items = 11
        min_number_of_cells = 3
        # 0,1,2,3+ for no tumor cells, isolated tumor cells, 
        # micro-metastasis and macro-metastasis respectively
        num_tumor_cells = class_index if class_index != 3 else np.random.randint(3, max_items)

        num_healthy_cells = max_items - num_tumor_cells
        if num_healthy_cells + num_tumor_cells < min_number_of_cells:
            num_healthy_cells = min_number_of_cells - num_tumor_cells

        return (num_tumor_cells, num_healthy_cells)

    def __generate_for_class(self,
                             items,
                             items_per_class_count,
                             class_index,
                             uid,
                             all_tumor_cell_images,
                             all_healthy_cell_images):
        for _ in range(items_per_class_count):
            num_tumors, num_healthy = self.__get_cell_counts(items_per_class_count, class_index)

            healthy_idxs = np.random.randint(0, len(all_healthy_cell_images), num_healthy)
            tumor_idxs = np.random.randint(0, len(all_tumor_cell_images), num_tumors)

            healthy_cells = all_healthy_cell_images[healthy_idxs]
            tumor_cells = all_tumor_cell_images[tumor_idxs]
            cells = np.vstack((healthy_cells, tumor_cells))
            image = self.__combine_images(cells, 256)
            image = np.expand_dims(image, axis=0)
            self.data.append(image)
            self.source_images.append(tumor_cells.numpy())
            self.labels.append(class_index)
            uid += 1
        return uid

    def __create_dataset_if_needed(self):
        if self.__dataset_exists():
            return

        self.data = []
        self.labels = []
        self.source_images = []

        # in how many partitions to split dataset creation
        partitions_count = 10

        # number of classes in output (fixed)
        num_classes = 4

        mnist = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           download=True,
                                           transform=self.mnist_transform)

        mnist_loader = iter(torch.utils.data.DataLoader(mnist,
                                                        batch_size=int(self.total / partitions_count),
                                                        shuffle=False,
                                                        num_workers=0))
        uid = 0
        batch, mnist_labels = mnist_loader.next()
        # 9's represent tumors
        all_tumor_cell_images = batch[mnist_labels == 9]
        # everything else except 6's healthy cells
        all_healthy_cell_images = batch[(mnist_labels != 9) & (mnist_labels != 6)]

        for _ in range(partitions_count):
            items_per_class_count = int(self.total / (num_classes * partitions_count))

            for class_index in range(num_classes):
                uid = self.__generate_for_class(class_index,
                                                items_per_class_count,
                                                class_index,
                                                uid,
                                                all_tumor_cell_images,
                                                all_healthy_cell_images)
        self.__store()
        print("Done.")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, uid):
        if torch.is_tensor(uid):
            uid = uid.tolist()
        label = self.labels[uid]
        sample = self.data[uid]
        if self.transform:
            sample = self.transform(sample)

        return (sample, label)

    def __store(self):
        n_partitions = self.n_partitions_train if self.train else self.n_partitions_test

        assert (len(self.data) == len(self.labels))
        max_index = len(self.data)
        partition_size = max_index / n_partitions
        for i in range(n_partitions):
            start, end = (int(partition_size * i), int(partition_size * (i + 1)))
            partition = np.array(self.data[start:end])
            np.save(os.path.join(self.dir, "part_" + str(i)), partition)

        np.save(os.path.join(self.dir, "labels"), np.array(self.labels))

        if not self.train:
            np.save(os.path.join(self.dir, "sources"), np.array(self.source_images))

    def __load_data(self):
        n_partitions = self.n_partitions_train if self.train else self.n_partitions_test
        data = []
        for i in range(n_partitions):
            data.append(np.load(os.path.join(self.dir, "part_" + str(i) + ".npy")))
        self.data = np.vstack(data)
        self.labels = np.load(os.path.join(self.dir, "labels.npy"))

In [None]:
from torch.utils.data import Dataset
import numpy as np
import torch
class ClosedSquaresDataset(Dataset):
    """Binary: number of not closed squares"""

    def __init__(self,
                 train,
                 size=64,
                 object_width=3,
                 n_missing=2,
                 n_classes=4,
                 n_circles=6,
                 total_train=16000,
                 total_test=1600):
        if train:
            np.random.seed(1)
        else:
            np.random.seed(2)
        self.n = total_train if train else total_test
        self.__create_data(n_classes, n_circles, size, object_width, n_missing)

    def __create_data(self, n_classes, n_circles, size, object_width, n_missing):
        self.labels = []
        self.data = []

        for class_i in range(n_classes):
            for _ in range(int(self.n / n_classes)):
                image = self.__generate_image(class_i, n_circles, size, object_width, n_missing)
                self.data.append(torch.tensor(image))
                self.labels.append(class_i)

    def __generate_image(self, n_open, n_all, size, object_width, n_missing):
        image = np.zeros((size, size))
        # top left x,y positions within bounds
        top_lefts = (np.random.rand(n_all, 2) * (size - (object_width + 2))).astype(int)
        # ensure no overlapping
        for top_left in top_lefts:
            x_0, y_0 = top_left[0], top_left[1]
            # 1 bigger so no overlaps
            image[x_0: x_0 + object_width + 2, y_0:y_0 + object_width + 2] += 1
        # make sure no overlapping
        is_valid = np.all(image <= 1)
        if is_valid:
            image = np.zeros((size, size)).astype(np.float32)
            for i, top_left in enumerate(top_lefts):
                x_0, y_0 = top_left[0] + 1, top_left[1] + 1
                image[x_0: x_0 + object_width, y_0:y_0 + object_width] = 1
                # open it
                if i < n_open:
                    pos = (np.random.rand(n_missing, 2) * object_width).astype(int)
                    for p in pos:
                        image[x_0 + p[0], y_0 + p[1]] = 0
            return image.reshape(1, size, size)
        else:
            return self.__generate_image(n_open, n_all, size, object_width, n_missing)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, uid):
        if torch.is_tensor(uid):
            uid = uid.tolist()
        label = self.labels[uid]
        sample = self.data[uid]

        return (sample, label)

In [None]:
import torch
import torchvision
from torchvision.transforms import transforms
import numpy as np


class DatasetLocator:
    def __init__(self, conf: Config):

        self.dataset = conf.dataset
        self.gpu_run = conf.use_gpu
        self.batch_size = conf.batch_size
        train, valid, test = self.__load_data()

        self.dataset_dict = {
            DatasetType.TRAIN: train,
            DatasetType.VALID: valid,
            DatasetType.TEST: test
        }

    @staticmethod
    def __f(image):
        np_image = np.array(image)
        input_dim = np_image.shape[-1]
        new_image = np.zeros(shape=(60, 60), dtype=np.float32)
        i, j = np.random.randint(0, 60 - input_dim, size=2)
        new_image[i:i + input_dim, j:j + input_dim] = np_image
        return new_image

    def __transformed_mnist_transformation(self):
        return transforms.Compose(
            [torchvision.transforms.Lambda(self.__f),
             torchvision.transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))])

    @staticmethod
    def __augmented_mnist_transformation():
        return transforms.Compose([
            torchvision.transforms.RandomAffine(degrees=(-180, 180), scale=(0.5, 1.0), ),
            torchvision.transforms.ToTensor()])

    @staticmethod
    def __augmented_mnist_simple_transformation():
        return transforms.Compose([
            torchvision.transforms.RandomAffine(degrees=(0, 90), scale=(0.9, 1.0), ),
            torchvision.transforms.ToTensor()])

    def __load_data(self):
        train_total = self.__load_dataset(True)
        test = self.__load_dataset(False)

        train_length = int(len(train_total) * 0.9)
        valid_length = len(train_total) - train_length
        (train, valid) = torch.utils.data.random_split(train_total, (train_length, valid_length))
        return train, valid, test

    def __load_dataset(self, is_train):
        transform = None
        if self.dataset == DatasetName.MNIST:
            transform = torchvision.transforms.ToTensor()
        elif self.dataset == DatasetName.AUGMENTED:
            transform = self.__augmented_mnist_transformation()
        elif self.dataset == DatasetName.TRANSFORMED:
            transform = self.__transformed_mnist_transformation()
        elif self.dataset == DatasetName.AUGMENTED_MEDICAL:
            return AugmentedMedicalMNISTDataset(root_dir='.', data_dir = "MEDNIST",train = is_train, mnist_transform = self.__augmented_mnist_transformation())

        elif self.dataset == DatasetName.CLOSED_SQUARES:
            return ClosedSquaresDataset(train=is_train)
        return torchvision.datasets.MNIST(root='./data', train=is_train, download=True, transform=transform)

    def data_loader(self, dataset: DatasetType):
        should_shuffle = dataset == DatasetType.TRAIN
        data = self.dataset_dict[dataset]
        return torch.utils.data.DataLoader(data,
                                           batch_size=self.batch_size,
                                           pin_memory=self.gpu_run,
                                           shuffle=should_shuffle,
                                           num_workers=0)

In [None]:
import logging

import torch.nn as nn
import torch.nn.functional as F


class ActionNetwork(nn.Module):
    """The action network.

    Uses the internal state `h_t` of the core network to
    produce the final output classification.

    Concretely, feeds the hidden state `h_t` through a fc
    layer followed by a softmax to create a vector of
    output probabilities over the possible classes.

    Hence, the environment action `a_t` is drawn from a
    distribution conditioned on an affine transformation
    of the hidden state vector `h_t`, or in other words,
    the action network is simply a linear softmax classifier.

    Args:
        input_size: input size of the fc layer.
        output_size: output size of the fc layer.
        h_t: the hidden state vector of the core network
            for the current time step `t`.

    Returns:
        a_t: output probability vector over the classes.
    """

    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)
        logging.info(self)

    def forward(self, h_t):
        logging.debug("\n\nActionNetwork")
        logging.debug(f"Input:   {h_t.shape}")
        a_t = F.log_softmax(self.fc(h_t), dim=1)
        logging.debug(f"Softmax: {a_t.shape}\n\n")
        return a_t


In [None]:
import logging

import torch.nn as nn
import torch.nn.functional as F

class BaselineNetwork(nn.Module):
    """The baseline network.

    This network regresses the baseline in the
    reward function to reduce the variance of
    the gradient update.

    Args:
        input_size: input size of the fc layer.
        output_size: output size of the fc layer.
        h_t: the hidden state vector of the core network
            for the current time step `t`.

    Returns:
        b_t: a 2D vector of shape (B, 1). The baseline
            for the current time step `t`.
    """

    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)
        logging.info(self)

    def forward(self, h_t):
        logging.debug("\n\nBaselineNetwork")
        logging.debug(f"Input: {h_t.shape}")
        b_t = F.relu(self.fc(h_t.detach()))
        logging.debug(f"Fc1:   {b_t.shape}\n\n")
        return b_t

In [None]:
import logging

import torch.nn as nn
import torch
import torch.nn.functional as F

class CoreNetwork(nn.Module):

    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        logging.info(self)

    def forward(self, g_t):
        h1 = self.i2h(g_t)
        h2 = self.h2h(self.hidden_state)
        h_t = F.relu(h1 + h2)
        self.hidden_state = h_t
        return h_t

    def reset(self, batch_size, device):
        self.hidden_state = torch.zeros(
            (batch_size, self.hidden_size),
            dtype=torch.float,
            device=device,
            requires_grad=True)

In [None]:
import torch
import torch.nn.functional as F


class Retina:
    """
    Extracts a glimpse `phi` around location `l`
    from an image `x`.

    Encodes the region around `l` at a
    high-resolution but uses a progressively lower
    resolution for pixels further from `l`, resulting
    in a compressed representation of the original
    image `x`.

    Args:
        x: a 4D Tensor of shape (B, H, W, C). The minibatch
            of images.
        l: a 2D Tensor of shape (B, 2). Contains normalized
            coordinates in the range [-1, 1].
        From config:

        patch_size: size of the first square patch.
        num_patches: number of patches to extract in the glimpse.
        scale: scaling factor that controls the size of
            successive patches.
    """

    def __init__(self, conf:Config):
        self.patch_size = conf.patch_size
        self.num_patches = conf.num_patches
        self.scale = conf.glimpse_scale

    def foveate(self, x, l):
        """Extract `num_patches` square patches of size `patch_size`, centered
        at location `l`. The initial patch is a square of
        size `patch_size`, and each subsequent patch is a square
        whose side is `scale` times the size of the previous
        patch.

        The `num_patches` patches are finally resized to (patch_size, patch_size) and
        concatenated into a tensor of shape (B, k, g, g, C).
        """
        phi = []
        size = self.patch_size

        # extract k patches of increasing size
        for i in range(self.num_patches):
            phi.append(self.extract_patch(x, l, size))
            size = int(self.scale * size)

        # resize the patches to squares of size g
        for i in range(1, len(phi)):
            k = phi[i].shape[-1] // self.patch_size
            phi[i] = F.avg_pool2d(phi[i], k)

        # concatenate into a single tensor and flatten
        phi = torch.cat(phi, 1)
        # phi = phi.view(phi.shape[0], -1)

        return phi

    def extract_patch(self, x, l, size):
        """Extract a single patch for each image in `x`.

        Args:
        x: a 4D Tensor of shape (B, C, H, W). The minibatch
            of images.
        l: a 2D Tensor of shape (B, 2).
        size: a scalar defining the size of the extracted patch.

        Returns:
            patch: a 4D Tensor of shape (B, num_patches, size, size)
        """
        B, C, H, W = x.shape

        start = self.denormalize(H, l)
        end = start + size

        # pad with zeros
        x = F.pad(x, (size // 2, size // 2, size // 2, size // 2))

        # loop through mini-batch and extract patches
        patch = []
        for i in range(B):
            patch.append(x[i, :, start[i, 1]: end[i, 1], start[i, 0]: end[i, 0]])
        return torch.stack(patch)

    def denormalize(self, T, coords):
        """Convert coordinates in the range [-1, 1] to
        coordinates in the range [0, T] where `T` is
        the size of the image.
        """
        return (0.5 * ((coords + 1.0) * T)).long()

In [None]:
import logging

import torch.nn as nn
import torch.nn.functional as F


class GlimpseNetwork(nn.Module):
    """The glimpse network.

    TODO

    Args:
        conf.glimpse_hidden: hidden layer size of the fc layer for `phi`.
        conf.loc_hidden: hidden layer size of the fc layer for `l`.
        g: size of the square patches in the glimpses extracted
        by the retina.
        k: number of patches to extract per glimpse.
        s: scaling factor that controls the size of successive patches.
        c: number of channels in each image.
        x: a 4D Tensor of shape (B, H, W, C). The minibatch
            of images.
        l_t_prev: a 2D tensor of shape (B, 2). Contains the glimpse
            coordinates [x, y] for the previous timestep `t-1`.

    Returns:
        g_t: a 2D tensor of shape (B, hidden_size).
            The glimpse representation returned by
            the glimpse network for the current
            timestep `t`.
    """

    def __init__(self, conf: Config):
        super().__init__()

        self.retina = Retina(conf)

        D_out = conf.glimpse_hidden + conf.loc_hidden

        # what

        # padding of 1, to ensure same dimensions
        self.conv1 = nn.Conv2d(in_channels=self.retina.num_patches, out_channels=16, kernel_size=3, padding=1)

        self.conv2 = nn.Conv2d(in_channels=self.conv1.out_channels, out_channels=16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=self.conv2.out_channels, track_running_stats=True)

        self.conv3 = nn.Conv2d(in_channels=self.conv2.out_channels, out_channels=16, kernel_size=3, padding=1)

        D_in = self.conv3.out_channels * conf.patch_size * conf.patch_size

        self.fc1 = nn.Linear(in_features=D_in, out_features=D_out)

        # where
        # in_features = 2, loc is a tuple of (x,y)
        self.loc_fc1 = nn.Linear(in_features=2, out_features=D_out)

        logging.info(self)

    def forward(self, x, l_t_prev):
        logging.debug("\n\nGlimpseNetwork shapes")
        logging.debug("#### What ####")
        # generate glimpse phi from image x
        phi = self.retina.foveate(x, l_t_prev)
        logging.debug(phi.shape)

        # what
        # 3 conv layers
        h = self.conv1(phi)
        logging.debug(f"Conv1:      {h.shape}")
        h = F.relu(h)
        logging.debug(f"Conv1 ReLu: {h.shape}")
        h = F.relu(self.bn1(self.conv2(h)))
        logging.debug(f"Conv2:        {h.shape}")
        logging.debug(f"Bn1 ReLu:   {h.shape}")
        h = F.relu(self.conv3(h))
        logging.debug(f"Conv3:      {h.shape}")
        # flatten
        # keep batch dimension and determine other one automatically
        h = h.view(x.shape[0], -1)
        logging.debug(f"Flatten:    {h.shape}")

        # fully connected layers
        h = F.relu(self.fc1(h))
        logging.debug(f"Fc1:        {h.shape}")

        # where
        logging.debug("#### Where ####")
        # flatten location vector
        l_t_prev = l_t_prev.view(l_t_prev.size(0), -1)
        logging.debug(f"Input:         {l_t_prev.shape}")

        l = F.relu(self.loc_fc1(l_t_prev))
        logging.debug(f"Fc1(loc):      {l.shape}")
        logging.debug("#### Combined ####")
        # combine what and where
        g = F.relu(h * l)

        logging.debug(f"relu(h * l):   {g.shape}\n\n")
        return g

In [None]:
import logging

import torch.nn as nn
import torch
from torch.distributions import Normal

import torch.nn.functional as F

class LocationNetwork(nn.Module):
    """The location network.

    Uses the internal state `h_t` of the core network to
    produce the location coordinates `l_t` for the next
    time step.

    Concretely, feeds the hidden state `h_t` through a fc
    layer followed by a tanh to clamp the output beween
    [-1, 1]. This produces a 2D vector of means used to
    parametrize a two-component Gaussian with a fixed
    variance from which the location coordinates `l_t`
    for the next time step are sampled.

    Hence, the location `l_t` is chosen stochastically
    from a distribution conditioned on an affine
    transformation of the hidden state vector `h_t`.

    Args:
        input_size: input size of the fc layer.
        output_size: output size of the fc layer.
        std: standard deviation of the normal distribution.
        h_t: the hidden state vector of the core network for
            the current time step `t`.

    Returns:
        mu: a 2D vector of shape (B, 2).
        l_t: a 2D vector of shape (B, 2).
    """

    def __init__(self, input_size, output_size, std):
        super().__init__()

        self.std = std

        hid_size = input_size // 2
        self.fc = nn.Linear(input_size, output_size)
        #self.fc_lt = nn.Linear(hid_size, output_size)
        logging.info(self)

    def forward(self, h_t):
        logging.debug("\n\nLocationNetwork")
        logging.debug(f"Input:     {h_t.shape}")
        # compute mean
        mean = torch.tanh(self.fc(h_t.detach()))
        logging.debug(f"fc2+tanh:  {mean.shape}")

        if self.training:
            l_t = torch.distributions.Normal(mean, self.std).rsample().detach()
        #eval, not stochastic
        else:
            l_t = mean

        #if torch.any(l_t < -1):
            #print("MEAN")
            #print(mean)
            #print("L_T")
            #print(l_t)
        ##if torch.any(l_t > 1):
        #    print("MEAN")
        #    print(mean)
        # bound between [-1, 1]
        l_t = torch.clamp(l_t, -1, 1)

        return mean, l_t

In [None]:
import logging

import torch
import torch.nn as nn


class RecurrentAttention(nn.Module):
    """A Recurrent Model of Visual Attention (RAM) [1].

    RAM is a recurrent neural network that processes
    inputs sequentially, attending to different locations
    within the image one at a time, and incrementally
    combining information from these fixations to build
    up a dynamic internal representation of the image.

    References:
      [1]: Minh et. al., https://arxiv.org/abs/1406.6247
    """

    def __init__(self,config):
        """Constructor.

        Args:
          g: size of the square patches in the glimpses extracted by the retina.
          k: number of patches to extract per glimpse.
          s: scaling factor that controls the size of successive patches.
          c: number of channels in each image.
          h_g: hidden layer size of the fc layer for `phi`.
          h_l: hidden layer size of the fc layer for `l`.
          std: standard deviation of the Gaussian policy.
          hidden_size: hidden size of the rnn.
          num_classes: number of classes in the dataset.
          num_glimpses: number of glimpses to take per image,
            i.e. number of BPTT steps.
        """
        super().__init__()

        self.sensor = GlimpseNetwork(config)
        self.rnn = CoreNetwork(config.hidden_size, config.hidden_size)
        self.locator = LocationNetwork(config.hidden_size, 2, config.std)
        self.classifier = ActionNetwork(config.hidden_size, config.num_classes)
        self.baseliner = BaselineNetwork(config.hidden_size, 1)

    def forward(self, x, l_t_prev, last=False):
        """Run RAM for one timestep on a minibatch of images.

        Args:
            x: a 4D Tensor of shape (B, H, W, C). The minibatch
                of images.
            l_t_prev: a 2D tensor of shape (B, 2). The location vector
                containing the glimpse coordinates [x, y] for the previous
                timestep `t-1`.
            h_t_prev: a 2D tensor of shape (B, hidden_size). The hidden
                state vector for the previous timestep `t-1`.
            last: a bool indicating whether this is the last timestep.
                If True, the action network returns an output probability
                vector over the classes and the baseline `b_t` for the
                current timestep `t`. Else, the core network returns the
                hidden state vector for the next timestep `t+1` and the
                location vector for the next timestep `t+1`.

        Returns:
            h_t: a 2D tensor of shape (B, hidden_size). The hidden
                state vector for the current timestep `t`.
            mu: a 2D tensor of shape (B, 2). The mean that parametrizes
                the Gaussian policy.
            l_t: a 2D tensor of shape (B, 2). The location vector
                containing the glimpse coordinates [x, y] for the
                current timestep `t`.
            b_t: a vector of length (B,). The baseline for the
                current time step `t`.
            probabilities: a 2D tensor of shape (B, num_classes). The
                output log probability vector over the classes.
            mean_t: a vector of length (B,).
        """
        g_t = self.sensor(x, l_t_prev)
        h_t = self.rnn(g_t)
        mean_t, l_t = self.locator(h_t.detach())
        b_t = self.baseliner(h_t.detach()).squeeze()

        if last:
            probabilities = self.classifier(h_t)
            return h_t, l_t, b_t, probabilities, mean_t

        return h_t, l_t, b_t, mean_t

    def reset(self, batch_size, device):
        # h_t maintained by rnn itself
        self.rnn.reset(batch_size=batch_size, device=device)

        #l_t = torch.zeros(batch_size, 2).to(device)
        l_t = torch.FloatTensor(batch_size, 2).uniform_(-1, 1).to(device)
        logging.debug(f"DRAM reset, l_0: {l_t}")
        # TODO it doesn't right?
        l_t.requires_grad = True

        return l_t

In [None]:
import os
import json

# https://github.com/pytorch/examples/blob/master/imagenet/main.py
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def prepare_dirs(config):
    for path in [config.data_dir, config.ckpt_dir, config.logs_dir]:
        if not os.path.exists(path):
            os.makedirs(path)


def save_config(config):
    model_name = "ram_{}_{}x{}_{}".format(
        config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale
    )
    filename = model_name + "_params.json"
    param_path = os.path.join(config.ckpt_dir, filename)

    print("[*] Model Checkpoint Dir: {}".format(config.ckpt_dir))
    print("[*] Param Path: {}".format(param_path))

    with open(param_path, "w") as fp:
        json.dump(config.__dict__, fp, indent=4, sort_keys=True)

In [None]:
import logging
import os
import time
import shutil
import pickle

import torch
import torch.nn.functional as F
from torch.distributions import Normal

import itertools

from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboard_logger import configure, log_value

import numpy as np
class Trainer:
    """A Recurrent Attention Model trainer.

    All hyperparameters are provided by the user in the
    config file.
    """

    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.num_glimpses = config.num_glimpses

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.dataset)
            self.num_valid = len(self.valid_loader.dataset)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.best_valid_acc = 0.0
        self.counter = 0

        self.plot_dir = "./plots/" + self.config.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.config.use_tensorboard:
            tensorboard_dir = self.config.logs_dir + self.config.model_name
            logging.info("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(config)
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.init_lr,
            weight_decay=self.config.weight_decay
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, "min", patience=config.lr_patience
        )

    def train(self):
        """Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.config.resume:
            self.load_checkpoint(best=False)

        logging.info("\n[*] Train on {} samples, validate on {} samples"
              .format(self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):
            logging.info("\nEpoch: {}/{} - LR: {:.6f}".format(epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]))

            # train for 1 epoch
            train_loss, train_acc, loss_act,loss_base,loss_reinf = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            self.scheduler.step(-valid_acc)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} - action: {:.3f}, baseline: {:.3f} reinforce: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            logging.info(msg.format(train_loss, train_acc, loss_act,loss_base,loss_reinf , valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.config.train_patience:
                logging.info("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint({
                "epoch": epoch + 1,
                "model_state": self.model.state_dict(),
                "optim_state": self.optimizer.state_dict(),
                "best_valid_acc": self.best_valid_acc,
            },
                is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        losses_action = AverageMeter()
        losses_reinforce = AverageMeter()
        losses_baseline = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):

                loss, acc, preds, locs, imgs, loss_action ,loss_baseline, loss_reinforce  = self.one_batch(x, y)

                self.optimizer.zero_grad()

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # store
                losses.update(loss.item(), x.size()[0])
                losses_reinforce.update(loss_reinforce.item(), x.size()[0])
                losses_baseline.update(loss_baseline.item(), x.size()[0])
                losses_action.update(loss_action.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format((toc - tic), loss.item(), acc.item())))

                batch_size = x.shape[0]
                pbar.update(batch_size)

                # log to tensorboard
                if self.config.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg, losses_action.avg,losses_baseline.avg,losses_reinforce.avg

    def one_batch(self, x, y):
        # initialize location vector and hidden state
        batch_size = x.shape[0]
        x, y = x.to(self.device), y.to(self.device)

        imgs = []
        locs = []
        means = []
        baselines = []
        locations = []
        l_t = self.model.reset(batch_size, self.device)
        locs.append(l_t[0:9])
        for t in range(self.num_glimpses - 1):
            # forward pass through model
            h_t, l_t, b_t, mean_t = self.model(x, l_t)

            # save locs for plotting
            locs.append(l_t[0:9])
            locations.append(l_t)
            baselines.append(b_t)
            means.append(mean_t)

        # last iteration
        _, _, _, probabilities, _ = self.model(x, l_t, last=True)

        # save locs and images for plotting
        imgs.append(x[0:9])

        # convert list to tensors and reshape
        #TODO verify the transpoe
        baselines = torch.stack(baselines).transpose(1, 0)
        means = torch.stack(means).transpose(1, 0)
        locations = torch.stack(locations).transpose(1, 0)

        # calculate reward
        predicted = torch.argmax(probabilities, 1)
        R = (predicted.detach() == y).float()
        #print(f"Act:  {np.bincount(y.numpy())}")
        #print(f"Pred: {np.bincount(predicted.numpy())}")
        #print(f"Base: {baselines.sum(dim=0)}")
        #print(f"R:     {R.sum()}")
        #print("---------------")
        # either 1 (if correct) or 0
        R = R.unsqueeze(1).repeat(1, self.num_glimpses-1)

        # compute losses for differentiable modules
        # smaller, better, no need invert for nll
        loss_action = F.nll_loss(probabilities, y)

        loss_baseline = F.mse_loss(baselines, R)

        # compute reinforce loss

        # todo NEGATE reinforce loss?
        adjusted_reward = R - baselines.detach()

        adjusted_reward=adjusted_reward.repeat(1, 2).reshape(self.config.batch_size,-1,2).detach()
        probs = Normal(means, self.model.locator.std).log_prob(locations)
        # summed over timesteps and averaged across batch
        loss_reinforce = torch.sum(-probs * adjusted_reward, dim=1).sum(dim = 1)
        loss_reinforce = torch.mean(loss_reinforce, dim=0)

        #TODO LOGITS directly?
        # sum up into a hybrid loss
        #TODO super high loss with other sensor
        loss = loss_action + loss_baseline + loss_reinforce * self.config.reward_multi

        # compute accuracy
        correct = (predicted == y).float()
        acc = 100 * (correct.sum() / len(y))

        return loss, acc, predicted, locs, imgs, loss_action ,loss_baseline, loss_reinforce

    def __save_images_if_plotting(self, epoch, i, locs, imgs,y):
        # dump the glimpses and locs
        if (epoch % self.config.plot_freq == 0) and (i == 0):
            #print(y)
            imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
            locs = [l.cpu().data.numpy() for l in locs]
            pickle.dump(imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
            pickle.dump(locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

    @torch.no_grad()
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()
        # TODO check
        self.model.eval()

        for i, (x, y) in enumerate(self.valid_loader):
            # 3, 3, 0, 2, 3, 0, 1, 1, 1
            loss, acc, preds, locs, imgs, _,_,_ = self.one_batch(x, y)
            self.__save_images_if_plotting(epoch, i, locs, imgs,y)
            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.config.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg

    @torch.no_grad()
    def test(self):
        """
        Test the RAM model.
        """
        correct = 0
        preds = []

        # load the best checkpoint
        self.load_checkpoint(best=self.best)
        # TODO check
        self.model.eval()

        for i, (x, y) in enumerate(self.test_loader):
            loss, acc, predictions, locs, imgs,_,_,_ = self.one_batch(x, y)

            correct += sum(predictions == y)
            preds.append(predictions)
        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc

        logging.info("[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
            correct, self.num_test, perc, error))
        return torch.cat(preds)

    def save_checkpoint(self, state, is_best):
        """Saves a checkpoint of the model.

        If this model has reached the best validation accuracy thus
        far, a separate file with the suffix `best` is created.
        """
        filename = self.config.model_name + "_ckpt.pth.tar"
        ckpt_path = os.path.join(self.config.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        if is_best:
            filename = self.config.model_name + "_model_best.pth.tar"
            shutil.copyfile(ckpt_path, os.path.join(self.config.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """Load the best copy of a model.
        Args:
            best: if set to True, loads the best model.
        """
        logging.info("[*] Loading model from {}".format(self.config.ckpt_dir))

        filename = self.config.model_name + "_ckpt.pth.tar"
        if best:
            filename = self.config.model_name + "_model_best.pth.tar"
        ckpt_path = os.path.join(self.config.ckpt_dir, filename)
        logging.info(os.path.abspath(ckpt_path))
        ckpt = torch.load(ckpt_path, map_location="cpu")

        # load variables from checkpoint
        self.start_epoch = ckpt["epoch"]
        self.best_valid_acc = ckpt["best_valid_acc"]
        self.model.load_state_dict(ckpt["model_state"])
        self.optimizer.load_state_dict(ckpt["optim_state"])

        if best:
            logging.info(
                "[*] Loaded {} checkpoint @ epoch {} "
                "with best valid acc of {:.3f}".format(
                    filename, ckpt["epoch"], ckpt["best_valid_acc"]
                )
            )
        else:
            logging.info("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))


In [None]:


def main(config):
    prepare_dirs(config)
    torch.random.manual_seed(config.random_seed)
    if config.use_gpu:
        torch.cuda.manual_seed(config.random_seed)

    locator = DatasetLocator(config)
    # instantiate data loaders
    if config.is_train:
        train_loader = locator.data_loader(DatasetType.TRAIN)
        valid_loader = locator.data_loader(DatasetType.VALID)
        dloader = (train_loader,valid_loader)
    else:
        dloader = locator.data_loader(DatasetType.TEST)

    trainer = Trainer(config, dloader)

    # either train
    if config.is_train:
        trainer.train()
    # or load a pretrained model and test
    else:
        trainer.test()


In [None]:
log_level = logging.INFO
logging.basicConfig(level=log_level, format='%(name)-12s %(levelname)-8s %(message)s',
                datefmt='%m-%d %H:%M')
fh = logging.FileHandler('run.log')
fh.setLevel(log_level)
logging.getLogger().addHandler(fh)

config = Config()
main(config)