This notebook implements Implicit Q-Learning to solve Maze 2D.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

from typing import Tuple, Optional

import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
from IPython.display import Image
from src.utils import (
    get_device,
    set_seed,
    eval_policy,
    demo_policy,
    plot_returns,
    save_frames_as_gif,
    return_range
)
from tqdm import tqdm
import einops
import os
import copy

from src.d4rl_dataset import D4RLSampler

plt.ion()

  from pkg_resources import resource_stream, resource_exists


<contextlib.ExitStack at 0x155551008f70>

We start by importing D4RL, a library that contains training data obtained by running policy of different levels on a few environments.

In [2]:
import gym
import d4rl

  from distutils.dep_util import newer, newer_group
No module named 'flow'
/home/hice1/ibaali3/.conda/envs/cs8803drl_hw2/lib/python3.10/site-packages/glfw/__init__.py:914: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'
No module named 'carla'
pybullet build time: Nov 28 2023 23:45:17


In [3]:
SEED: int = 42
ENVIRONMENT_NAME: str='maze2d-umaze'

# torch related defaults
DEVICE = get_device()
torch.set_default_dtype(torch.float32)

# Use random seeds for reproducibility
set_seed(SEED)

# instantiate the environment
env = gym.make(ENVIRONMENT_NAME)

# get the state and action dimensions
action_dimension = env.action_space.shape[0]
state_dimension = env.observation_space.shape[0]

logger.info(f'Action Dimension: {action_dimension}')
logger.info(f'Action High: {env.action_space.high}')
logger.info(f'Action Low: {env.action_space.low}')
logger.info(f'State Dimension: {state_dimension}')

