# environments/dungeonmayhem

## card.py

In [None]:
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple
from itertools import count

@dataclass
class DungeonMayhem_Card:
    attack: int = 0
    defend: int = 0
    draw: int = 0
    heal: int = 0
    play: int = 0
    power: Optional[Tuple[int, Callable[[Any, Any, Any], None]]] = None

    def __post_init__(self):
        self.player = None
        self.card_id = -1
        self.action_id = -1
        self.action_encoding = None

    def discard(self):
        if self.player is not None and self.player.health > 0:
            self.player.discardpile.append(self)

def DestroyDefense_Power(game: Any, player: Any, target: Any):
    if len(target.defenses) == 0:
        return
    target.destroy_defense()

def BarbarianDiscardHand_Power(game: Any, player: Any, target: Any):
    for p in game.active_players:
        p.discard_hand()
        p.draw(3)

def BarbarianHeal_Power(game: Any, player: Any, target: Any):
    for p in game.active_players:
        if p != player:
            player.heal(1)
            p.take_damage(1)

def PaladinDestroyAllDefenses_Power(game: Any, player: Any, target: Any):
    for p in game.active_players:
        while len(p.defenses) > 0:
            p.destroy_defense()

def PaladinDrawDiscard_Power(game: Any, player: Any, target: Any):
    if len(player.discardpile) == 0:
        return
    pick = 0
    card = player.discardpile.pop(pick)
    player.hand.append(card)

def RogueGainImmunity_Power(game: Any, player: Any, target: Any):
    game.active_players.remove(player)
    game.immune_players.append(player)
    
def RogueStealDiscard_Power(game: Any, player: Any, target: Any):
    if len(target.discardpile) == 0:
        return
    card = target.discardpile.pop()
    player.hand.append(card)
    game.play_card(player, card)

def WizardFireball_Power(game: Any, player: Any, target: Any):
    for p in game.active_players:
        p.take_damage(3)

def WizardStealDefense_Power(game: Any, player: Any, target: Any):
    if len(target.defenses) == 0:
        return
    defense = target.defenses.pop()
    player.defenses.append(defense)

def WizardSwapHealth_Power(game: Any, player: Any, target: Any):
    if target.health > 0:
        player.health, target.health = target.health, player.health

power_counter = count(1)
DestroyDefense = (next(power_counter), DestroyDefense_Power)
BarbarianDiscardHand = (next(power_counter), BarbarianDiscardHand_Power)
BarbarianHeal = (next(power_counter), BarbarianHeal_Power)
PaladinDestroyAllDefenses = (next(power_counter), PaladinDestroyAllDefenses_Power)
PaladinDrawDiscard = (next(power_counter), PaladinDrawDiscard_Power)
RogueGainImmunity = (next(power_counter), RogueGainImmunity_Power)
RogueStealDiscard = (next(power_counter), RogueStealDiscard_Power)
WizardFireball = (next(power_counter), WizardFireball_Power)
WizardStealDefense = (next(power_counter), WizardStealDefense_Power)
WizardSwapHealth = (next(power_counter), WizardSwapHealth_Power)

## character.py

In [None]:
from itertools import count

# from environments.dungeonmayhem.card import *

class DungeonMayhem_Character:
    deck = []
    @classmethod
    def get_deck(cls):
        return list(cls.deck)

character_counter = count()

class DungeonMayhem_Barbarian(DungeonMayhem_Character):
    character_id = next(character_counter)
    deck = [
        DungeonMayhem_Card(attack = 2), # BrutalPunch(),
        DungeonMayhem_Card(attack = 2), # BrutalPunch(),
        DungeonMayhem_Card(play = 2), # TwoAxes(),
        DungeonMayhem_Card(play = 2), # TwoAxes(),
        DungeonMayhem_Card(attack = 1, play = 1), # HeadButt(),
        DungeonMayhem_Card(attack = 1, play = 1), # HeadButt(),
        DungeonMayhem_Card(attack = 3), # BigAxe(),
        DungeonMayhem_Card(attack = 3), # BigAxe(),
        DungeonMayhem_Card(attack = 3), # BigAxe(),
        DungeonMayhem_Card(attack = 3), # BigAxe(),
        DungeonMayhem_Card(attack = 3), # BigAxe(),
        DungeonMayhem_Card(attack = 4), # Rage(),
        DungeonMayhem_Card(attack = 4), # Rage(),
        DungeonMayhem_Card(defend = 3), # Riff(),
        DungeonMayhem_Card(defend = 3), # Raff(),
        DungeonMayhem_Card(defend = 2), # SpikedShield(),
        DungeonMayhem_Card(defend = 1, draw = 1), # BagOfRats(),
        DungeonMayhem_Card(draw = 2, heal = 1), # SnackTime(),
        DungeonMayhem_Card(draw = 2), # OpenTheArmory(),
        DungeonMayhem_Card(draw = 2), # OpenTheArmory(),
        DungeonMayhem_Card(draw = 1, heal = 1), # Flex(),
        DungeonMayhem_Card(draw = 1, heal = 1), # Flex(),
        DungeonMayhem_Card(draw = 1, power = DestroyDefense), # MightyToss(),
        DungeonMayhem_Card(draw = 1, power = DestroyDefense), # MightyToss(),
        DungeonMayhem_Card(play = 1, power = BarbarianDiscardHand), # BattleRoar(),
        DungeonMayhem_Card(play = 1, power = BarbarianDiscardHand), # BattleRoar(),
        DungeonMayhem_Card(power = BarbarianHeal), # WhirlingAxes(),
        DungeonMayhem_Card(power = BarbarianHeal), # WhirlingAxes(),
    ]

