In [3]:
%matplotlib inline

In [2]:
import os
import json
from collections import namedtuple

import numpy as np
from scipy import sparse

from tqdm import tqdm

class BatchEnv(object):
    def __init__(self):
        pass

    def init(self, path, root, race, enemy_race, step_mul=8, n_replays=4, n_steps=5, epochs=10, seed=None):
        np.random.seed(seed)

        with open(path) as f:
            replays = json.load(f)

        self.replays = self.__generate_replay_list__(replays, root, race)

        self.race = race
        self.enemy_race = enemy_race

        self.step_mul = step_mul
        self.n_replays = n_replays
        self.n_steps = n_steps

        self.epochs = epochs
        self.epoch = -1
        self.steps = 0

        self.replay_idx = -1
        self.replay_list = [None for _ in range(self.n_replays)]
        
        ## Display Progress Bar
        """
        self.epoch_pbar = tqdm(total=self.epochs, desc='Epoch')
        self.replay_pbar = None
        """

        self.__post_init__()

    def __generate_replay_list__(self, replays, race):
        raise NotImplementedError

    def __init_epoch__(self):
        self.epoch += 1
        """
        if self.epoch > 0:
            self.epoch_pbar.update(1)
        """
        if self.epoch == self.epochs:
            return False

        np.random.shuffle(self.replays)
        ## Display Progress Bar
        """
        if self.replay_pbar is not None:
            self.replay_pbar.close()
        self.replay_pbar = tqdm(total=len(self.replays), desc='  Replays')
        """
        return True

    def __reset__(self):
        self.replay_idx += 1
        if self.replay_idx % len(self.replays) == 0:
            has_more = self.__init_epoch__()
            if not has_more:
                return None

        path = self.replays[self.replay_idx%len(self.replays)]

        return self.__load_replay__(path)

    def __load_replay__(self, path):
        raise NotImplementedError

    def step(self, **kwargs):
        require_init = [False for _ in range(self.n_replays)]
        for i in range(self.n_replays):
            if self.replay_list[i] is None or self.replay_list[i]['done']:
                if self.replay_list[i] is not None:
                    keys = set(self.replay_list[i].keys())
                    for k in keys:
                        del self.replay_list[i][k]
                self.replay_list[i] = self.__reset__()
                require_init[i] = True
            if self.replay_list[i] is None:
                return None

        result = []
        for step in range(self.n_steps):
            result_per_step = []
            for i in range(self.n_replays):
                replay_dict = self.replay_list[i]

                features = self.__one_step__(replay_dict, replay_dict['done'])

                result_per_step.append(features)

            result.append(result_per_step)

        return self.__post_process__(result, **kwargs), require_init

    def __one_step__(self, replay_dict, done):
        raise NotImplementedError

    def __post_process__(self, result, **kwargs):
        raise NotImplementedError

    def step_count(self):
        return self.steps

    def close(self):
        """
        if self.epoch_pbar is not None:
            self.epoch_pbar.close()
        if self.replay_pbar is not None:
            self.replay_pbar.close()
        """
            
class BatchGlobalFeatureEnv(BatchEnv):
    n_features_dic = {'Terran':  {'Terran': 738,  'Protoss': 648,  'Zerg': 1116},
                      'Protoss': {'Terran': 638,  'Protoss': 548,  'Zerg': 1016},
                      'Zerg':    {'Terran': 1106, 'Protoss': 1016, 'Zerg': 1484}}
    n_actions_dic = {'Terran': 75, 'Protoss': 61, 'Zerg': 74}

    def __post_init__(self):
        self.n_features = self.n_features_dic[self.race][self.enemy_race]
        self.n_actions = self.n_actions_dic[self.race]

    def __generate_replay_list__(self, replays, root, race):
        result = []
        for path_dict in replays:
            for player_path in path_dict[race]:
                result.append(os.path.join(root, player_path['global_path']))

        return result

    def __load_replay__(self, path):
        replay_dict = {}
        replay_dict['ptr'] = 0
        replay_dict['done'] = False
        replay_dict['states'] = np.asarray(sparse.load_npz(path).todense())

        return replay_dict

    def __one_step__(self, replay_dict, done):
        states = replay_dict['states']
        feature_shape = states.shape[1:]
        if done:
            return np.zeros(feature_shape)

        self.steps += 1
        state = states[replay_dict['ptr']]
        replay_dict['ptr'] += 1
        if replay_dict['ptr'] == states.shape[0]:
            #self.replay_pbar.update(1)
            replay_dict['done'] = True
    
        return state

    def __post_process__(self, result, reward=True, action=False, score=False):
        result = np.asarray(result)

        result_return = [result[:, :, 15:]]
        if reward:
            result_return.append(result[:, :, 0:1])
        if action:
            result_return.append(result[:, :, 1:2])
        if score:
            result_return.append(result[:, :, 2:15])

        return result_return



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(6, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 128)
        self.fc5 = nn.Linear(128, 128)
        self.fc6 = nn.Linear(128, 75)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.fc6(x)
        return x

