In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random

# Environment
env = gym.make("CartPole-v1",render_mode='human')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Hyperparameters
lr_actor = 3e-4
lr_discriminator = 1e-3
gamma = 0.99
epsilon = 0.2
batch_size = 64
num_episodes = 100

In [2]:
class Policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.fc(state)

policy = Policy()
optimizer_actor = optim.Adam(policy.parameters(), lr=lr_actor)

In [23]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),  # Now expects state_dim + action_dim
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, state, action):
        action_onehot = torch.zeros(action.shape[0], action_dim)  # Shape: [batch_size, 2]
        action_onehot.scatter_(1, action.long(), 1)  # Convert action to one-hot
        x = torch.cat([state, action_onehot], dim=-1)  # Shape: [batch_size, 4+2=6]
        return self.fc(x)
    
discriminator = Discriminator()
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr_discriminator)

In [7]:
def generate_expert_data(num_trajectories=50):
    expert_states, expert_actions = [], []
    for _ in range(num_trajectories):
        state = env.reset()[0]
        done = False
        while not done:
            action = 1 if state[2] > 0 else 0  # Biased policy (balance pole)
            expert_states.append(state)
            expert_actions.append(action)
            state, _, done,_ , _ = env.step(action)
    return np.array(expert_states), np.array(expert_actions)

expert_states, expert_actions = generate_expert_data()

In [8]:

def update_discriminator(expert_states, expert_actions, policy_states, policy_actions):
    expert_states_tensor = torch.FloatTensor(expert_states)
    expert_actions_tensor = torch.FloatTensor(expert_actions).unsqueeze(1)
    policy_states_tensor = torch.FloatTensor(policy_states)
    policy_actions_tensor = torch.FloatTensor(policy_actions).unsqueeze(1)
    
    # Discriminator loss
    expert_probs = discriminator(expert_states_tensor, expert_actions_tensor)
    policy_probs = discriminator(policy_states_tensor, policy_actions_tensor)
    loss_d = -torch.mean(torch.log(expert_probs + 1e-8)) - torch.mean(torch.log(1 - policy_probs + 1e-8))
    
    optimizer_discriminator.zero_grad()
    loss_d.backward()
    optimizer_discriminator.step()


In [9]:

def compute_advantages(rewards):
    # Simplified advantage calculation (GAIL typically uses GAE)
    return (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)


In [10]:

def update_policy(states, actions, advantages):
    states_tensor = torch.FloatTensor(np.array(states))
    actions_tensor = torch.LongTensor(np.array(actions))
    advantages_tensor = torch.FloatTensor(advantages)
    
    # PPO loss
    old_probs = policy(states_tensor).gather(1, actions_tensor.unsqueeze(1))
    for _ in range(3):  # PPO epochs
        new_probs = policy(states_tensor).gather(1, actions_tensor.unsqueeze(1))
        ratio = new_probs / old_probs.detach()
        surrogate_loss = -torch.min(
            ratio * advantages_tensor.unsqueeze(1),
            torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages_tensor.unsqueeze(1)
        ).mean()
        
        optimizer_actor.zero_grad()
        surrogate_loss.backward()
        optimizer_actor.step()


In [21]:
def train_gail():
    for episode in range(num_episodes):
        # Collect policy trajectories
        states, actions, rewards = [], [], []
        state = env.reset()
        state = state[0]
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = policy(state_tensor)
            action = torch.multinomial(action_probs, 1).item()
            
            next_state, _, done, _, _ = env.step(action)
            
            states.append(state)
            actions.append(action)
            rewards.append(0.0)  # Placeholder (updated later)
            state = next_state
        
        # Convert to tensors
        states_tensor = torch.FloatTensor(np.array(states))
        actions_tensor = torch.FloatTensor(np.array(actions)).unsqueeze(1)
        
        # Compute rewards using discriminator
        with torch.no_grad():
            policy_rewards = torch.log(discriminator(states_tensor, actions_tensor) + 1e-8)
        rewards = policy_rewards.squeeze().numpy()
        
        # Update policy (PPO)
        advantages = compute_advantages(rewards)  # Simplified advantage calculation
        update_policy(states, actions, advantages)
        
        # Update discriminator
        update_discriminator(expert_states, expert_actions, states, actions)
        
        print(f"Episode {episode}, Reward: {len(states)}")


In [26]:

# Start training
train_gail()

Episode 0, Reward: 17
Episode 1, Reward: 68
Episode 2, Reward: 25
Episode 3, Reward: 23
Episode 4, Reward: 30
Episode 5, Reward: 30
Episode 6, Reward: 28
Episode 7, Reward: 12
Episode 8, Reward: 30
Episode 9, Reward: 18
Episode 10, Reward: 17
Episode 11, Reward: 14
Episode 12, Reward: 10
Episode 13, Reward: 20
Episode 14, Reward: 47
Episode 15, Reward: 13
Episode 16, Reward: 11
Episode 17, Reward: 26
Episode 18, Reward: 12
Episode 19, Reward: 19
Episode 20, Reward: 9
Episode 21, Reward: 9
Episode 22, Reward: 13
Episode 23, Reward: 29
Episode 24, Reward: 24
Episode 25, Reward: 14
Episode 26, Reward: 30
Episode 27, Reward: 14
Episode 28, Reward: 13
Episode 29, Reward: 16
Episode 30, Reward: 11
Episode 31, Reward: 17
Episode 32, Reward: 18
Episode 33, Reward: 19
Episode 34, Reward: 16
Episode 35, Reward: 17
Episode 36, Reward: 14
Episode 37, Reward: 10
Episode 38, Reward: 13
Episode 39, Reward: 12
Episode 40, Reward: 25
Episode 41, Reward: 30
Episode 42, Reward: 32
Episode 43, Reward: 12
