# Import Everything

In [1]:
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
import argparse
from cmath import log
import os
import random
import time
from distutils.util import strtobool
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from quantize_methods import get_eager_quantization


import gym
from algos.opt import Adan, hAdam
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from torch.ao.quantization.fake_quantize import default_fused_wt_fake_quant , default_weight_fake_quant

  from .autonotebook import tqdm as notebook_tqdm
  tensorboard.__version__


# import Logging

In [2]:
logging.basicConfig(filename="tests.log", level=logging.NOTSET,
                    filemode='w',
                    format='%(asctime)s:%(levelname)s:%(filename)s:%(lineno)d:%(message)s')
try:
    from quantize_methods import size_of_model
except ModuleNotFoundError as e:
    logging.error(e)

# Arguments

In [3]:
def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    ## Jupyter Notebook Arguments
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="if toggled, this experiment will be tracked with Weights and Biases")
    parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
        help="the wandb's project name")
    parser.add_argument("--wandb-entity", type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="weather to capture videos of the agent performances (check out `videos` folder)")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="CartPole-v1",
        help="the id of the environment")
    parser.add_argument("--total-timesteps", type=int, default=500000,
        help="total timesteps of the experiments")
    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")
    parser.add_argument("--buffer-size", type=int, default=10000,
        help="the replay memory buffer size")
    parser.add_argument("--gamma", type=float, default=0.99,
        help="the discount factor gamma")
    parser.add_argument("--target-network-frequency", type=int, default=500,
        help="the timesteps it takes to update the target network")
    parser.add_argument("--batch-size", type=int, default=128,
        help="the batch size of sample from the reply memory")
    parser.add_argument("--start-e", type=float, default=1,
        help="the starting epsilon for exploration")
    parser.add_argument("--end-e", type=float, default=0.05,
        help="the ending epsilon for exploration")
    parser.add_argument("--exploration-fraction", type=float, default=0.5,
        help="the fraction of `total-timesteps` it takes from start-e to go end-e")
    parser.add_argument("--learning-starts", type=int, default=10000,
        help="timestep to start learning")
    parser.add_argument("--train-frequency", type=int, default=10,
        help="the frequency of training")
    
    # Quantization specific arguments
    ## Quantize Weight
    parser.add_argument("--quantize-weight", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True)
    parser.add_argument("--quantize-weight-bitwidth", type=int, default=8)
    parser.add_argument("--quantize-weight-quantize-min", type=int, default= 0)
    parser.add_argument("--quantize-weight-quantize-max", type=int, default= 255)
    parser.add_argument("--quantize-weight-dtype", type=str, default="quint8")
    parser.add_argument("--quantize-weight-qschme", type=str, default="per_tensor_symmetric")
    parser.add_argument("--quantize-weight-reduce-range", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    ## Quantize Activation
    parser.add_argument("--quantize-activation" , type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    parser.add_argument("--quantize-activation-bitwidth", type=int, default=8)
    parser.add_argument("--quantize-activation-quantize-min", type=int, default= 0)
    parser.add_argument("--quantize-activation-quantize-max", type=int, default= 255)
    parser.add_argument("--quantize-activation-qscheme", type=str, default="per_tensor_symmetric")
    parser.add_argument("--quantize-activation-quantize-dtype", type=str, default="quint8")
    parser.add_argument("--quantize-activation-reduce-range", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    ## Other papers algorithm and ideas
    parser.add_argument("--optimizer" , type=str, default="Adam")
    args = parser.parse_args()
    # fmt: on
    return args

# Make Env

In [4]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

# Q Network

In [5]:
class QNetwork(nn.Module):
    def __init__(self, 
                 env ,
                 ):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )
        self.quantize_modules = torch.ao.quantization.QuantStub()
        self.dequantize_modules = torch.ao.quantization.DeQuantStub()
        logging.info(f"The model is {self.network} GB")  
        logging.info(f"The size of the model is {size_of_model(self.network)}")
    def forward(self, x , quantize = False):
        return  self.quantize_modules(self.quantize_modules(self.network(x))) if quantize else self.network(x)
    ## Fuse the model
    def fuse_model(self):
        layers = list()
        for index in range( 0, len(self.network) - 2 , 2):
            layers.append([str(index) , str(index + 1)])
        logging.info(f"Layers to fuse {layers}")
        print(f"Layers to fuse {layers}")
        torch.ao.quantization.fuse_modules(self.network, layers, inplace=True)
    def get_quantization_config(self):
        if self.quantize_weight:
            if self.quantize_activation:
                return torch.ao.quantization.QConfig(
                    activation = torch.ao.quantization.FakeQuantize.with_args(
                        observer = torch.ao.quantization.MovingAverageMinMaxObserver(
                            dtype = self.quantize_activation_quantize_dtype,
                            reduce_range = self.quantize_activation_quantize_reduce_range,
                            quant_min = self.quantize_activation_quantize_min,
                            quant_max = self.quantize_activation_quantize_max,
                        )
                    ),
                    weight = torch.ao.quantization.FakeQuantize.with_args(
                        observer = torch.ao.quantization.MovingAverageMinMaxObserver(
                            dtype = torch.quint8,
                            quant_min = -128,
                            quant_max = 127,
                        )
                    )
                )
            else:
                return torch.ao.quantization.QConfig(
                    activation = torch.nn.Identity,
                    weight = default_weight_fake_quant,
                )

# Linear Schedule

In [6]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

# Parse Arguments

In [36]:
args = dict()
args["env"] = "CartPole-v0"
args["quantize_activation_quantize_dtype"] = torch.quint8
args["seed"] = 0
args["torch_deterministic"] = True
args["env_id"] = "CartPole-v0"
args["capture_video"] = False
## Quantization
args["quantize_weight"] = True
args["quantize_weight_bitwidth"] = 8
args["quantize_observation_type"] = "moving_average_min_max"
args["quantize_weight_quantize_min"] = 0
args["quantize_weight_quantize_max"] = 255
args["quantize_weight_dtype"] = "quint8"
args["quantize_weight_qscheme"] = "per_tensor_symmetric"
args["quantize_weight_reduce_range"] = False
## Quantize Activation
args["quantize_activation"] = True
args["quantize-activation-bitwidth"] = 8
args["quantize_observation_type"] = "moving_average_min_max"
args["quantize_activation_quantize_min"] = 0
args["quantize_activation_quantize_max"] = 255
args["quantize_activation_dtype"] = "quint8"
args["quantize_activation_qscheme"] = "per_tensor_symmetric"
args["quantize_activation_reduce_range"] = False

## Convert the dictionary to a namespace
args = argparse.Namespace(**args)
assert args.quantize_activation_dtype == "quint8", f"The activation dtype must be quint8 and {args.quantize_activation_dtype}"

# Modified the argument where the string pytorch quantization datatype will be PyTorch Quantization Datatype

In [37]:
if args.quantize_activation_quantize_dtype is not None and isinstance(args.quantize_activation_quantize_dtype, str):
    if args.quantize_activation_quantize_dtype == "quint8":
        args.quantize_activation_quantize_dtype = torch.quint8
    elif args.quantize_activation_quantize_dtype == "qint8":
        args.quantize_activation_quantize_dtype = torch.qint8
    else:
        print(args.quantize_activation_quantize_dtype)
        raise ValueError(f"{args.quantize_activation_quantize_dtype} is not supported for quantization")
if args.quantize_weight_dtype is not None and isinstance(args.quantize_weight_dtype, str):
    if args.quantize_weight_dtype == "quint8":
        args.quantize_weight_dtype = torch.quint8
    elif args.quantize_weight_dtype == "qint8":
        args.quantize_weight_dtype = torch.qint8
    else:
        raise ValueError(f"{args.quantize_weight_dtype} is not supported for quantization")
assert isinstance(args.quantize_activation_quantize_dtype, torch.dtype), f"The activation dtype must be torch.dtype and {type(args.quantize_activation_quantize_dtype)}"
print(args.quantize_activation_dtype)
print(args.quantize_weight_dtype)

# Modified the argument where the string pytorch quantization Quantization Scheme 

In [38]:
if args.quantize_activation_qscheme is not None and isinstance( args.quantize_activation_qscheme , str):
    if args.quantize_activation_qscheme == "per_tensor_symmetric":
        args.quantize_activation_qscheme = torch.per_tensor_symmetric
    elif args.quantize_activation_qscheme == "per_tensor_affine":
        args.quantize_activation_qscheme = torch.per_tensor_affine
    else:
        raise ValueError(f"{args.quantize_activation_qscheme} is not supported for quantization")
if args.quantize_weight_qscheme is not None and isinstance(args.quantize_weight_qscheme, str):
    if args.quantize_weight_qscheme == "per_tensor_symmetric":
        args.quantize_weight_qscheme = torch.per_tensor_symmetric
    elif args.quantize_weight_qscheme == "per_tensor_affine":
        args.quantize_weight_qscheme = torch.per_tensor_affine
    else:
        raise ValueError(f"{args.quantize_weight_qscheme} is not supported for quantization")

# Set the Seed

In [39]:
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

# Set the Devices

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
logging.info(f"The device the DQN is running on: {device}")

# Env

In [41]:
run_name = f"{args.env}_{args.seed}_{int(time.time())}"
# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

  f"The environment {path} is out of date. You should consider "
  "Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "


# Build the Neural Network 

In [42]:
from rich import print
q_network = QNetwork(
                    env = envs,
                        )
print(q_network)
from rich.pretty import pprint
pprint(q_network)

# Applied Eager Quantization to the method fuse the modules

1. To quantize the model for inference in using eager quantization, we need to fuse the model in where the model is eval mode
2. Called the fuse method
3. Set the Quantization Configuration for the model
4. Called the prepare QAT mode 

In [43]:
q_network.eval()
q_network.fuse_model()
pprint(q_network)

# Appled the Q Config to the model

## Example of Q Config from fbgemm

In [44]:
## fbgemm
from torch.ao.quantization import get_default_qat_qconfig

example_qconfig = get_default_qat_qconfig("fbgemm")
print(example_qconfig)
pprint(example_qconfig)

## Create Eager Quantization Cofiguration

In [45]:
import torch
import os
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.qconfig import QConfig
import torch
def get_eager_quantization(
    weight_quantize:bool  = True,
    weight_observer_type:str = "moving_average_minmax",
    weight_quantization_min:int = 0,
    weight_quantization_max:int = 255,
    weight_quantization_dtype:torch.dtype = torch.quint8,
    weight_quantization_qscheme:torch.qscheme = torch.per_tensor_symmetric,
    weight_reduce_range = True,
    activation_quantize:bool = True,
    activation_observer_type:str = "moving_average_minmax",
    activation_quantization_min:int = -128,
    activation_quantization_max:int = 127,
    activation_quantization_dtype:torch.dtype = torch.quint8,
    activation_quantization_qscheme:torch.qscheme = torch.per_tensor_symmetric,
    *args, **kwargs
):
    assert isinstance( weight_quantization_dtype , torch.dtype)
    assert isinstance( activation_quantization_dtype , torch.dtype)
    assert isinstance( weight_quantization_qscheme , torch.qscheme)
    assert isinstance( activation_quantization_qscheme , torch.qscheme)
    ## all quantization  in eager mode are unifrom quantization 
    weight_quantization_fake_quantize = None
    if weight_quantize:
        weight_quantization_fake_quantize = FakeQuantize.with_args(
                    observer =  MinMaxObserver.with_args(
                        dtype = weight_quantization_dtype,
                        qscheme = weight_quantization_qscheme,
                        reduce_range = False , 
                        quant_min= weight_quantization_min,
                        quant_max = weight_quantization_max,
                    ))
    activation_quantization_fake_quantize = None
    if activation_quantize:
            activation_quantization_fake_quantize = FakeQuantize.with_args(
                    observer =   MinMaxObserver.with_args( 
                        quant_min = activation_quantization_min,
                        quant_max = activation_quantization_max,
                        dtype = activation_quantization_dtype,
                        qscheme = activation_quantization_qscheme,
                        reduce_range = False
                    ))
    return QConfig(
        weight = weight_quantization_fake_quantize,
        activation = activation_quantization_fake_quantize
    )

### Print out the Default Q Config

In [46]:
q_config = get_eager_quantization()
pprint(q_config)
print(q_config)

## Set the Q Configuration of the Q Netork 

In [47]:
q_network.train()
q_network.qconfig = q_config

## Prepare the model for QAT

In [48]:
torch.ao.quantization.prepare_qat(q_network, inplace=True)

QNetwork(
  (network): Sequential(
    (0): LinearReLU(
      in_features=4, out_features=120, bias=True
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (1): Identity()
    (2): LinearReLU(
      in_features=120, out_features=84, bias=True
     

In [49]:
print(q_network)
pprint(q_network)

# Applied Use Spefic Q Config Pattern

In [50]:
'''
1. Create another Q Network
2. Fuse the modules
3. Set the Quantization Config
4. called prepare for QAT
'''
q_network = QNetwork(
    env=envs)
## set the model to eval mode
q_network.eval()
## fuse the modules
q_network.fuse_model()
pprint(q_network)
## set the model to trian 
q_network.train()
## set the user spefici quantization config
q_network.qconfig = get_eager_quantization(
            weight_quantize = args.quantize_weight,
            weight_observer_type = "moving_average_min_max",
            weight_quantization_min =  args.quantize_weight_quantize_min , 
            weight_quantization_max = args.quantize_weight_quantize_max,
            weight_quantization_dtype = args.quantize_weight_dtype,
            weight_reduce_range= args.quantize_weight_reduce_range,
            activation_quantize= args.quantize_activation,
            activation_quantization_min = args.quantize_activation_quantize_min,
            activation_quantization_max = args.quantize_activation_quantize_max,
            activation_quantization_dtype = args.quantize_activation_quantize_dtype,
            activation_quantization_qscheme = args.quantize_activation_qscheme,
            activation_reduce_range = args.quantize_activation_reduce_range,
)
torch.ao.quantization.prepare_qat(q_network, inplace=True)

QNetwork(
  (network): Sequential(
    (0): LinearReLU(
      in_features=4, out_features=120, bias=True
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (1): Identity()
    (2): LinearReLU(
      in_features=120, out_features=84, bias=True
      (w