In [3]:
# Removes all features except for
# minerals, vespene, food_cap, food_used, food_army, food_workers
def transform(states):
    mask = np.ones([1, 256, 738], dtype=bool)
    mask[:, :, 7:] = False
    mask[:, :, 0] = False
    states=states[mask]
    return states

# Removes doing nothing action
def transform2(states, actions, n_replays):
    mask = np.ones(n_replays, dtype=bool)
    for i, action in enumerate(actions.squeeze()):
        if action == 74:
            mask[i] = False
    actions = actions.squeeze()[mask]
    states = transform(states)
    states = np.reshape(states, (256, -1))
    states = states[mask]
    return states, actions


In [None]:
import torch
import visdom
import sys

# Require grad?!?
# weight ?!?!?
# Todo:
# 1. skapa baseline
# 2. jämföra en epok och see accuracy
# 3. top 3 error rate test?!
# 4. see om predictions ser rimliga ut för många fall


def train(model, env):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    env_return = env.step(reward=True, action=True)
    print(env.epochs)
    
    if env_return is not None:
        (states, reward, actions_gt), require_init = env_return
   
    # Remove doing nothing
    states = torch.from_numpy(states).float().view(env.n_replays, -1)
    actions_gt = torch.from_numpy(actions_gt).long().squeeze()
    
    states, actions_gt = transform2(states, actions_gt, n_replays)
    running_loss = 0
    i = 0
    
    while True:
        #print("steps: {} \n epochs: {} ", end='\r')#.format(env.steps, env.epochs))
        print("steps: {}, replay: {}/{} epoch: {}".format(env.step_count(), env.replay_idx, len(env.replays), env.epoch+1), end="\r")
        sys.stdout.flush()
        
        # since we are not using steps we need o reshape result
        actions = model(states)
        
        loss = 0
        loss += F.cross_entropy(actions, actions_gt)
        optimizer.zero_grad()
        loss.backward()
        
        #print(model.fc1.weight)
        optimizer.step()
        #print(model.fc1.weight)
        
        
        env_return = env.step(reward=False, action=True)
        if env_return is not None:
            (raw_states, raw_actions_gt), require_init = env_return
            states = torch.from_numpy(raw_states).float().view(env.n_replays, -1)
            actions_gt = torch.from_numpy(raw_actions_gt).long().squeeze()
            states, actions_gt = transform2(states, actions_gt, n_replays)
        
        #env.step_count() > save or
        
        running_loss += loss.item()
        if env.steps % 2000 == 0:
            print("loss: {}".format(running_loss/2000))
            running_loss = 0
        if env.epoch == env.epochs:
            torch.save(model.state_dict(), 'model_iter_{}.pth'.format(env.step_count()))
            return
        if env_return is None:
            env.close()
            break

replay_path = ''
dataset_path = ''
race = 'Terran'
enemy_race = 'Terran'
steps = 20 # ?
n_replays = 256
epochs = 1
model = Net()

path = 'train_val_test/Terran_vs_Terran/train.json'
phrase = 'train'

env = BatchGlobalFeatureEnv()
env.init(path, './', race, enemy_race, n_replays=n_replays, n_steps=1, epochs=3)

train(model, env)
path = "train_val_test/Terran_vs_Terran/val.json"
env.init(path, './', race, enemy_race, n_replays=n_replays, n_steps=1, epochs=1)
accuracy, actions = test(model, env)
print()
print(accuracy)

#initial features
"""
tensor([0.0001, 0.0008, 0.0000, 0.0750, 0.0600, 0.0000, 0.0600, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000
"""

In [5]:
from collections import defaultdict

