In [4]:
#@@@@@@@@@@@@@@@@@@@@@@@
#import useful libraries
#@@@@@@@@@@@@@@@@@@@@@@@
import pandas as pd
import numpy as np
import copy
from scipy import interpolate
from scipy.stats import truncnorm
import pickle
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
from pydantic import BaseModel
import csv
import os
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List

from game import TheGang
from models import HandFeatures

In [5]:
import torch

# Check if a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU


In [6]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, hidden_layer_size, hidden_layer_count = 1):
        super(PolicyNetwork, self).__init__()
        self.hidden_layers = [nn.Linear(input_size, hidden_layer_size)]
        nn.init.normal_(self.hidden_layers[0].weight, mean = 0, std = 0.01)
        nn.init.normal_(self.hidden_layers[0].bias, mean = 0, std = 0.01)
        for _ in range(hidden_layer_count-1):
            layer = nn.Linear(hidden_layer_size, hidden_layer_size)
            self.hidden_layers.append(layer)
            nn.init.normal_(layer.weight, mean = 0, std = 0.01)
            nn.init.normal_(layer.bias, mean = 0, std = 0.01)
        self.output_layer = nn.Linear(hidden_layer_size, 4)    #the output layer with probabiliy for each action
        self.output_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0, 0.0, 0.0])) 

    def forward(self, s: torch.Tensor) -> torch.Tensor:
        '''A function to do the forward pass
            Takes:
                s -- the state representation
            Returns:
                a tensor of probabilities
        '''         
        device_tensor = s.to(device)
        for layer in self.hidden_layers:
            s = torch.relu(layer(s))    #pass through the hidden layers
        s = self.output_layer(s)
        action_probs = torch.softmax(s, dim=1)    #use softmax to get action probabilities
        return action_probs.to('cpu')


In [12]:
class AgentConfig(BaseModel):
    hidden_layer_size: int = 16
    hidden_layer_count: int = 16
    learning_rate: float = 0.0005

class Agent():
    def __init__(self, config: AgentConfig):
        self.config = config
        self.pi = PolicyNetwork(17, config.hidden_layer_size, config.hidden_layer_count).to(device)
        self.optimizer = optim.Adam(self.pi.parameters(), lr=self.config.learning_rate)

    def generate_actions(self, state_array: List[List[int]]):
        action_probs = self.pi(torch.tensor(state_array, dtype=torch.float32)) # Create tensor and feed state through model
        action_probs.to("cpu")
        sampled_actions = torch.multinomial(action_probs, 1).squeeze(dim=1)
        action_space = torch.tensor([1, 2, 3, 4])
        final_actions = action_space[sampled_actions]
        return final_actions

    def generate_action_probs(self, a: List[int], state_array: List[List[int]]):
        all_action_probs = self.pi(torch.tensor(state_array, dtype=torch.float32) # Create tensor and feed state through model
        all_action_probs.to("cpu")
        a_tensor = torch.tensor(a, dtype=torch.long)
        a_tensor = a_tensor - 1
        performed_action_probs = all_action_probs.gather(1, a_tensor.unsqueeze(1)).squeeze(1)
        log_probs = torch.log(performed_action_probs)
        return log_probs

    def checkpoint(self, model_name, epoch_count: int):
        directory = f'checkpoints/{model_name}'
        os.makedirs(directory, exist_ok=True)
        torch.save(self.pi.state_dict(), f'{directory}/epi_{epoch_count}.pth')
        
    def load_checkpoint(self, model_name: str, epoch_count: int):
        self.pi.load_state_dict(torch.load(f'checkpoints/{model_name}/epi_{epoch_count}.pth'))


In [13]:
class TrainingConfig(BaseModel):
    discount_factor: float = 0.5
    batch_size: int = 100
    checkpoint_freq: int = 5000

In [14]:
def train_agent(
    agent: Agent,
    env: TheGang,
    training_config: TrainingConfig,
    num_epochs: int,
    starting_epoch: int,
    model_name: str
):
    average_epoch_rewards = []
    for epoch in tqdm(range(num_epochs)):
        epoch_batch_rewards = []
        batch = []
        # Repead episodes until epoch batch size is reached
        for epi in range(training_config.batch_size):
            env.reset()

            episode_rewards = []
            episode_states = []
            episode_actions = []
            state = {'state':env.generate_state_array([0,0,0,0]),'reward':[0,0,0,0],'done':False}

            # Loop turns in episode
            while not state['done']:
                episode_states.append(state['state'])
                a = agent.generate_actions(state['state']).tolist()
                episode_actions.append(a)
                state = env.step(a)
                episode_rewards.append(state['reward'])

            # Track fire rate and terminal rewards for non-fired episodes
            epoch_batch_rewards.extend(episode_rewards)

            # Calcualte causal reward
            causal_returns = []
            rolling_causal_return = np.zeros(len(episode_rewards[0]))
            # Calculate causal returns in reverse order
            for rewards in episode_rewards[::-1]:
                causal_return = rewards + training_config.discount_factor * rolling_causal_return
                rolling_causal_return = causal_return
                causal_returns.append(causal_return)
            # Reverse the causal returns to get the correct order
            causal_returns = causal_returns[::-1]

            for state, action, reward, causal_return in zip(episode_states, episode_actions, episode_rewards, causal_returns):
                batch.append({'s_t': state, 'a_t': action, 'r_t': reward, 'cr_t': causal_return})

        # Calculate epoch performance
        average_epoch_rewards.append(np.mean(epoch_batch_rewards))
        
        # Checkpoint if applicable
        if epoch % training_config.checkpoint_freq == 0:
            agent.checkpoint(model_name, starting_epoch + num_epochs)
            with open(f'checkpoints/{model_name}/epoch_rewards.csv', 'w', newline="") as f:
                writer = csv.writer(f)
                writer.writerow(epoch_batch_rewards)

        log_probs = []
        batch_causal_returns = []
        for step in batch:    # Loop over batch
            states = step['s_t']
            actions = step['a_t']
            action_probs = agent.generate_action_probs(actions, states)    # Compute action probability from policy
            log_probs.append(action_probs.squeeze())    # Record the log probability of the chosen action
            batch_causal_returns.append(step['cr_t'])

        # Perform baselining
        baseline = np.mean(batch_causal_returns)
        baselined_causal_returns = batch_causal_returns - baseline

        baselined_causal_returns = torch.tensor(baselined_causal_returns)
        log_probs = torch.stack(log_probs)    #reshape to compute gradient over the whole episode #Shape: (batch_size, 4)

        objective = -torch.sum(log_probs * baselined_causal_returns)/len(batch) # Batch, and do "-" to convert "loss" to "gain"

        agent.optimizer.zero_grad()    #zero gradients from the previous step
        objective.backward()    #compute gradients
        agent.optimizer.step()

    agent.checkpoint(model_name, starting_epoch + num_epochs)
    with open(f'checkpoints/{model_name}/epoch_rewards.csv', 'w', newline="") as f:
        writer = csv.writer(f)
        writer.writerow(epoch_batch_rewards)

    return epoch_batch_rewards



In [18]:
next(agent.pi.parameters()).device

device(type='cuda', index=0)

In [15]:
agent_config = AgentConfig()
training_config = TrainingConfig()
agent = Agent(agent_config)
env = TheGang()
fire_rate, average_terminal_states = train_agent(agent, env, training_config, 40000, 0, 'Baseline')

  0%|          | 0/40000 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)