[32m2024-12-10 16:09:40.448[0m | [1mINFO    [0m | [36msrc.utils[0m:[36mget_device[0m:[36m52[0m - [1mUsing cuda device.[0m
[32m2024-12-10 16:09:40.463[0m | [1mINFO    [0m | [36msrc.utils[0m:[36mset_seed[0m:[36m38[0m - [1mRandom seed set as 42.[0m
  logger.warn(
[32m2024-12-10 16:09:40.668[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mAction Dimension: 2[0m
[32m2024-12-10 16:09:40.669[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mAction High: [1. 1.][0m
[32m2024-12-10 16:09:40.670[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mAction Low: [-1. -1.][0m
[32m2024-12-10 16:09:40.670[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m21[0m - [1mState Dimension: 4[0m


Then, we need to download the training data using D4RL library.

In [4]:
dataset = d4rl.qlearning_dataset(env)

logger.info(f'Dataset type: {type(dataset)}')
logger.info(f'Dataset keys: {dataset.keys()}')
logger.info(f'# Samples: {len(dataset["observations"])}')
sampler = D4RLSampler(dataset, 256, DEVICE)

load datafile: 100%|██████████| 8/8 [00:00<00:00,  8.19it/s]
[32m2024-12-10 16:09:45.715[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mDataset type: <class 'dict'>[0m
[32m2024-12-10 16:09:45.716[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mDataset keys: dict_keys(['observations', 'actions', 'next_observations', 'rewards', 'terminals'])[0m
[32m2024-12-10 16:09:45.717[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m# Samples: 987540[0m


Then, we define the Q network that calculates the state-action values. Here, we use a double Q-network, as was highlighted in the original paper.

In [5]:
class QNetwork(nn.Module):
    def __init__(self, state_dimension, action_dimension, hidden_dim, n_hidden):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.Q1 = nn.Sequential(
            nn.Linear(state_dimension+action_dimension,hidden_dim),
            nn.ReLU(),
            *nn.ModuleList([nn.Linear(hidden_dim,hidden_dim),nn.ReLU()]*n_hidden),
            nn.Linear(hidden_dim,1)
        )
       
        # Q2 architecture
        self.Q2 = nn.Sequential(
            nn.Linear(state_dimension+action_dimension,hidden_dim),
            nn.ReLU(),
            *nn.ModuleList([nn.Linear(hidden_dim,hidden_dim),nn.ReLU()]*n_hidden),
            nn.Linear(hidden_dim,1)
        )


    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        q1 = self.Q1(xu)
        q2 = self.Q2(xu)
        return q1, q2

In [6]:
class VNetwork(nn.Module):
    def __init__(self, state_dimension, hidden_dim, n_hidden):
        super(VNetwork, self).__init__()

        self.V = nn.Sequential(
            nn.Linear(state_dimension,hidden_dim),
            nn.ReLU(),
            *nn.ModuleList([nn.Linear(hidden_dim,hidden_dim),nn.ReLU()]*n_hidden),
            nn.Linear(hidden_dim,1)
        )
        
    def forward(self, state):
        v = self.V(state)
        return v

In [7]:
from torch.distributions.normal import Normal

DEVICE = get_device()
HIDDEN_DIMENSION: int = 256
N_HIDDEN: int = 3

def tensor(x: np.array, type=torch.float32, device=DEVICE) -> torch.Tensor:
    return torch.as_tensor(x, dtype=type, device=device)


def network(
        in_dimension: int, 
        out_dimension: int, 
        hidden_dimension: int = 256, 
        n_hidden: int = 3) -> nn.Module:
    """
    Args:
        in_dimension (int): Dimension of the input layer.
        hidden_dimension (int): Dimension of the hidden layers.
        out_dimension (int): Dimension of the output layer.

    Returns:
        nn.Module: The constructed neural network model.
    """
    shapes = [in_dimension] + [hidden_dimension] * n_hidden + [out_dimension]
    layers = []
    for i in range(len(shapes) - 2):
        layers.append(nn.Linear(shapes[i], shapes[i+1]))
        layers.append(nn.Mish())
    layers.append(nn.Linear(shapes[-2], shapes[-1]))
    return nn.Sequential(*layers)


class GaussianPolicy(nn.Module):
    def __init__(
            self,
            state_dimension: int,
            action_dimension: int,
            hidden_dimension: int = HIDDEN_DIMENSION,
            n_hidden: int = N_HIDDEN,
    ):
        super(GaussianPolicy, self).__init__()
        self.network = network(
            state_dimension, 2 * action_dimension, hidden_dimension, n_hidden
        )
        self.action_dimension = action_dimension

    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the Policy network. Should return mean and log_std of the policy distribution

        Args:
            state (np.ndarray): The input state.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The tuple (mean, log_std) of the distribution corresponding to each action
        """
        out = self.network(state)
        mean, log_std = torch.split(out, self.action_dimension, dim=-1)
        log_std = torch.clamp(log_std, -10, 2)
        return mean, log_std
        

    def pi(self, state: torch.Tensor) -> Normal:
        """
        Computes the action distribution π(a|s) for a given state.

        Args:
            state (np.ndarray): The input state.

        Returns:
            Categorical: The action distribution.
        """
        mean, log_std = self(state)
        std = log_std.exp()
        return Normal(mean, std)

    def action(self, state: np.ndarray, eval=False) -> np.ndarray:
        """
        Selects an action based on the policy without returning the log probability.

        Args:
            state (np.ndarray): The input state.

        Returns:
            torch.Tensor: The selected action.
        """
        state = tensor(state)

        policy = self.pi(state)
        if eval:
            action = policy.mean.cpu().numpy()
        else:
            action = policy.sample().cpu().numpy()
        return action

[32m2024-12-10 16:09:46.007[0m | [1mINFO    [0m | [36msrc.utils[0m:[36mget_device[0m:[36m52[0m - [1mUsing cuda device.[0m


In [8]:
def expectile_loss(diff, expectile=0.8):
    return torch.abs(expectile-(diff<0).int())*torch.square(diff)

In [None]:
EPOCHS = 150
EVAL_FREQ = 15
LOAD_FROM_CKPT = True

# These parameters should work fine but you may tune them if you want to
hidden_dim: int = 256
n_hidden: int = 2
lr: float = 3e-4
discount = 0.99
alpha = 0.005
exp_advantage_max = 100

tau = 0.7
beta = 3

min_rew, max_rew = return_range(dataset, 1000)

#############################################################################################

sampler = D4RLSampler(dataset, 256, DEVICE)

iql_policy = GaussianPolicy(state_dimension, action_dimension, hidden_dim, n_hidden).to(DEVICE)
policy_optimizer = Adam(iql_policy.parameters(), lr)
policy_lr_schedule = CosineAnnealingLR(policy_optimizer, EPOCHS * len(sampler))

v_critic = VNetwork(state_dimension, hidden_dim, n_hidden).to(DEVICE)
v_optimizer = Adam(v_critic.parameters(), lr)

q_critic = QNetwork(state_dimension, action_dimension, hidden_dim, n_hidden).to(DEVICE)
q_critic_target = copy.deepcopy(q_critic)
q_critic_target.requires_grad_(False)
q_optimizer = Adam(q_critic.parameters(), lr)

means, stds, start_epoch = [], [], 0
if os.path.exists('iql_checkpoint.pth') and LOAD_FROM_CKPT:
    checkpoint = torch.load('iql_checkpoint.pth')

    iql_policy.load_state_dict(checkpoint['iql_policy'])
    policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
    v_critic.load_state_dict(checkpoint['v_critic'])
    v_optimizer.load_state_dict(checkpoint['v_optimizer'])
    q_critic.load_state_dict(checkpoint['q_critic'])
    q_critic_target.load_state_dict(checkpoint['q_critic_target'])
    q_optimizer.load_state_dict(checkpoint['q_optimizer'])
    
    start_epoch = checkpoint['epoch']
    means = checkpoint['means']
    stds = checkpoint['stds']
    
    print(f'Resuming run from epoch {start_epoch}')

for epoch in range(start_epoch, EPOCHS):
    total_q_loss = total_v_loss = total_policy_loss = count = 0
    policy_losses = []
    # for batch in tqdm(dataloader):
    for batch in tqdm(sampler):
        state = batch['state'].to(DEVICE)
        next_state = batch['next_state'].to(DEVICE)
        action = batch['action'].to(DEVICE)
        reward = einops.rearrange(batch['reward'], 'b -> b 1').to(DEVICE)
        reward = reward / (max_rew - min_rew) * 1000
        not_done = einops.rearrange(batch['not_done'], 'b -> b 1').to(DEVICE)

        v_loss = expectile_loss(torch.minimum(*q_critic_target(state,action))-v_critic(state), expectile=tau).mean()
        v_optimizer.zero_grad()
        v_loss.backward()
        v_optimizer.step()

        target = reward+not_done*discount*v_critic(next_state)
        q_values = q_critic(state,action)
        q_loss = F.mse_loss(q_values[0], target) + F.mse_loss(q_values[1], target)
        q_optimizer.zero_grad()
        q_loss.backward()
        q_optimizer.step()
        for var, var_target in zip(q_critic.parameters(), q_critic_target.parameters()):
            var_target.data = alpha * var.data + (1.0 - alpha) * var_target.data

        mean, log_std = iql_policy(state)
        
        advantage = torch.minimum(*q_critic_target(state,action))-v_critic(state)
        weight = torch.clamp(torch.exp(beta*advantage),max=exp_advantage_max)
        policy_loss = (weight*F.gaussian_nll_loss(mean, action, log_std.exp()**2, reduction='none')).mean()
        policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_optimizer.step()

        policy_lr_schedule.step()
        total_v_loss += v_loss.item()
        total_q_loss += q_loss.item()
        total_policy_loss += policy_loss.item()
        count += 1
        
    if (epoch + 1) % EVAL_FREQ == 0:
        rew_mean, rew_std = eval_policy(iql_policy, environment_name=ENVIRONMENT_NAME, eval_episodes=50)
        print(f'Epoch: {epoch + 1}. Q Loss: {total_q_loss / count:.4f}. V Loss: {total_v_loss / count:.4f}. P Loss: {total_policy_loss / count:.4f}. Reward: {rew_mean:.4f} +/- {rew_std:.4f}')
        means.append(rew_mean)
        stds.append(rew_std)

    # Save a checkpoint so that you can resume training if it crashes
    checkpoint = {
        'iql_policy': iql_policy.state_dict(),
        'policy_optimizer': policy_optimizer.state_dict(),
        'v_critic': v_critic.state_dict(),
        'v_optimizer': v_optimizer.state_dict(),
        'q_critic': q_critic.state_dict(),
        'q_critic_target': q_critic_target.state_dict(),
        'q_optimizer': q_optimizer.state_dict(),
        'epoch': epoch + 1,
        'means': means,
        'stds': stds
    }
    torch.save(checkpoint, 'iql_checkpoint.pth')

epochs = np.arange(EVAL_FREQ, EPOCHS + EVAL_FREQ, step=EVAL_FREQ)
plot_returns(means, stds, 'Implicit Q Learning', goal=0.4, epochs=epochs)

 64%|██████▎   | 2452/3858 [00:13<00:07, 189.08it/s]