def test(model, env):
    print("Testing")
    ######################### SAVE RESULT ############################
    total_data_points = 0
    correct = 0
    correct_dumb = 0
    correct_three = 0
    saved_actions = defaultdict(int)
    ######################### TEST ###################################
    env_return = env.step(reward=False, action=True)
    if env_return is not  None:
        (states, actions_gt), require_init = env_return
        states = transform(states)
        states = torch.from_numpy(states).float().view(env.n_replays, -1)
        actions_gt = torch.from_numpy(actions_gt).long().squeeze()
    not_bad = 0
    while True:
        print("steps: {}, replay: {}/{} epoch: {}".format(env.step_count(), env.replay_idx, len(env.replays), env.epoch+1), end="\r")
        actions = model(states).view(env.n_replays, -1)
        actions.detach()
        #rint(np.argpartition(actions[0].detach().numpy(), -3)[-3:])
        #beak
        ########################### NEXT BATCH #############################################
        actions_np = np.squeeze(np.vstack([np.argpartition(action.detach().numpy(), -3)[-3:] for action in actions]))
        #actions_np = np.squeeze(np.vstack([np.argmax(action.data.cpu().numpy(), axis=0) for action in actions]))
        for three_actions in actions_np:
            for action in three_actions:
                saved_actions[action] +=1
        actions_gt_np = np.squeeze(actions_gt.cpu().numpy())
        for i, action in enumerate(actions_gt_np):
            if action != 74:
                if action == actions_np[i][2]:
                    correct += 1
                if action in actions_np[i]:
                    correct_three += 1
                if action == 62:
                    correct_dumb += 1
                total_data_points += 1
                
        if env.epoch == env.epochs:
            break
        
        env_return = env.step(reward=False, action=True)
        if env_return is not None:
            (raw_states, raw_actions), require_init = env_return
            raw_states = transform(raw_states)
            states = states.copy_(torch.from_numpy(raw_states).float().view(env.n_replays, -1))
            actions_gt = actions_gt.copy_(torch.from_numpy(raw_actions).long().squeeze())
        else:
            env.close()
    

    return correct / total_data_points, correct_dumb/total_data_points, correct_three/total_data_points, saved_actions


In [6]:
replay_path = ''
dataset_path = ''
race = 'Terran'
enemy_race = 'Terran'
steps = 20 # ?
n_replays = 256
epochs = 1
model = Net()

path = 'train_val_test/Terran_vs_Terran/train.json'
phrase = 'train'

model = Net()
path = "train_val_test/Terran_vs_Terran/val.json"
PATH = "model_iter_11638016.pth"
model.load_state_dict(torch.load(PATH))
model.eval()
env = BatchGlobalFeatureEnv()
env.init(path, './', race, enemy_race, n_replays=n_replays, n_steps=1, epochs=1)
accuracy, accuracy_dumb, accuracy_three, actions = test(model, env)

Testing
steps: 469248, replay: 980/980 epoch: 2

In [7]:
actions

defaultdict(int,
            {34: 8425,
             11: 34696,
             62: 450008,
             41: 31576,
             48: 175707,
             15: 27104,
             27: 16969,
             13: 20598,
             51: 414421,
             22: 15799,
             60: 8340,
             21: 7394,
             36: 81,
             47: 8725,
             50: 5301,
             42: 1395,
             29: 23298,
             26: 89631,
             64: 42001,
             46: 600,
             52: 526,
             68: 25311,
             31: 551,
             19: 50,
             28: 2,
             71: 2,
             35: 1})

In [22]:
action_id =  {'140': '1',
               '168': '11',
               '261': '0',
               '300': '16',
               '301': '17',
               '304': '18',
               '305': '19',
               '309': '22',
               '312': '23',
               '317': '26',
               '318': '27',
               '319': '28',
               '320': '33',
               '321': '30',
               '322': '31',
               '326': '35',
               '327': '38',
               '352': '58',
               '353': '55',
               '354': '57',
               '355': '56',
               '361': '61',
               '362': '63',
               '363': '65',
               '369': '67',
               '370': '70',
               '371': '69',
               '375': '72',
               '378': '73',
               '39': '10',
               '402': '2',
               '403': '3',
               '405': '4',
               '406': '5',
               '410': '6',
               '414': '7',
               '418': '8',
               '419': '9',
               '42': '13',
               '423': '12',
               '43': '14',
               '44': '15',
               '453': '34',
               '459': '39',
               '460': '40',
               '464': '42',
               '468': '44',
               '469': '45',
               '470': '46',
               '475': '54',
               '476': '49',
               '477': '51',
               '478': '52',
               '487': '59',
               '488': '60',
               '490': '62',
               '492': '64',
               '496': '66',
               '498': '68',
               '50': '20',
               '502': '71',
               '53': '21',
               '56': '24',
               '58': '25',
               '64': '29',
               '66': '32',
               '71': '36',
               '72': '37',
               '79': '41',
               '83': '43',
               '89': '47',
               '91': '48',
               '92': '50',
               '93': '53'}