class DungeonMayhem_Paladin(DungeonMayhem_Character):
    character_id = next(character_counter)
    deck = [
        DungeonMayhem_Card(attack = 3), # ForTheMostJustice(),
        DungeonMayhem_Card(attack = 3), # ForTheMostJustice(),
        DungeonMayhem_Card(attack = 2), # ForEvenMoreJustice(),
        DungeonMayhem_Card(attack = 2), # ForEvenMoreJustice(),
        DungeonMayhem_Card(attack = 2), # ForEvenMoreJustice(),
        DungeonMayhem_Card(attack = 2), # ForEvenMoreJustice(),
        DungeonMayhem_Card(attack = 1, play = 1), # ForJustice(),
        DungeonMayhem_Card(attack = 1, play = 1), # ForJustice(),
        DungeonMayhem_Card(attack = 1, play = 1), # ForJustice(),
        DungeonMayhem_Card(play = 2), # FingerWagOfJudgment(),
        DungeonMayhem_Card(play = 2), # FingerWagOfJudgment(),
        DungeonMayhem_Card(attack = 3, heal = 1), # DivineSmite(),
        DungeonMayhem_Card(attack = 3, heal = 1), # DivineSmite(),
        DungeonMayhem_Card(attack = 3, heal = 1), # DivineSmite(),
        DungeonMayhem_Card(attack = 2, heal = 1), # FightingWords(),
        DungeonMayhem_Card(attack = 2, heal = 1), # FightingWords(),
        DungeonMayhem_Card(attack = 2, heal = 1), # FightingWords(),
        DungeonMayhem_Card(play = 2), # HighCharisma(),
        DungeonMayhem_Card(play = 2), # HighCharisma(),
        DungeonMayhem_Card(draw = 2, heal = 1), # CureWounds(),
        DungeonMayhem_Card(defend = 1, draw = 1), # SpinningParry(),
        DungeonMayhem_Card(defend = 1, draw = 1), # SpinningParry(),
        DungeonMayhem_Card(defend = 3), # DivineShield(),
        DungeonMayhem_Card(defend = 3), # DivineShield(),
        DungeonMayhem_Card(defend = 2), # Fluffly(),
        DungeonMayhem_Card(play = 1, power = PaladinDestroyAllDefenses), # BanishingSmite(),
        DungeonMayhem_Card(heal = 2, power = PaladinDrawDiscard), # DivineInspiration(),
        DungeonMayhem_Card(heal = 2, power = PaladinDrawDiscard), # DivineInspiration(),
    ]

class DungeonMayhem_Rogue(DungeonMayhem_Character):
    character_id = next(character_counter)
    deck = [
        DungeonMayhem_Card(attack = 3), # AllTheThrownDaggers(),
        DungeonMayhem_Card(attack = 3), # AllTheThrownDaggers(),
        DungeonMayhem_Card(attack = 3), # AllTheThrownDaggers(),
        DungeonMayhem_Card(attack = 2), # TwoThrownDaggers(),
        DungeonMayhem_Card(attack = 2), # TwoThrownDaggers(),
        DungeonMayhem_Card(attack = 2), # TwoThrownDaggers(),
        DungeonMayhem_Card(attack = 2), # TwoThrownDaggers(),
        DungeonMayhem_Card(attack = 1, play = 1), # OneThrownDagger(),
        DungeonMayhem_Card(attack = 1, play = 1), # OneThrownDagger(),
        DungeonMayhem_Card(attack = 1, play = 1), # OneThrownDagger(),
        DungeonMayhem_Card(attack = 1, play = 1), # OneThrownDagger(),
        DungeonMayhem_Card(attack = 1, play = 1), # OneThrownDagger(),
        DungeonMayhem_Card(play = 2), # CunningAction(),
        DungeonMayhem_Card(play = 2), # CunningAction(),
        DungeonMayhem_Card(play = 1, heal = 1), # StolenPotion(),
        DungeonMayhem_Card(play = 1, heal = 1), # StolenPotion(),
        DungeonMayhem_Card(draw = 2, heal = 1), # EvenMoreDaggers(),
        DungeonMayhem_Card(draw = 1, defend = 1), # WingedSerpent(),
        DungeonMayhem_Card(draw = 1, defend = 1), # WingedSerpent(),
        DungeonMayhem_Card(defend = 2), # TheGoonSquad(),
        DungeonMayhem_Card(defend = 2), # TheGoonSquad(),
        DungeonMayhem_Card(defend = 3), # MyLittleFriend(),
        DungeonMayhem_Card(power = RogueGainImmunity), # CleverDisguise(),
        DungeonMayhem_Card(power = RogueGainImmunity), # CleverDisguise(),
        DungeonMayhem_Card(play = 1, power = DestroyDefense), # SneakAttack(),
        DungeonMayhem_Card(play = 1, power = DestroyDefense), # SneakAttack(),
        DungeonMayhem_Card(play = 1, power = RogueStealDiscard), # PickPocket(),
        DungeonMayhem_Card(play = 1, power = RogueStealDiscard), # PickPocket(),
    ]

class DungeonMayhem_Wizard(DungeonMayhem_Character):
    character_id = next(character_counter)
    deck = [
        DungeonMayhem_Card(attack = 3), # LightningBolt(),
        DungeonMayhem_Card(attack = 3), # LightningBolt(),
        DungeonMayhem_Card(attack = 3), # LightningBolt(),
        DungeonMayhem_Card(attack = 3), # LightningBolt(),
        DungeonMayhem_Card(attack = 2), # BurningHands(),
        DungeonMayhem_Card(attack = 2), # BurningHands(),
        DungeonMayhem_Card(attack = 2), # BurningHands(),
        DungeonMayhem_Card(attack = 1, play = 1), # MagicMissile(),
        DungeonMayhem_Card(attack = 1, play = 1), # MagicMissile(),
        DungeonMayhem_Card(attack = 1, play = 1), # MagicMissile(),
        DungeonMayhem_Card(play = 2), # SpeedOfThought(),
        DungeonMayhem_Card(play = 2), # SpeedOfThought(),
        DungeonMayhem_Card(play = 2), # SpeedOfThought(),
        DungeonMayhem_Card(heal = 1, play = 1), # EvilSneer(),
        DungeonMayhem_Card(heal = 1, play = 1), # EvilSneer(),
        DungeonMayhem_Card(draw = 3), # KnowledgeIsPower(),
        DungeonMayhem_Card(draw = 3), # KnowledgeIsPower(),
        DungeonMayhem_Card(draw = 3), # KnowledgeIsPower(),
        DungeonMayhem_Card(defend = 1, draw = 1), # Shield(),
        DungeonMayhem_Card(defend = 1, draw = 1), # Shield(),
        DungeonMayhem_Card(defend = 2), # Stoneskin(),
        DungeonMayhem_Card(defend = 3), # MirrorImage(),
        DungeonMayhem_Card(power = WizardFireball), # Fireball(),
        DungeonMayhem_Card(power = WizardFireball), # Fireball(),
        DungeonMayhem_Card(power = WizardStealDefense), # Charm(),
        DungeonMayhem_Card(power = WizardStealDefense), # Charm(),
        DungeonMayhem_Card(power = WizardSwapHealth), # VampiricTouch(),
        DungeonMayhem_Card(power = WizardSwapHealth), # VampiricTouch(),        
    ]

