# Import PyTorch 

In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import logging
from torch import nn
class QNetwork(nn.Module):
    def __init__(self, env,
                 quantize:bool = False,
                 ):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
        )
        logging.info(f"QNetwork: {self.network}")
        ## quantization 
    def forward(self, x):
        return self.network(x / 255.0)

In [4]:
import os
import argparse
import gym
from distutils.util import strtobool

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="if toggled, this experiment will be tracked with Weights and Biases")
    parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
        help="the wandb's project name")
    parser.add_argument("--wandb-entity", type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="weather to capture videos of the agent performances (check out `videos` folder)")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
        help="the id of the environment")
    parser.add_argument("--total-timesteps", type=int, default=10000000,
        help="total timesteps of the experiments")
    parser.add_argument("--learning-rate", type=float, default=1e-4,
        help="the learning rate of the optimizer")
    parser.add_argument("--buffer-size", type=int, default=1000000,
        help="the replay memory buffer size")
    parser.add_argument("--gamma", type=float, default=0.99,
        help="the discount factor gamma")
    parser.add_argument("--target-network-frequency", type=int, default=1000,
        help="the timesteps it takes to update the target network")
    parser.add_argument("--batch-size", type=int, default=32,
        help="the batch size of sample from the reply memory")
    parser.add_argument("--start-e", type=float, default=1,
        help="the starting epsilon for exploration")
    parser.add_argument("--end-e", type=float, default=0.01,
        help="the ending epsilon for exploration")
    parser.add_argument("--exploration-fraction", type=float, default=0.10,
        help="the fraction of `total-timesteps` it takes from start-e to go end-e")
    parser.add_argument("--learning-starts", type=int, default=80000,
        help="timestep to start learning")
    parser.add_argument("--train-frequency", type=int, default=4,
        help="the frequency of training")
    
    # Quantization
    parser.add_argument("--quantize", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    args = parser.parse_args()
    # fmt: on
    return args

In [5]:
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

  tensorboard.__version__


In [8]:
# env setup
import gym
#args = parse_args()

In [9]:
envs = gym.vector.SyncVectorEnv([make_env("BreakoutNoFrameskip-v4", 42, 0, False, "run_name")])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs , quantize = False)

A.L.E: Arcade Learning Environment (version 0.8.0+919230b)
[Powered by Stella]


In [10]:
from rich import print
print(q_network)

## Applied Torch QAT

In [11]:
observation_space = envs.single_observation_space.shape
print(observation_space)

In [12]:
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qat_qconfig_mapping
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
prepare_qat_q_netork = prepare_fx(q_network , qconfig_mapping=qconfig_mapping  , example_inputs = observation_space)

  reduce_range will be deprecated in a future release of PyTorch."


In [19]:
print( qconfig_mapping )
example =  qconfig_mapping.to_dict()
print(example)

In [33]:
from rich.pretty import pprint
pprint(torch.ao.quantization.get_default_qat_qconfig('fbgemm'))

In [31]:
from rich.pretty import pprint
pprint( prepare_qat_q_netork ) 

In [29]:
print( prepare_qat_q_netork.print_readable() ) 

class GraphModule(torch.nn.Module):
    def forward(self, x):
        
        # File: /tmp/ipykernel_142013/1800980637.py:23, code: return self.network(x / 255.0)
        truediv = x / 255.0;  x = None
        
        # No stacktrace found for following nodes 
        activation_post_process_0 = self.activation_post_process_0(truediv);  truediv = None
        
        # File: /home/null/miniconda3/envs/cleanrl/lib/python3.7/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
        network_0 = getattr(self.network, "0")(activation_post_process_0);  activation_post_process_0 = None
        
        # No stacktrace found for following nodes 
        activation_post_process_1 = self.activation_post_process_1(network_0);  network_0 = None
        
        # File: /home/null/miniconda3/envs/cleanrl/lib/python3.7/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
    