In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
import pickle
from matplotlib import pyplot as plt    

# Environment
env = gym.make("CartPole-v1",render_mode='human')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Hyperparameters
lr_actor = 3e-4
lr_discriminator = 1e-3
gamma = 0.99
epsilon = 0.2
batch_size = 64
num_episodes = 100

In [61]:
class Policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.fc(state)

policy = Policy()
optimizer_actor = optim.Adam(policy.parameters(), lr=lr_actor)

In [62]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),  # Now expects state_dim + action_dim
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, state, action):
        action_onehot = torch.zeros(action.shape[0], action_dim)  # Shape: [batch_size, 2]
        action_onehot.scatter_(1, action.long(), 1)  # Convert action to one-hot
        x = torch.cat([state, action_onehot], dim=-1)  # Shape: [batch_size, 4+2=6]
        return self.fc(x)
    
discriminator = Discriminator()
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr_discriminator)

In [63]:
# def generate_expert_data(num_trajectories=50):
#     expert_states, expert_actions = [], []
#     for _ in range(num_trajectories):
#         state = env.reset()[0]
#         done = False
#         while not done:
#             action = 1 if state[2] > 0 else 0  # Biased policy (balance pole)
#             expert_states.append(state)
#             expert_actions.append(action)
#             state, _, done,_ , _ = env.step(action)
#     return np.array(expert_states), np.array(expert_actions)

# expert_states, expert_actions = generate_expert_data()

In [65]:
dest_dir = "./expert_data"
checkpoint_files = ["ckpt0.pkl"]

for ckpt in checkpoint_files:
    # load expert data from pickle
    with open(f"{dest_dir}/{ckpt}", "rb") as f:
        all_data = pickle.load(f)
    expert_states = all_data["states"]
    expert_actions = all_data["actions"]

In [66]:
for i in range(len(expert_states)):
    expert_states[i] = expert_states[i].tolist()
    expert_actions[i] = int(expert_actions[i])

In [67]:
print(expert_states[:5])
print(expert_actions[:5])

[[0.021411817520856857, 0.017583424225449562, -0.018309468403458595, -0.023649049922823906], [0.021763484925031662, -0.17727123200893402, -0.018782449886202812, 0.2632012665271759], [0.018218060955405235, 0.01811371184885502, -0.013518424704670906, -0.035346176475286484], [0.01858033426105976, -0.1768117994070053, -0.014225348830223083, 0.2530410885810852], [0.01504409871995449, 0.018510356545448303, -0.00916452705860138, -0.044094622135162354]]
[0, 1, 0, 1, 0]


In [68]:

def update_discriminator(expert_states, expert_actions, policy_states, policy_actions):
    expert_states_tensor = torch.FloatTensor(expert_states)
    expert_actions_tensor = torch.FloatTensor(expert_actions).unsqueeze(1)
    policy_states_tensor = torch.FloatTensor(policy_states)
    policy_actions_tensor = torch.FloatTensor(policy_actions).unsqueeze(1)
    
    # Discriminator loss
    expert_probs = discriminator(expert_states_tensor, expert_actions_tensor)
    policy_probs = discriminator(policy_states_tensor, policy_actions_tensor)
    loss_d = -torch.mean(torch.log(expert_probs + 1e-8)) - torch.mean(torch.log(1 - policy_probs + 1e-8))
    
    optimizer_discriminator.zero_grad()
    loss_d.backward()
    optimizer_discriminator.step()


In [69]:

def compute_advantages(rewards):
    # Simplified advantage calculation (GAIL typically uses GAE)
    return (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)


In [70]:

def update_policy(states, actions, advantages):
    states_tensor = torch.FloatTensor(np.array(states))
    actions_tensor = torch.LongTensor(np.array(actions))
    advantages_tensor = torch.FloatTensor(advantages)
    
    # PPO loss
    old_probs = policy(states_tensor).gather(1, actions_tensor.unsqueeze(1))
    for _ in range(3):  # PPO epochs
        new_probs = policy(states_tensor).gather(1, actions_tensor.unsqueeze(1))
        ratio = new_probs / old_probs.detach()
        surrogate_loss = -torch.min(
            ratio * advantages_tensor.unsqueeze(1),
            torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages_tensor.unsqueeze(1)
        ).mean()
        
        optimizer_actor.zero_grad()
        surrogate_loss.backward()
        optimizer_actor.step()