## player.py

In [None]:
class DungeonMayhem_Player:

    def __init__(self, player_id, np_rng):
        self.player_id = player_id
        self.np_rng = np_rng
        self.reset()

    def reset(self):
        self.character = None
        self.deck = []
        self.hand = []
        self.discardpile = []
        self.defenses = []
        self.health = 10
        self.plays = 1

    def get_character_deck(self, character):
        self.character = character
        self.deck = character.get_deck()
        self.np_rng.shuffle(self.deck)
        for card in self.deck:
            card.player = self

    def draw(self, n):
        for _ in range(n):
            if len(self.deck) == 0:
                self.deck = self.discardpile
                self.discardpile = []
                self.np_rng.shuffle(self.deck)
            if len(self.deck) == 0:
                return
            card = self.deck.pop()
            self.hand.append(card)

    def discard_hand(self):
        self.discardpile.extend(self.hand)
        self.hand = []

    def total_health(self):
        return self.health + self.total_defenses()

    def total_defenses(self):
        return sum(defense[0] for defense in self.defenses)

    def start_turn(self):
        self.plays = 1
        self.draw(1)

    def add_defense(self, card):
        self.defenses.append((card.defend, card))
        self.defenses.sort(key=lambda x: x[0])

    def destroy_defense(self):
        (_, card) = self.defenses.pop()
        card.discard()

    def take_damage(self, n):
        if n == 0:
            return
        if len(self.defenses) == 0:
            self.health = max(self.health-n, 0)
            return
        if self.defenses[0][0] > n:
            self.defenses[0] = (self.defenses[0][0]-n, self.defenses[0][1])
            return
        (m, card) = self.defenses.pop(0)
        card.discard()
        self.take_damage(n-m)

    def heal(self, n):
        self.health = min(self.health+n, 10)

    def play(self, n):
        self.plays += n

## game.py

In [None]:
# from environments.dungeonmayhem.player import DungeonMayhem_Player
# from environments.dungeonmayhem.character import DungeonMayhem_Barbarian, DungeonMayhem_Paladin, DungeonMayhem_Rogue, DungeonMayhem_Wizard

class DungeonMayhem_Game:

    def __init__(self, np_rng):
        self.np_rng = np_rng
        self.players = [DungeonMayhem_Player(i, self.np_rng) for i in range(4)]
        self.characters = [DungeonMayhem_Barbarian, DungeonMayhem_Paladin, DungeonMayhem_Rogue, DungeonMayhem_Wizard]
        self.reset()

    def reset(self, char_perm=None, order_perm=None):
        for player in self.players:
            player.reset()
        self.active_players = list(self.players)
        self.immune_players = []
        if order_perm is None:
            order_perm = self.np_rng.permutation(4)
        self.ordered_players = [self.players[i] for i in order_perm]
        if char_perm is None:
            char_perm = self.np_rng.permutation(4)
        self.permuted_characters = [self.characters[i] for i in char_perm]
        for player, character in zip(self.ordered_players, self.permuted_characters):
            player.get_character_deck(character)
            player.draw(3)
        self.current_player = self.ordered_players[0]
        self.current_player.start_turn()
        self.winner = None
        self.time_left = 200

    def play_card(self, player, card):
        player.hand.remove(card)
        player.play(-1)

        if card.power:
            card.power[1](self, player, self.select_target(player))
        if card.attack:
            self.select_target(player).take_damage(card.attack)
        if card.defend:
            player.add_defense(card)
        else:
            card.discard()
        if card.draw:
            player.draw(card.draw)
        if card.heal:
            player.heal(card.heal)
        if card.play:
            player.play(card.play)

        if len(player.hand) == 0:
            player.draw(2)
        for p in self.active_players:
            if p.health == 0:
                self.active_players.remove(p)
                self.ordered_players.remove(p)
                for (_, c) in p.defenses:
                    c.discard()
                p.deck = []
                p.hand = []
                p.discardpile = []
                p.defenses = []

        self.time_left -= 1
        if len(self.ordered_players) == 0 or self.time_left == 0:
            self.winner = DungeonMayhem_Player(4, self.np_rng)
        elif len(self.ordered_players) == 1:
            self.winner = self.ordered_players[0]
        else:
            if player == self.ordered_players[0] and player.plays == 0:
                self.ordered_players.remove(player)
                self.ordered_players.append(player)
            self.current_player = self.ordered_players[0]
            self.current_player.start_turn()
            if self.current_player in self.immune_players:
                self.immune_players.remove(self.current_player)
                self.active_players.append(self.current_player)

        # x = 0
        # for p in self.players:
        #     x += len(p.deck)
        #     x += len(p.hand)
        #     x += len(p.discardpile)
        #     x += len(p.defenses)   
        # print('')     
        # print(x)
        # print(player.character.character_id)
        # print(target.character.character_id)
        # print(card)
        # print(card.player.character.character_id)
        # for p in self.players:            
        #     p_list = [
        #         p.character.character_id,
        #         p.health,
        #         p.total_defenses(),
        #         p.plays,
        #         p in self.active_players,
        #         p in self.immune_players,
        #         p in self.ordered_players,
        #         len(p.deck),
        #         len(p.hand),
        #         len(p.discardpile),
        #         len(p.defenses),
        #     ]
        #     print(p_list)
        # if x % 28 > 3 and x % 28 < 25:
        #     raise Exception('FAULTY GAME MECHANIC')

    def select_target(self, player):
        sorted_players = sorted(self.players, key=lambda x: -x.total_health()+self.np_rng.random()/2)
        for p in sorted_players:
            if p != player and p not in self.immune_players:
                return p

    def get_game_state(self):
        current_player = self.current_player
        private_state = [
            current_player.player_id,
            current_player.character.character_id,
            current_player.health,
            current_player.total_defenses(),
            int(current_player in self.active_players),
            current_player.plays,
            current_player.deck,
            current_player.hand,
        ]
        public_states = [
            [
                player.health,
                player.total_defenses(),
                int(player in self.active_players),
                player.discardpile,
            ]
            for player in sorted(self.players, key=lambda x: x.character.character_id)
        ]
        return private_state, public_states

## environment.py

In [None]:
from itertools import count, chain
import numpy as np
import torch

# from environments.dungeonmayhem.player import DungeonMayhem_Game

