In [54]:
import gymnasium as gym
from gymnasium.spaces import MultiBinary, Box, Discrete

import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli
import torchvision.datasets as torchdata
from torch.utils.data import DataLoader, RandomSampler
import torchvision.transforms as transforms

from stable_baselines3.common.policies import ActorCriticCnnPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.distributions import BernoulliDistribution
from stable_baselines3.common.utils import get_device
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

In [38]:
def action_space_model(dset):
    # Model the action space by dividing the image space into equal size patches
    if dset == 'C10' or dset == 'C100':
        img_size = 32
        patch_size = 8
    elif dset == 'fMoW':
        img_size = 224
        patch_size = 56
    elif dset == 'ImgNet':
        img_size = 224
        patch_size = 56

    mappings = []
    for cl in range(0, img_size, patch_size):
        for rw in range(0, img_size, patch_size):
            mappings.append([cl, rw])

    return mappings, img_size, patch_size

In [103]:
batch_size = 512
lr_size = 8
mappings, img_size, patch_size = action_space_model('C10')
penalty = -0.5
num_patches = 16
alpha=0.8

In [121]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, initial_kernel_size, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=initial_kernel_size, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, res):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out
    

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.actor = ResNet(BasicBlock, [1,1,1,1], 3, 16)
        self.critic = ResNet(BasicBlock, [1,1,1,1], 3, 1)

    def forward(self, x, res):
        actor_out = self.actor(x, res)
        critic_out = self.critic(x, res)
        return actor_out, critic_out


class CustomCNNFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(CustomCNNFeaturesExtractor, self).__init__(observation_space, features_dim)
        self.cnn = ActorCritic()
        self._features_dim = features_dim

    def forward(self, observations, res='lr'):
        features = self.cnn(observations, res)[0]
        return features


class Policy(ActorCriticCnnPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, *args, **kwargs):
        super(Policy, self).__init__(observation_space, action_space, lr_schedule,
                                     features_extractor_class=CustomCNNFeaturesExtractor,
                                     features_extractor_kwargs=dict(features_dim=512),
                                     *args, **kwargs)
        
        self.action_dist = BernoulliDistribution(self.action_space.n)

    def _build_mlp_extractor(self):
        return super()._build_mlp_extractor()

    def forward(self, obs):
        features = self.extract_features(obs)
        latent_pi, latent_vf = self.mlp_extractor(features)
        
        logits = self.action_net(latent_pi)
        values = self.value_net(latent_vf)
        
        distribution = self.action_dist.proba_distribution(logits)
        actions = distribution.get_actions()
        log_prob = distribution.log_prob(actions)
        return actions, values, log_prob

    def evaluate_actions(self, obs, actions):
        features = self.extract_features(obs)
        latent_pi, latent_vf = self.mlp_extractor(features)
        
        logits = self.action_net(latent_pi)
        values = self.value_net(latent_vf)
        
        distribution = self.action_dist.proba_distribution(logits)
        log_prob = distribution.log_prob(actions)
        entropy = distribution.entropy()
        return values, log_prob, entropy


rnet_hr = ResNet(BasicBlock, [3,4,6,3], 3, 10)

In [113]:
mean = [x/255.0 for x in [125.3, 123.0, 113.9]]
std = [x/255.0 for x in [63.0, 62.1, 66.7]]
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

root = '../data'

trainset = torchdata.CIFAR10(root=root, train=True, download=False, transform=transform_train)
testset = torchdata.CIFAR10(root=root, train=False, download=False, transform=transform_test)

trainloader = RandomSampler(trainset, replacement=True, num_samples=1)
testloader = RandomSampler(testset, replacement=True, num_samples=1)
# trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
# testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [106]:
r = next(iter(trainloader))
# trainset[r]
i = trainloader.data_source[r][0]
F.interpolate(i.unsqueeze(0), (lr_size, lr_size)).squeeze(0).shape

torch.Size([3, 8, 8])

In [107]:
dataloader = trainloader

In [123]:
class PatchDrop(gym.Env):
    metadata = {"render_modes": None}

    def __init__(self, render_mode=None, num_patches=16, img_size=8):
        super(PatchDrop, self).__init__()

        self.observation_space = Box(low=0, high=255, shape=(3, img_size, img_size), dtype=np.uint8)
        self.action_space = MultiBinary(num_patches)

    def reset(self, dataloader=trainloader, seed=None, options=None):
        self.curr_img_idx = next(iter(dataloader))
        self.hr_img, self.target = dataloader.data_source[self.curr_img_idx]
        # if len(self.hr_img.shape) == 3:
        #     self.hr_img = self.hr_img.unsqueeze(0)
        lr_img = F.interpolate(self.hr_img.unsqueeze(0), (lr_size, lr_size)).squeeze(0)
        # print("PASS")
        return lr_img, {}
    
    def step(self, policy_sample):
        inputs_sample = self.agent_chosen_input(self.hr_img, policy_sample, mappings, patch_size)
        preds_sample = rnet_hr.forward(inputs_sample, 'C10', 'hr')

        reward_sample, _ = self.compute_reward(preds_sample, self.target, policy_sample.data, penalty)

        done = True
        lr_img = self.reset()
        return lr_img, reward_sample, done, {}


    def agent_chosen_input(self, input_org, policy, mappings, patch_size):
        """ Generate masked images w.r.t policy learned by the agent."""
        input_full = input_org.clone()
        sampled_img = torch.zeros([input_org.shape[0], input_org.shape[1], input_org.shape[2], input_org.shape[3]])
        for pl_ind in range(policy.shape[1]):
            mask = (policy[:, pl_ind] == 1).cpu()
            sampled_img[:, :, mappings[pl_ind][0]:mappings[pl_ind][0]+patch_size, mappings[pl_ind][1]:mappings[pl_ind][1]+patch_size] = input_full[:, :, mappings[pl_ind][0]:mappings[pl_ind][0]+patch_size, mappings[pl_ind][1]:mappings[pl_ind][1]+patch_size]
            sampled_img[:, :, mappings[pl_ind][0]:mappings[pl_ind][0]+patch_size, mappings[pl_ind][1]:mappings[pl_ind][1]+patch_size] *= mask.unsqueeze(1).unsqueeze(1).unsqueeze(1).float()
        input_org = sampled_img

        return input_org.cuda()
    
    def compute_reward(self, preds, targets, policy, penalty):
        # Reward function favors policies that drops patches only if the classifier
        # successfully categorizes the image
        patch_use = policy.sum(1).float() / policy.size(1)
        sparse_reward = 1.0 - patch_use**2

        _, pred_idx = preds.max(1)
        match = (pred_idx==targets).data

        reward = sparse_reward
        reward[~match] = penalty
        reward = reward.unsqueeze(1)

        return reward, match.float()

In [124]:
env = make_vec_env(lambda: PatchDrop(), n_envs=batch_size)

model = PPO(Policy, env, verbose=1, device='cuda')
model.learn(total_timesteps=10000)

model.save('ppo_C10')



Using cuda device
Actor output shape: torch.Size([512, 16])
Critic output shape: torch.Size([512, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x16 and 512x64)

In [None]:
model = PPO.load('ppo_C10')

# Evaluate agent
obs = env.reset()
done = False
while not done:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    print(reward)