In [1]:
import showdown
import agent
import asyncio
import torch
import json
import numpy as np
from datetime import datetime
from IPython.display import clear_output
from matplotlib import pyplot as plt

In [2]:

async def agent_battle(agent, showdown):
    await showdown.restart()
    done = False
    totalReward = 0
    battleProgress = []
    
    
    while not done:
        state = showdown.getState()
        validActions = showdown.getValidActions()
        
        action = agent.act(state, validActions)
        nextState, reward, done, winner = await showdown.executeAction(action)
        battleProgress.append((state, action, reward, nextState, done))
    
    battleReward = -5 if not winner else 5
    for i, (state, action, reward, nextState, done) in enumerate(battleProgress):
        adjustedReward = reward + (battleReward / (len(battleProgress)))
        totalReward += adjustedReward
        agent.remember(state, action, adjustedReward, nextState, done)
    
    print("Finishing battle...")
    return winner, totalReward

In [3]:
def makePlot(x, y, battle, timestamp, winrate, ratio):
        plt.plot(x, y)
        plt.xlabel('Battles')
        plt.ylabel('Rewards')
        plt.title(f'Learning Curve - Winrate: {winrate} - Best Ratio: {ratio}')
        plt.savefig(f"data/logs/plots/plot-{battle}-{timestamp}.png")
        plt.show()

In [4]:
async def training_loop(agent1, agent2, showdown1, showdown2, numBattles=5000):
    agent1Wins = 0
    latestWins = 0
    currentBestRatio = 0
    rewards1 = 0
    plotX = []
    plotY = []
    currentBestModel = 0
    agent2.model.load_state_dict(agent1.model.state_dict())
    
    for battle in range(numBattles):
        
        # Concurrently execute both agents and get the results from agent_battle
        results = await asyncio.gather(agent_battle(agent1, showdown1), agent_battle(agent2, showdown2))
        winner = results[0][0]
        print(results)
        if winner == 1:
            agent1Wins += 1
            latestWins += 1
            rewards1 += results[0][1]
            plotY.append(results[0][1])
        else:
            rewards1 += results[0][1]
            plotY.append(results[0][1])
            
        agent1.replay()
        
        plotX.append(battle)
        
        agent1.replay()        
        
        # Every 10 battles, output the current state and clear the old output.
        # Notebooks are so laggy.
        if battle % 10 == 0 and battle > 0:
            clear_output(wait=True)
            
            timestamp = datetime.now().strftime("%Y_%m%d-%p%I_%M_%S")
            # Save output to file
            with open(f"data/logs/outputs/output-{battle}-{timestamp}.txt", "w") as file:
                file.write(f"Current Stats: \n Wins This Cycle: {agent1Wins} \n Battles: {battle} \n Epsilon: {agent1.epsilon}")
             
            print(f"Cleared Output! Current Stats: \n Wins This Cycle: {agent1Wins} \n  Battles: {battle} \n Epsilon: {agent1.epsilon}, \n Latest Wins: {latestWins} \n Memory: {len(agent1.memory)}")
        
        # Every 50 battles, save the model and memory.
        if battle % 50 == 0 and battle > 0:
            
            # Reset epsilon according to win ratio
            winRatio = agent1Wins / battle
            
            if float(winRatio) < 0.7 and float(winRatio) > 0.3:
                
                # Reset Epsilon
                agent1.epsilon = max(agent1.epsilon, 0.5)
            
            # Save model and memory
            agent1.saveModel(f"data/models/model_{battle}.pt")
            agent1.saveMemory(f"data/memory/memory_{battle}.json")
            
            # Save plot
            makePlot(plotX, plotY, battle, timestamp, winRatio, currentBestRatio)
            
            # Set agent 2's weights to agent 1's.
            agent2.model.load_state_dict(agent1.model.state_dict())

            f = open(f"data/stats/{battle}.json", "w")
            f.write(json.dumps({"wins": agent1Wins, "rewards": rewards1, "winsThisCycle": latestWins, "epsilon": agent1.epsilon}))
            f.close()
            latestWins = 0
            
        if battle % 100 == 0 and battle > 0:
            agent1.loadTargetModel()
            
        if battle % 500 == 0 and battle > 0:
            # If the previous 500 battles went worse than the current 500, revert to the previous model.

            if winRatio > currentBestRatio:
                print("Noting best model")
                currentBestModel = f"data/models/model_{battle}.pt"
                currentBestRatio = winRatio
            else:
                print("Reloading Model...!")
                agent1.loadModel(currentBestModel)
                agent1.epsilon = 0.3
                
                


In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
stateSize = 671

possibleActions = json.load(open("data/possible_actions.json", "r"))
actionSize = len(possibleActions)

agent1 = agent.Agent(stateSize, actionSize, device, possibleActions)
agent2 = agent.Agent(stateSize, actionSize, device, possibleActions)
#agent1.loadModel("data/models/model_2850.pt")
#agent1.loadMemory("data/memory/memory_2850.json")

sd1 = showdown.Showdown("https://play.pokemonshowdown.com/action.php", "PoryAI-1", "password", "ws://localhost:8000/showdown/websocket", "gen9randombattle", True, True)
sd2 = showdown.Showdown("https://play.pokemonshowdown.com/action.php", "PoryAI-2", "password", "ws://localhost:8000/showdown/websocket", "gen9randombattle", False, False)
await sd1.connectNoSecurity()
await sd2.connectNoSecurity()
await training_loop(agent1, agent2, sd1, sd2, 10000)

Cleared Output! Current Stats: 
 Wins This Cycle: 8 
  Battles: 20 
 Epsilon: 1.0, 
 Latest Wins: 8 
 Memory: 1348
Recording Active Mon: heatran, with condition: ['100/100']
Player: p2, State: {'request': 3, 'playerSide': {'activeMon': {'pokeid': 1148, 'type1': 5, 'type2': 11, 'ability': 216, 'item': 250, 'teraType': 13, 'terrastillized': 0, 'moves': [{'disabled': 0, 'locked': 0, 'pp': 16, 'moveid': 542, 'category': 2, 'power': 110, 'type': 11}, {'disabled': 0, 'locked': 0, 'pp': 8, 'moveid': 355, 'category': 0, 'power': 0, 'type': 11}, {'disabled': 0, 'locked': 0, 'pp': 32, 'moveid': 483, 'category': 0, 'power': 0, 'type': 15}, {'disabled': 0, 'locked': 0, 'pp': 24, 'moveid': 686, 'category': 2, 'power': 90, 'type': 0}], 'condition': {'hp': 1.0, 'status': 0, 'taunted': 0, 'encored': 0, 'perishSong': 0, 'struggling': 0, 'substitute': 0, 'confusion': 0, 'leechSeed': 0, 'trapped': 0}, 'stats': {'atk': 122, 'atkMod': 0, 'def': 166, 'defMod': 0, 'spa': 213, 'spaMod': 0, 'spd': 166, 'spdMod