class DungeonMayhem_Environment:    

    def __init__(self, device = None, random_seed = 6450, encoding_complexity = 0):
        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        self.np_rng = np.random.default_rng(random_seed)
        self.game = DungeonMayhem_Game(self.np_rng)
        self.num_actions, self.num_cards = self.prepare_cards(self.game)
        self.encoding_complexity = encoding_complexity
        self.state_dimensions = sum([
            0, # current_player.player_id,
            4, # current_player.character.character_id,
            1, # current_player.health,
            1, # current_player.total_defenses(),
            1, # int(current_player in self.active_players),
            1, # current_player.plays,
            self.num_actions if self.encoding_complexity > 0 else 0, # current_player.deck,
            self.num_actions, # current_player.hand,
            4 * 1, # player.health,
            4 * 1, # player.total_defenses(),
            4 * 1, # int(player in self.active_players),
            self.num_cards if self.encoding_complexity > 1 else 0, # player.discardpile,
        ])
        self.action_dimensions = 6
        

    def prepare_cards(self, game):
        decks = [character.get_deck() for character in game.characters]

        hash_tensor = torch.Tensor([1, 10, 100, 1000, 10000, 100000])
        for deck in decks:
            for card in deck:
                card.action_encoding = torch.Tensor([
                    card.attack,
                    card.defend,
                    card.draw,
                    card.heal,
                    card.play,
                    0 if card.power is None else card.power[0],            
                ])
                card.action_id = torch.dot(card.action_encoding, hash_tensor).item()
                card.action_encoding = card.action_encoding.to(device=self.device)

        card_counter = count()
        for deck in decks:
            deck.sort(key=lambda x: x.action_id)
            prev_card_id = -1
            prev_action_id_hash = -1
            for card in deck:
                if card.action_id == prev_action_id_hash:
                    card.card_id = prev_card_id
                else:
                    card.card_id = next(card_counter)
                    prev_card_id = card.card_id
                    prev_action_id_hash = card.action_id

        action_counter = count()
        cards = list(chain.from_iterable(decks))
        cards.sort(key=lambda x: x.action_id)
        prev_action_id = -1
        prev_action_id_hash = -1
        for card in cards:
            if card.action_id == prev_action_id_hash:
                card.action_id = prev_action_id
            else:
                prev_action_id_hash = card.action_id
                card.action_id = next(action_counter)
                prev_action_id = card.action_id

        return next(action_counter), next(card_counter)

    def reset(self):
        self.game.reset()

    def get_winner(self):
        return self.game.winner

    def get_current_player_details(self):
        game_state = self.game.get_game_state()
        id = game_state[0][0] # current_player.player_id,

        state = torch.zeros(self.state_dimensions)
        idx = 0
        state[idx + game_state[0][1]] = 1 # current_player.character.character_id,
        idx += 4
        state[idx] = game_state[0][2] # current_player.health,
        idx += 1
        state[idx] = game_state[0][3] # current_player.total_defenses(),
        idx += 1
        state[idx] = game_state[0][4] # int(current_player in self.active_players),
        idx += 1
        state[idx] = game_state[0][5] # current_player.plays,
        idx += 1
        if self.encoding_complexity > 0:
            for card in game_state[0][6]:
                state[idx + card.action_id] += 1 # current_player.deck
            idx += self.num_actions
        for card in game_state[0][7]:
            state[idx + card.action_id] += 1 # current_player.hand
        idx += self.num_actions
        for i in range(4):
            state[idx + i] = game_state[1][i][0] # player.health,
        idx += 4
        for i in range(4):
            state[idx + i] = game_state[1][i][1] # player.total_defenses(),
        idx += 4
        for i in range(4):
            state[idx + i] = game_state[1][i][2] # int(player in self.active_players),
        idx += 4
        if self.encoding_complexity > 1:
            for i in range(4):
                for card in game_state[1][i][3]:
                    state[idx + card.card_id] += 1 # player.discardpile,
        state = state.to(device=self.device)

        actions = torch.vstack([
            card.action_encoding
            for card in game_state[0][7]
        ])

        return id, state, actions

    def apply_action(self, idx):
        player = self.game.current_player
        card = self.game.current_player.hand[idx]
        self.game.play_card(player, card)

    def get_health_defense_idxs(self):
        idx = sum([
            0,
            4,
            1,
            1,
            1,
            1,
            self.num_actions if self.encoding_complexity > 0 else 0,
            self.num_actions,
        ])
        return (4, 5), [(idx+i, idx+4+i) for i in range(4)]

# agents

## random_agent.py

In [None]:
import numpy as np
from itertools import count

class Random_Agent:
    agent_counter = count()

    def __init__(self, random_seed = 6450):
        self.np_rng = np.random.default_rng(random_seed + next(Random_Agent.agent_counter))

    def step(self, state_actions):
        return self.np_rng.integers(len(state_actions))

## nn_estimator.py

In [None]:
import torch
import torch.nn as nn
from itertools import count

class NN_Model(nn.Module):

    def __init__(
        self,
        input_dim,
        hidden_dims,
    ):
        super(NN_Model, self).__init__()
        dims = [input_dim] + hidden_dims
        fc = []
        for i in range(len(dims)-1):
            fc.append(nn.Linear(dims[i], dims[i+1]))
            fc.append(nn.LeakyReLU())
        fc.append(nn.Linear(dims[-1], 1))
        fc.append(nn.Tanh())
        self.fc_layers = nn.Sequential(*fc)

    def forward(self, sa):
        q_sa = self.fc_layers(sa).squeeze(1)
        return q_sa