action_name = { '140': 'Cancel_quick',
                 '168': 'Cancel_Last_quick',
                 '261': 'Halt_quick',
                 '300': 'Morph_Hellbat_quick',
                 '301': 'Morph_Hellion_quick',
                 '304': 'Morph_LiberatorAAMode_quick',
                 '305': 'Morph_LiberatorAGMode_screen',
                 '309': 'Morph_OrbitalCommand_quick',
                 '312': 'Morph_PlanetaryFortress_quick',
                 '317': 'Morph_SiegeMode_quick',
                 '318': 'Morph_SupplyDepot_Lower_quick',
                 '319': 'Morph_SupplyDepot_Raise_quick',
                 '320': 'Morph_ThorExplosiveMode_quick',
                 '321': 'Morph_ThorHighImpactMode_quick',
                 '322': 'Morph_Unsiege_quick',
                 '326': 'Morph_VikingAssaultMode_quick',
                 '327': 'Morph_VikingFighterMode_quick',
                 '352': 'Research_AdvancedBallistics_quick',
                 '353': 'Research_BansheeCloakingField_quick',
                 '354': 'Research_BansheeHyperflightRotors_quick',
                 '355': 'Research_BattlecruiserWeaponRefit_quick',
                 '361': 'Research_CombatShield_quick',
                 '362': 'Research_ConcussiveShells_quick',
                 '363': 'Research_DrillingClaws_quick',
                 '369': 'Research_HiSecAutoTracking_quick',
                 '370': 'Research_HighCapacityFuelTanks_quick',
                 '371': 'Research_InfernalPreigniter_quick',
                 '375': 'Research_NeosteelFrame_quick',
                 '378': 'Research_PersonalCloaking_quick',
                 '39': 'Build_Armory_screen',
                 '402': 'Research_RavenCorvidReactor_quick',
                 '403': 'Research_RavenRecalibratedExplosives_quick',
                 '405': 'Research_Stimpack_quick',
                 '406': 'Research_TerranInfantryArmor_quick',
                 '410': 'Research_TerranInfantryWeapons_quick',
                 '414': 'Research_TerranShipWeapons_quick',
                 '418': 'Research_TerranStructureArmorUpgrade_quick',
                 '419': 'Research_TerranVehicleAndShipPlating_quick',
                 '42': 'Build_Barracks_screen',
                 '423': 'Research_TerranVehicleWeapons_quick',
                 '43': 'Build_Bunker_screen',
                 '44': 'Build_CommandCenter_screen',
                 '453': 'Stop_quick',
                 '459': 'Train_Banshee_quick',
                 '460': 'Train_Battlecruiser_quick',
                 '464': 'Train_Cyclone_quick',
                 '468': 'Train_Ghost_quick',
                 '469': 'Train_Hellbat_quick',
                 '470': 'Train_Hellion_quick',
                 '475': 'Train_Liberator_quick',
                 '476': 'Train_Marauder_quick',
                 '477': 'Train_Marine_quick',
                 '478': 'Train_Medivac_quick',
                 '487': 'Train_Raven_quick',
                 '488': 'Train_Reaper_quick',
                 '490': 'Train_SCV_quick',
                 '492': 'Train_SiegeTank_quick',
                 '496': 'Train_Thor_quick',
                 '498': 'Train_VikingFighter_quick',
                 '50': 'Build_EngineeringBay_screen',
                 '502': 'Train_WidowMine_quick',
                 '53': 'Build_Factory_screen',
                 '56': 'Build_FusionCore_screen',
                 '58': 'Build_GhostAcademy_screen',
                 '64': 'Build_MissileTurret_screen',
                 '66': 'Build_Nuke_quick',
                 '71': 'Build_Reactor_quick',
                 '72': 'Build_Reactor_screen',
                 '79': 'Build_Refinery_screen',
                 '83': 'Build_SensorTower_screen',
                 '89': 'Build_Starport_screen',
                 '91': 'Build_SupplyDepot_screen',
                 '92': 'Build_TechLab_quick',
                 '93': 'Build_TechLab_screen'}
#sorted([int(val) for val in test['action_id'].values()])
test = {}
for key, value in action_id.items():
    test[value] = key
i=0
for key, _ in actions.items():
    print(action_name[str(test[str(key)])])
    i+=1
print(i) 
len(actions.keys())

Stop_quick
Cancel_Last_quick
Train_SCV_quick
Build_Refinery_screen
Build_SupplyDepot_screen
Build_CommandCenter_screen
Morph_SupplyDepot_Lower_quick
Build_Barracks_screen
Train_Marine_quick
Morph_OrbitalCommand_quick
Train_Reaper_quick
Build_Factory_screen
Build_Reactor_quick
Build_Starport_screen
Build_TechLab_quick
Train_Cyclone_quick
Build_MissileTurret_screen
Morph_SiegeMode_quick
Train_SiegeTank_quick
Train_Hellion_quick
Train_Medivac_quick
Train_VikingFighter_quick
Morph_Unsiege_quick
Morph_LiberatorAGMode_screen
Morph_SupplyDepot_Raise_quick
Train_WidowMine_quick
Morph_VikingAssaultMode_quick
27


27