In [4]:
import gymnasium as gym

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
from matplotlib import pyplot as plt
from collections import namedtuple, deque
import random
import math
from itertools import count

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x174e17710>

In [6]:
device = torch.device("cuda")
print(device)

cuda


In [7]:
Transition = namedtuple('Transition', ('state','action','next_state','reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen = capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [8]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [9]:
def one_hot_encoding(x):
    ans = torch.zeros(1,500)
    ans[0,x] = 1
    return ans

In [10]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END)*math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1

    if sample>eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1,1)

    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.float)

In [11]:

def plot_durations(i_episode, show_results = False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype = torch.float)
    tile = "Training " + str(i_episode)
    if(show_results):
        plt.title('Results')
    else:
        plt.clf() #clear the current figure
        plt.title(tile)

    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())

    if(len(durations_t)>=100):
        means = durations_t.unfold(0,100,1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means)) #for intitial set
        plt.plot(means.numpy())

    plt.pause(0.001)

    if is_ipython:
        if not show_results:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [12]:
def optimize_model():
    if len(memory)<BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions)) #convert (s,a,s,r),(s1,a1,s1,r1) to (s,s1) (a,a1)....
    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    # temp = [(s.shape,s) for s in batch.next_state if s is not None]
    # print(temp)
    # print("batch: ", len(batch.next_state))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device = device, dtype = torch.bool)
    # print("non_final_mask: ", non_final_mask)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action).long()
    reward_batch = torch.cat(batch.reward)
    # print(policy_net(state_batch).shape, action_batch.shape)
    # print("action_batch: ", action_batch, type(action_batch))
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    # torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [13]:
10000/200

50.0

In [19]:
# #####################Training########################
# num_episodes = 200
# episode_durations = []
# BATCH_SIZE = 128
# GAMMA = 0.99
# EPS_START = 1.0
# EPS_END = 0.01
# EPS_DECAY = 5000
# TAU = 0.001
# LR = 1e-4

# n_actions = env.action_space.n
# state, info = env.reset()
# n_observations = 1

# policy_net = DQN(500, n_actions).to(device)
# target_net = DQN(500, n_actions).to(device)
# target_net.load_state_dict(policy_net.state_dict())

# optimizer = optim.AdamW(policy_net.parameters(), lr = LR, amsgrad = True)
# memory = ReplayMemory(10000)

# steps_done = 0


# for i_episode in range(num_episodes):
#     state, info = env.reset()

#     state = one_hot_encoding(state).to(device).view(1,500)
#     # print(state)
#     for t in count():
#         action = select_action(state)
#         observation, reward, terminated, truncated, _ = env.step(action.item())
#         reward = torch.tensor([reward], device=device)
#         done = terminated or truncated

#         if terminated:
#             next_state = None
#         else:
#             # next_state = torch.tensor(observation, dtype=torch.float32, device=device).view(1,1)
#             next_state = one_hot_encoding(observation).to(device).view(1,500)


#         # Store the transition in memory
#         memory.push(state, action, next_state, reward)

#         # Move to the next state
#         state = next_state

#         # Perform one step of the optimization (on the policy network)
#         optimize_model()

#         # Soft update of the target network's weights
#         # θ′ ← τ θ + (1 −τ )θ′
#         target_net_state_dict = target_net.state_dict()
#         policy_net_state_dict = policy_net.state_dict()
#         for key in policy_net_state_dict:
#             target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
#         target_net.load_state_dict(target_net_state_dict)

#         if done:
#             episode_durations.append(t + 1)
#             # plot_durations()
#             if i_episode%100==0:
#               plot_durations(i_episode)



#             break


# print('Complete')
# plot_durations(0, show_results=True)
# plt.ioff()
# plt.show()

In [15]:
new_policy_net = DQN(500, 6)
device = 'mps'

In [16]:
new_policy_net.load_state_dict(torch.load('works_ig_3000.pth', map_location=torch.device('cpu')))
new_policy_net.to(device)

  new_policy_net.load_state_dict(torch.load('works_ig_3000.pth', map_location=torch.device('cpu')))


DQN(
  (layer1): Linear(in_features=500, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=6, bias=True)
)

In [17]:
new_policy_net

DQN(
  (layer1): Linear(in_features=500, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=6, bias=True)
)

In [18]:
# ######################Record Video######################
# import gymnasium as gym
# from gym.wrappers.record_video import RecordVideo
# import torch

# # Create and wrap the environment
# env = gym.make("Taxi-v3", render_mode="rgb_array")
# env = RecordVideo(env, video_folder="./videos", episode_trigger=lambda x: True)  # Record every episode

# # Function to get action from your policy network
# def get_action(state, policy_net):
#     state = one_hot_encoding(state).to(device).view(1,500)
#     with torch.no_grad():
#         state_tensor = torch.tensor(state, dtype=torch.float32)
#         q_values = policy_net(state_tensor)
#         return q_values.argmax().item()

# # Record a few episodes
# num_episodes = 3
# for episode in range(num_episodes):
#     state = env.reset()[0]  # Get initial state
#     done = False
#     total_reward = 0
    
#     while not done:
#         # Get action from your policy
#         action = get_action(state, new_policy_net)
#         # print(action, type(action))
#         # Take step in environment
#         next_state, reward, done, truncated, info = env.step(action)
#         total_reward += reward
#         state = next_state
        
#     print(f"Episode {episode + 1} completed with reward: {total_reward}")

# env.close()

# AIRL Training