class NN_Estimator:

    def __init__(
        self,
        input_dim,
        hidden_dims,
        learning_rate,
        device = None,
        name = '',
    ):
        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self.model = NN_Model(input_dim, hidden_dims).to(self.device)
        for p in self.model.parameters():
            if len(p.data.shape) > 1:
                nn.init.xavier_uniform_(p.data)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

        self.name = name
        self.save_counter = count()

    def predict(self, sa):
        with torch.no_grad():
            q_sa = self.model(sa)
        return q_sa

    def train(self, sa, y):
        self.optimizer.zero_grad()
        q_sa = self.model(sa)
        loss = self.criterion(q_sa, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def save_model(self):
        save_dir = '{}_{}.pt'.format(self.name, next(self.save_counter))
        torch.save(self.model.state_dict(), save_dir)

    def load_model(self, load_dir):
        self.model.load_state_dict(torch.load(load_dir))

## dqn_agent.py

In [None]:
import numpy as np
from itertools import count
from copy import deepcopy
import logging

# from agents.nn_estimator import NN_Estimator

class DQN_Agent:
    agent_counter = count()

    def __init__(
        self,
        state_action_dims,
        learning_rate=0.0001,
        exploration_rate=0.1,
        batch_size=1024,
        replay_memory_max_size=100000,
        replay_memory_min_size=10000,
        train_learner_estimator_every=100,
        update_target_estimator_every=100,
        hidden_dims=[512]*3,
        device=None,
        name='',
        random_seed=6450,
    ):

        self.exploration_rate = exploration_rate
        self.batch_size = batch_size
        self.replay_memory_min_size = replay_memory_min_size
        self.replay_memory_max_size = replay_memory_max_size
        self.train_learner_estimator_every = train_learner_estimator_every
        self.update_target_estimator_every = update_target_estimator_every

        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self.total_t = 0
        self.train_t = 0
        self.memory = []

        self.learner_estimator = NN_Estimator(
            input_dim=state_action_dims,
            hidden_dims=hidden_dims,
            learning_rate=learning_rate,
            device=self.device,
            name='{}_dqn_learner'.format(name),
        )
        self.target_estimator = NN_Estimator(
            input_dim=state_action_dims,
            hidden_dims=hidden_dims,
            learning_rate=learning_rate,
            device=self.device,
            name='{}_dqn_target'.format(name),
        )

        self.np_rng = np.random.default_rng(random_seed + next(DQN_Agent.agent_counter))

    def step(self, state_actions):
        if self.np_rng.random() < self.exploration_rate:
            return self.np_rng.integers(len(state_actions))
        return self.learner_estimator.predict(state_actions).argmax().item()

    def feed(self, state_action, reward, next_state_actions):
        if len(self.memory) == self.replay_memory_max_size:
            self.memory.pop(0)
        self.memory.append([state_action, reward, next_state_actions])
        self.total_t += 1
        tmp = self.total_t - self.replay_memory_min_size
        if tmp >= 0 and tmp % self.train_learner_estimator_every == 0:
            return self.train()
        return 0

    def train(self):
        state_action_list = []
        reward_list = []
        next_state_actions_list = []
        idxs = self.np_rng.integers(len(self.memory)-self.batch_size+1, size=self.batch_size)
        for i, idx in enumerate(idxs):
            x = self.memory[i+idx]
            state_action_list.append(x[0])
            reward_list.append(x[1])
            next_state_actions_list.append(x[2])
        state_action_batch = torch.stack(state_action_list)
        reward_batch = torch.Tensor(reward_list).to(self.device)
        next_state_actions_batch = torch.cat(next_state_actions_list)
        next_state_actions_sizes = [len(x) for x in next_state_actions_list]

        q_values_next_learner = self.learner_estimator.predict(next_state_actions_batch)
        q_values_next_learner_split = torch.split(q_values_next_learner, next_state_actions_sizes)
        best_state_actions_list = [sa_s[q_sa_s.argmax()] for sa_s, q_sa_s in zip(next_state_actions_list, q_values_next_learner_split)]
        best_state_actions_batch = torch.stack(best_state_actions_list)
        q_values_next_target = self.target_estimator.predict(best_state_actions_batch)
        target_batch = reward_batch + reward_batch.eq(0).float() * q_values_next_target

        loss = self.learner_estimator.train(state_action_batch, target_batch)
        # print("\rStep: {} Loss: {}".format(self.total_t, loss), end="")

        if self.train_t % self.update_target_estimator_every == 0:
            logging.info('Step: {} Loss: {}'.format(self.total_t, loss))
            self.target_estimator = deepcopy(self.learner_estimator)
            logging.info('Target Updated')
            #self.learner_estimator.save_model()            

        self.train_t += 1

        return loss

## dmc_agent.py

In [None]:
import numpy as np
from itertools import count
import logging

# from agents.nn_estimator import NN_Estimator

class DMC_Agent:
    agent_counter = count()

    def __init__(
        self,
        state_action_dims,
        learning_rate=0.0001,
        exploration_rate=0.1,
        train_estimator_every=10,
        hidden_dims=[512]*3,
        device=None,
        name='',
        random_seed=6450,
    ):

        self.exploration_rate = exploration_rate
        self.train_estimator_every = train_estimator_every

        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self.t = 0
        self.memory = []

        self.estimator = NN_Estimator(
            input_dim=state_action_dims,
            hidden_dims=hidden_dims,
            learning_rate=learning_rate,
            device=self.device,
            name='{}_dmc_estimator'.format(name),
        )

        self.np_rng = np.random.default_rng(random_seed + next(DMC_Agent.agent_counter))

    def step(self, state_actions):
        if self.np_rng.random() < self.exploration_rate:
            return self.np_rng.integers(len(state_actions))
        return self.estimator.predict(state_actions).argmax().item()

    def feed(self, episodes):
        for state_actions, reward in episodes:
            T = len(state_actions)
            if T == 0:
                continue
            state_actions = torch.stack(state_actions)
            reward = reward * torch.ones(T).to(self.device)
            self.memory.append((state_actions, reward))
        self.t += 1
        if self.t % self.train_estimator_every == 0:
            return self.train()
        return 0

    def train(self):
        state_actions_batch = torch.cat([x[0] for x in self.memory])
        reward_batch = torch.cat([x[1] for x in self.memory])
        loss = self.estimator.train(state_actions_batch, reward_batch)
        logging.info('Step: {} Loss: {}'.format(self.t, loss))
        self.memory = []
        return loss

# Pretraining

In [None]:
import numpy as np
import torch
from math import exp
import logging
import pickle

# from environments.dungeonmayhem.environment import DungeonMayhem_Environment
# from agents.nn_estimator import NN_Estimator

def pretrain(num_episodes = 100000, num_epochs = 100, batch_size = 1024, name=''):

    log = logging.getLogger()
    for hdlr in log.handlers[:]:
        if isinstance(hdlr, logging.FileHandler):
            log.removeHandler(hdlr)
    filehandler = logging.FileHandler('{}_pretraining.log'.format(name), 'w')
    log.addHandler(filehandler)

    environment = DungeonMayhem_Environment()
    player_idxs, game_idxs = environment.get_health_defense_idxs()
    dims = environment.state_dimensions + environment.action_dimensions
    estimator = NN_Estimator(dims, hidden_dims=[512]*3, learning_rate=0.0001, name='{}_pretrained'.format(name))
    memory = []
    loss_list = []

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    np_rng = np.random.default_rng(6450)

    for episode_id in range(1, 1+num_episodes):
        environment.reset()
        prev_state_actions = [None, None, None, None]        

        while environment.get_winner() is None:
            id, state, actions = environment.get_current_player_details()
            if prev_state_actions[id] is not None:
                numerator = exp(state[player_idxs[0]] + state[player_idxs[1]])
                denominator = sum([exp(state[y[0]] + state[y[1]]) for y in game_idxs])
                y = numerator / denominator
                memory.append((prev_state_actions[id][0], prev_state_actions[id][1], 2*y-1))
            action_idx = np_rng.integers(len(actions))
            prev_state_actions[id] = (state, actions[action_idx])
            environment.apply_action(action_idx)

        winner_id = environment.get_winner().player_id
        for id in range(4):
            state_action = prev_state_actions[id]
            if state_action is None:
                continue
            reward = 1 if id == winner_id else -1            
            memory.append((state_action[0], state_action[1], reward))            

        if episode_id % 1000 == 0:
            logging.info('Episode: {}/{} DONE!'.format(episode_id, num_episodes))

    def process_memory_entry(x):
        return torch.cat((x[0], x[1])), x[2]
    for i, x in enumerate(memory):
        memory[i] = process_memory_entry(memory[i])

    return estimator, memory

    for epoch_id in range(1, 1+num_epochs):
        iii = np.array_split(np_rng.permutation(len(memory)), len(memory)//batch_size+1)
        for ii in iii:
            state_action_batch = torch.stack([memory[i][0] for i in ii])
            target_batch = torch.Tensor([memory[i][1] for i in ii]).to(device)
            loss = estimator.train(state_action_batch, target_batch)
            if loss > 0:
                loss_list.append(loss)
        logging.info('Loss: {}'.format(loss))
        logging.info('Epoch: {}/{} DONE!'.format(epoch_id, num_epochs))

        if epoch_id == num_epochs // 2 or epoch_id == num_epochs:
            estimator.save_model()
    with open('{}_pretraining_loss_list'.format(name), 'wb') as fp:
        pickle.dump(loss_list, fp)

    return estimator, loss_list

# Training

In [None]:
def round(environment, agents):
    environment.reset()
    prev_state_actions = [None, None, None, None]        
    if not isinstance(agents, list):
        agents = [agents] * 4

    while environment.get_winner() is None:
        id, state, actions = environment.get_current_player_details()            
        state = state.expand(len(actions), -1)
        state_actions = torch.cat((state, actions), dim=1)
        action_idx = agents[id].step(state_actions)
        prev_state_actions[id] = state_actions[action_idx]
        environment.apply_action(action_idx)

    winner_id = environment.get_winner().player_id
    return [1 if id == winner_id else -1 for id in range(4)]

def tournament(environment, agents, num_games):
    payoffs = [0 for _ in range(4)]
    for _ in range(num_games):
        _payoffs = round(environment, agents)
        for i, _ in enumerate(payoffs):
            payoffs[i] += _payoffs[i]
    for i, _ in enumerate(payoffs):
        payoffs[i] /= num_games
    return payoffs

In [None]:
import logging
import pickle
# from agents.dqn_agent import DQN_Agent
# from environments.dungeonmayhem.environment import DungeonMayhem_Environment

def train_DQN(
    hidden_dims = [512]*3,
    encoding_complexity = 0,
    pretraining_file_name = None,
    num_episodes = 100000,
    random_seed = 6450,
    name='',
):

    log = logging.getLogger()
    for hdlr in log.handlers[:]:
        if isinstance(hdlr, logging.FileHandler):
            log.removeHandler(hdlr)
    filehandler = logging.FileHandler('{}_dqn_training.log'.format(name), 'w')
    log.addHandler(filehandler)

    environment = DungeonMayhem_Environment(encoding_complexity = encoding_complexity, random_seed = random_seed)
    state_action_dims = environment.state_dimensions + environment.action_dimensions
    agent = DQN_Agent(state_action_dims, hidden_dims = hidden_dims, random_seed = random_seed, name=name)

    if pretraining_file_name is not None:
        agent.learner_estimator.load_model(pretraining_file_name)
        agent.target_estimator.load_model(pretraining_file_name)

    loss_list = []
    payoffs_list = []

    for episode_id in range(1, 1+num_episodes):
        environment.reset()
        prev_state_actions = [None, None, None, None]        

        while environment.get_winner() is None:
            id, state, actions = environment.get_current_player_details()            
            state = state.expand(len(actions), -1)
            state_actions = torch.cat((state, actions), dim=1)
            if prev_state_actions[id] is not None:
                loss = agent.feed(prev_state_actions[id], 0, state_actions)
                if loss > 0:
                    loss_list.append(loss)
            action_idx = agent.step(state_actions)
            prev_state_actions[id] = state_actions[action_idx]
            environment.apply_action(action_idx)

        winner_id = environment.get_winner().player_id
        for id in range(4):
            state_action = prev_state_actions[id]
            if state_action is None:
                continue
            reward = 1 if id == winner_id else -1            
            terminal = torch.zeros_like(state_action).unsqueeze(dim=0)
            loss = agent.feed(state_action, reward, terminal)
            if loss > 0:
                loss_list.append(loss)

        if episode_id % 100 == 0:
            logging.info('Episode: {}/{} DONE!'.format(episode_id, num_episodes))
            exp = agent.exploration_rate
            agent.exploration_rate = 0
            payoffs = tournament(environment, [agent]+[Random_Agent()]*3, 100)
            agent.exploration_rate = exp
            payoffs_list.append(payoffs)
            logging.info('Payoffs: {}'.format(payoffs))

        if episode_id % 1000 == 0:
            agent.learner_estimator.save_model()
            logging.info('Estimator Saved')

    with open('{}_dqn_training_loss_list'.format(name), 'wb') as fp:
        pickle.dump(loss_list, fp)
    with open('{}_dqn_training_payoffs_list'.format(name), 'wb') as fp:
        pickle.dump(payoffs_list, fp)

    return agent, loss_list, payoffs_list

In [None]:
import logging
import pickle
# from agents.dmc_agent import DMC_Agent
# from environments.dungeonmayhem.environment import DungeonMayhem_Environment

def train_DMC(
    hidden_dims = [512]*3,
    encoding_complexity = 0,
    pretraining_file_name = None,
    num_episodes = 200000,
    random_seed = 6450,
    name='',
):

    log = logging.getLogger()
    for hdlr in log.handlers[:]:
        if isinstance(hdlr, logging.FileHandler):
            log.removeHandler(hdlr)
    filehandler = logging.FileHandler('{}_dmc_training.log'.format(name), 'w')
    log.addHandler(filehandler)

    environment = DungeonMayhem_Environment(encoding_complexity = encoding_complexity, random_seed = random_seed)
    state_action_dims = environment.state_dimensions + environment.action_dimensions
    agent = DMC_Agent(state_action_dims, hidden_dims = hidden_dims, random_seed = random_seed, name=name)

    if pretraining_file_name is not None:
        agent.estimator.load_model(pretraining_file_name)

    loss_list = []
    payoffs_list = []

    for episode_id in range(1, 1+num_episodes):
        environment.reset()
        prev_state_actions = [[], [], [], []]        

        while environment.get_winner() is None:
            id, state, actions = environment.get_current_player_details()            
            state = state.expand(len(actions), -1)
            state_actions = torch.cat((state, actions), dim=1)
            action_idx = agent.step(state_actions)
            prev_state_actions[id].append(state_actions[action_idx])
            environment.apply_action(action_idx)

        winner_id = environment.get_winner().player_id
        episodes = []
        for id in range(4):
            state_action = prev_state_actions[id]
            reward = 1 if id == winner_id else -1
            episodes.append((state_action, reward))         
        
        loss = agent.feed(episodes)
        if loss > 0:
            loss_list.append(loss)

        if episode_id % 100 == 0:
            logging.info('Episode: {}/{} DONE!'.format(episode_id, num_episodes))
            exp = agent.exploration_rate
            agent.exploration_rate = 0
            payoffs = tournament(environment, [agent]+[Random_Agent()]*3, 100)
            agent.exploration_rate = exp
            payoffs_list.append(payoffs)
            logging.info('Payoffs: {}'.format(payoffs))

        if episode_id % 1000 == 0:
            agent.estimator.save_model()
            logging.info('Estimator Saved')

    with open('{}_dmc_training_loss_list'.format(name), 'wb') as fp:
        pickle.dump(loss_list, fp)
    with open('{}_dmc_training_payoffs_list'.format(name), 'wb') as fp:
        pickle.dump(payoffs_list, fp)

    return agent, loss_list, payoffs_list

# Run

In [None]:
# # For Connecting to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# # For Creating Folder 'For Albert'
# folder_path = '/content/drive/MyDrive/For Albert'
# import os
# if not os.path.exists(folder_path):
#     os.makedirs(folder_path)

In [None]:
# For Showing Logs on Screen
import logging
log = logging.getLogger()
log.addHandler(logging.StreamHandler())
log.setLevel(logging.INFO)

In [None]:
# For Pretraining
# estimator, loss_list = pretrain(
#     name=folder_path+'/RL',
# )

In [None]:
# For Training Model A
# agent, loss_list, payoffs_list = train_DQN(
#     name=folder_path+'/50E_pretrained',
#     pretraining_file_name=folder_path+'/RL_pretrained_0.pt',
# )

In [None]:
# # For Training Model B
# agent, loss_list, payoffs_list = train_DQN(
#     name=folder_path+'/100E_pretrained',
#     pretraining_file_name=folder_path+'/RL_pretrained_1.pt',
# )

In [None]:
# # For Training Model C
# agent, loss_list, payoffs_list = train_DMC(
#     name=folder_path+'/50E_pretrained',
#     pretraining_file_name=folder_path+'/RL_pretrained_0.pt',
# )

In [None]:
# # For Training Model D
# agent, loss_list, payoffs_list = train_DMC(
#     name=folder_path+'/100E_pretrained',
#     pretraining_file_name=folder_path+'/RL_pretrained_1.pt',
# )

In [None]:
# import multiprocessing

# def train_DQN_multiprocessing_wrapper(q, i):
#     if i == 0:
#         agent, loss_list, payoffs_list = train_DQN(
#             name=folder_path+'/50E_pretrained',
#             pretraining_file_name=folder_path+'/RL_pretrained_0.pt',
#         )        
#     else:
#         agent, loss_list, payoffs_list = train_DQN(
#             name=folder_path+'/100E_pretrained',
#             pretraining_file_name=folder_path+'/RL_pretrained_1.pt',
#         )
#     q.put(('agent_{}'.format(i), agent))
#     q.put(('loss_list_{}'.format(i), loss_list))
#     q.put(('payoffs_list_{}'.format(i), payoffs_list))

In [None]:
# q = multiprocessing.Queue()
# p0 = multiprocessing.Process(target=train_DQN_multiprocessing_wrapper, args=(q,0))
# p1 = multiprocessing.Process(target=train_DQN_multiprocessing_wrapper, args=(q,1))
# p0.start()
# p1.start()
# p0.join()
# p1.join()

Episode: 100/100000 DONE!
Episode: 100/100000 DONE!
Payoffs: [-0.12, -0.58, -0.62, -0.68]
Payoffs: [-0.22, -0.58, -0.56, -0.64]
Step: 10000 Loss: 0.16688847541809082
Target Updated
Step: 10000 Loss: 0.14960427582263947
Target Updated
Episode: 200/100000 DONE!
Episode: 200/100000 DONE!
Payoffs: [-0.14, -0.8, -0.4, -0.66]
Payoffs: [0.16, -0.8, -0.7, -0.68]
Step: 20000 Loss: 0.09804263710975647
Target Updated
Step: 20000 Loss: 0.09436871856451035
Target Updated
Episode: 300/100000 DONE!
Episode: 300/100000 DONE!
Payoffs: [-0.14, -0.68, -0.62, -0.56]
Payoffs: [-0.14, -0.68, -0.7, -0.5]
Step: 30000 Loss: 0.07889556884765625
Target Updated
Step: 30000 Loss: 0.06730889528989792
Target Updated
Episode: 400/100000 DONE!
Payoffs: [-0.24, -0.46, -0.56, -0.74]
Episode: 400/100000 DONE!
Payoffs: [0.1, -0.62, -0.7, -0.78]
Step: 40000 Loss: 0.06561850011348724
Target Updated
Step: 40000 Loss: 0.048055216670036316
Target Updated
Episode: 500/100000 DONE!
Payoffs: [-0.12, -0.8, -0.54, -0.56]
Episode: 5

KeyboardInterrupt: ignored

In [None]:


def load_agent(
    environment,
    random_seed=6450,
    ckpt=None,
    prefix="dqn",
    layers=3,
    layer_size=512,
    complexity=0,
    pretraining=0,
):
    name = f"{prefix}{layer_size}_{layers}{complexity}_{pretraining}"  # 0 is for no pretraining
    if ckpt is None:
        import glob

        ckpt = f"{name}/{name}_*.pt"
        ckpts = glob.glob(ckpt)
        ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
        if len(ckpt) == 0:
            raise ValueError("No checkpoint found")
        ckpt = ckpts[-1]
        print(ckpt)
        # raise ValueError(f'No checkpoint specified, using {ckpt}')
    else:
        ckpt = f"{name}/{name}_{ckpt}.pt"

    state_action_dims = environment.state_dimensions + environment.action_dimensions
    cls = (
        DQN_Agent if prefix == "dqn" else DMC_Agent if prefix == "dmc" else Random_Agent
    )
    agent: Union[DQN_Agent, DMC_Agent, Brute_Agent, Random_Agent] = cls(
        state_action_dims=state_action_dims,
        random_seed=random_seed,
        hidden_dims=[layer_size] * layers,
        name=name,
    )
    if isinstance(agent, DQN_Agent):
        agent.learner_estimator.load_model(ckpt)
        agent.exploration_rate = 0
    elif isinstance(agent, DMC_Agent):
        agent.estimator.load_model(ckpt)
        agent.exploration_rate = 0
    # agent = Brute_Agent(environment, num_sims=10, time_limit=5)

    return agent


def eval_a_model(
    ckpt=None,
    prefix="dqn",
    layers=3,
    layer_size=512,
    complexity=0,
    pretraining=0,
    num_games=10000,
    order_perm=None,
    char_perm=None,
    random_seed=6450,
):
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("pretraining.log"),
            logging.StreamHandler(),
        ],
    )

    environment = DungeonMayhem_Environment(
        random_seed=random_seed, encoding_complexity=complexity
    )
    agent = load_agent(
        environment,
        ckpt=ckpt,
        random_seed=random_seed,
        prefix=prefix,
        layers=layers,
        layer_size=layer_size,
        complexity=complexity,
        pretraining=pretraining,
    )

    # agents = [Random_Agent(random_seed)] * 4
    agents = [agent] + [Random_Agent(random_seed)] * 3
    payoffs = tournament(
        environment,
        agents,
        num_games=num_games,
        order_perm=order_perm,
        char_perm=char_perm,
    )
    logging.info("Payoffs {}".format(payoffs))
    return (ckpt, payoffs)