In [71]:
def train_gail():
    for episode in range(num_episodes):
        # Collect policy trajectories
        states, actions, rewards = [], [], []
        state = env.reset()
        state = state[0]
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = policy(state_tensor)
            action = torch.multinomial(action_probs, 1).item()
            
            next_state, _, done, _, _ = env.step(action)
            
            states.append(state)
            actions.append(action)
            rewards.append(0.0)  # Placeholder (updated later)
            state = next_state
        
        # Convert to tensors
        states_tensor = torch.FloatTensor(np.array(states))
        actions_tensor = torch.FloatTensor(np.array(actions)).unsqueeze(1)
        
        # Compute rewards using discriminator
        with torch.no_grad():
            policy_rewards = torch.log(discriminator(states_tensor, actions_tensor) + 1e-8)
        rewards = policy_rewards.squeeze().numpy()
        
        # Update policy (PPO)
        advantages = compute_advantages(rewards)  # Simplified advantage calculation
        update_policy(states, actions, advantages)
        
        # Update discriminator
        update_discriminator(expert_states, expert_actions, states, actions)
        
        print(f"Episode {episode}, Reward: {len(states)}")



In [72]:

# Start training
train_gail()

Episode 0, Reward: 9
Episode 1, Reward: 37
Episode 2, Reward: 10
Episode 3, Reward: 15
Episode 4, Reward: 16
Episode 5, Reward: 24
Episode 6, Reward: 12
Episode 7, Reward: 13
Episode 8, Reward: 11
Episode 9, Reward: 13
Episode 10, Reward: 19
Episode 11, Reward: 22
Episode 12, Reward: 18
Episode 13, Reward: 15
Episode 14, Reward: 11
Episode 15, Reward: 10
Episode 16, Reward: 15
Episode 17, Reward: 13
Episode 18, Reward: 19
Episode 19, Reward: 18
Episode 20, Reward: 11
Episode 21, Reward: 22
Episode 22, Reward: 13
Episode 23, Reward: 18
Episode 24, Reward: 18
Episode 25, Reward: 24
Episode 26, Reward: 27
Episode 27, Reward: 11
Episode 28, Reward: 23
Episode 29, Reward: 16
Episode 30, Reward: 9
Episode 31, Reward: 15
Episode 32, Reward: 10
Episode 33, Reward: 24
Episode 34, Reward: 15
Episode 35, Reward: 8
Episode 36, Reward: 12
Episode 37, Reward: 11
Episode 38, Reward: 12
Episode 39, Reward: 13
Episode 40, Reward: 13
Episode 41, Reward: 9
Episode 42, Reward: 14
Episode 43, Reward: 17
Ep

In [73]:
st = [0.5, 0.5, 0.5, 0.5]
policy(torch.FloatTensor(st).unsqueeze(0))  # Example usage of policy
torch.multinomial(policy(torch.FloatTensor(st).unsqueeze(0)), 1).item()  # Example action selection

1

In [79]:
env = gym.make("CartPole-v1", render_mode='human')

In [80]:
def run_episode(env, policy, render):
    state = env.reset()[0]
    total_reward = 0
    
    while True:
        if render:
            env.render()  # Visualize the episode
        
        # Convert state to tensor and get action probabilities
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy(state_tensor)
        
        # Sample an action from the probability distribution
        action = torch.multinomial(action_probs, 1).item()
        
        # Take the action in the environment
        next_state, reward, done,_, _ = env.step(action)
        total_reward += reward
        state = next_state
        
        if done:
            break
    
    return total_reward

# Test the policy for 10 episodes
for episode in range(10):
    reward = run_episode(env, policy, render=True)
    print(f"Episode {episode + 1}, Reward: {reward}")

env.close()  # Close the environment

Episode 1, Reward: 13.0
Episode 2, Reward: 10.0
Episode 3, Reward: 11.0
Episode 4, Reward: 15.0
Episode 5, Reward: 10.0
Episode 6, Reward: 31.0
Episode 7, Reward: 11.0
Episode 8, Reward: 9.0
Episode 9, Reward: 15.0
Episode 10, Reward: 13.0