def train_a_model(
    prefix, size=512, num=3, complexity=0, pretraining=0, load_file=None, func=None
):
    if func == None:
        if prefix == "dqn":
            func = train_DQN
        elif prefix == "dmc":
            func = train_DMC
    return func(
        hidden_dims=[size] * num,
        encoding_complexity=complexity,
        name=f"{prefix}{size}_{num}{complexity}_{pretraining}",
        pretraining_file_name=load_file,
    )


def round_robin(env, agents, num_games=100):
    agents += [Random_Agent(random_seed=6450)] * 4
    from itertools import permutations

    for order_perm in permutations(range(4)):
        for char_perm in permutations(range(4)):
            payoffs = tournament(
                env,
                agents,
                num_games=num_games,
                order_perm=order_perm,
                char_perm=char_perm,
            )
            logging.info("Payoffs {} {} = {}".format(order_perm, char_perm, payoffs))
            yield (order_perm, char_perm, payoffs)

def round_robin_helper(layers=3, layer_size=512, complexity=0, pretraining=0, num_games=100, all_rando=False):
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("pretraining.log"),
            logging.StreamHandler(),
        ],
    )

    environment = DungeonMayhem_Environment(random_seed=6450)
    payoffs = round_robin(
        environment,
        [] if all_rando else
        [
            load_agent(
                environment,
                prefix="dqn",
                layers=layers,
                layer_size=layer_size,
                complexity=complexity,
                pretraining=pretraining,
            ),
            load_agent(
                environment,
                prefix="dmc",
                layers=layers,
                layer_size=layer_size,
                complexity=complexity,
                pretraining=pretraining,
            )
        ],
        num_games=num_games,
    )
    with open(f"tournament_{layer_size}_{layers}{complexity}_{pretraining}.pickle", "wb") as fp:
        pickle.dump(list(payoffs), fp)

# rr_num_games = 1000
# round_robin_helper(layers=3, layer_size=512, complexity=0, pretraining=0, num_games=rr_num_games)
# round_robin_helper(layers=3, layer_size=256, complexity=0, pretraining=0, num_games=rr_num_games)
# round_robin_helper(layers=4, layer_size=512, complexity=0, pretraining=0, num_games=rr_num_games)
# round_robin_helper(layers=5, layer_size=512, complexity=0, pretraining=0, num_games=rr_num_games)

payoffs = []
single_model_num = 10000
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=512, complexity=1))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=512, complexity=2))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=4, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=5, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=512, complexity=0, pretraining=50,))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=512, complexity=0, pretraining=100,))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=512, complexity=1))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=512, complexity=2))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=4, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=5, layer_size=512, complexity=0))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=512, complexity=0, pretraining=50,))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=512, complexity=0, pretraining=100,))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=3, layer_size=256, complexity=0))
payoffs.append( eval_a_model( prefix="dmc", num_games=single_model_num, layers=4, layer_size=256, complexity=0))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=3, layer_size=256, complexity=0))
payoffs.append( eval_a_model( prefix="dqn", num_games=single_model_num, layers=4, layer_size=256, complexity=0))
with open("testing.pickle", "wb") as fp:
    pickle.dump(payoffs, fp)


# agent, loss_list, payoffs_list = train_a_model("dmc", 2048, 3, 0)
# agent, loss_list, payoffs_list = train_a_model("dmc", 256, 3, 0)
# agent, loss_list, payoffs_list = train_a_model("dqn", 256, 3, 0)
# agent, loss_list, payoffs_list = train_a_model("dmc", 256, 4, 0)
# agent, loss_list, payoffs_list = train_a_model("dqn", 256, 4, 0)
# agent, loss_list, payoffs_list = train_a_model("dqn", 128, 4, 0) # TODO
# agent, loss_list, payoffs_list = train_a_model("dmc", 128, 4, 0)
# raise ValueError("Training finished")
