In [1]:
import random
from rich import print
from tqdm import tqdm
import os
import json
from diplomacy.engine.map import Map
import sys
sys.path.append("../../")
from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile
from typing import Any


In [3]:
all_task_pks = list(EnvironmentProfile.all_pks())
for i in all_task_pks:
    if EnvironmentProfile.get(i).env_tag == "tv3":
        print(EnvironmentProfile.get(i))


### Adjuncat cleaning

In [4]:
games_dir = "/data/user_data/wenkail/sotopia_diplomacy/filter_games"
def read_games_from_folder(game_folder):
    games = []
    for root, dirs, files in os.walk(game_folder):
        for file in files:
            file_path = os.path.join(root, file)
            with open(file_path, 'r') as f:
                games.append(json.load(f))
    return games
def has_neighboring_units(country1, country2, unit_positions):
    def get_location(unit):
        return unit.split()[-1]
    map = Map(name='standard', use_cache=True)
    neighboring_units = []

    for unit1 in unit_positions[country1]:
        loc1 = get_location(unit1)
        for unit2 in unit_positions[country2]:
            loc2 = get_location(unit2)
            if loc2 in map.abut_list(loc1, incl_no_coast=True):
                neighboring_units.append((unit1, unit2))

    if neighboring_units:
        return neighboring_units
    else:
        return None
def adjunction_selection(games_dir, countries):
    games = read_games_from_folder(games_dir)
    game_phases = {}
    country1, country2 = countries
    for game in games:
        game_phases[game["id"]] = []
        for i in range(len(game['phases'])):
            result = has_neighboring_units(country1, country2, game['phases'][i]['state']['units'])
            if result:
                game_phases[game["id"]].append(game['phases'][i]['name'])
    game_phases = {k: v for k, v in game_phases.items() if v}
    return game_phases

In [5]:
# games = read_games_from_folder(games_dir)
game_phases = adjunction_selection(games_dir, ["ENGLAND", "ITALY"])

In [19]:
def get_env_pks(game_phases: dict[str, Any]):
    selected_pks = []
    all_task_pks = list(EnvironmentProfile.all_pks())
    
    # Create a dictionary to map game_id and phase_name to pks
    env_dict = {}
    for pk in tqdm(all_task_pks, desc = "Processing env profile dictionary"):
        env = EnvironmentProfile.get(pk)
        if env.game_id not in env_dict:
            env_dict[env.game_id] = {}
        if env.phase_name not in env_dict[env.game_id]:
            env_dict[env.game_id][env.phase_name] = []
        env_dict[env.game_id][env.phase_name].append(pk)
    
    # Search using the pre-built dictionary
    for game_id, phases in tqdm(game_phases.items(), desc = "Processing game phases"):
        if game_id in env_dict:
            for phase in phases:
                if phase in env_dict[game_id]:
                    selected_pks.extend(env_dict[game_id][phase])
    
    return selected_pks


In [20]:
pk_list = get_env_pks(game_phases)

Processing env profile dictionary: 100%|██████████| 1824/1824 [00:12<00:00, 142.08it/s]
Processing game phases: 100%|██████████| 38/38 [00:00<00:00, 48224.98it/s]


In [18]:
pk_list

['01J1SYR1Q0HNJD08JRCPPJS0HX',
 '01J1SYR1Q7S63MNHP5GXH69TQ8',
 '01J1SYR1R3C7S9WPEDFM61PDH3',
 '01J1SYR1RASDFEPP3TX72YN2RR',
 '01J1SYR1RF1E4KF42XFJ85F2JJ',
 '01J1SYR1SEW0P867WRBHM6178S',
 '01J1SYR1SKSFGS8M6ZZCST2CS6',
 '01J1SYR1SWGBNC40F01QEZ02F9',
 '01J1SYR1T10EKV70ZDAB6WVAX2',
 '01J1SYR1T7VXXTHBFHPSEZX5HH',
 '01J1SYR1TCZK9HP3QDBPGG6ZB9',
 '01J1SYR1TM5K0P07SMA9RJ3VH9',
 '01J1SYR1TXMJRXBQP7E484T00D',
 '01J1SYR1V3QJPN2GJXV3MCN95S',
 '01J1SYR1Z9HNC9CG261JHV58JY',
 '01J1SYR1ZG2RXJCDQS1KBY3MW1',
 '01J1SYR1ZNXB9TTAP2NZXH2T2R',
 '01J1SYR1ZW47W32CXCZWQEHW46',
 '01J1SYR205RDVMBJ5GEAMN41GQ',
 '01J1SYR20C6AC6PZZNE67XVPJB',
 '01J1SYR22THKMGP8TJWWNTQSA1',
 '01J1SYR231FKGNG7A16XD57KMG',
 '01J1SYR236F5X6MVY68FXVVK3M',
 '01J1SYR23B8AR7MG896W8QFB59',
 '01J1SYR23JTZZAAFFHKWAHZ937',
 '01J1SYR2462B1G675D28DYKCV4',
 '01J1SYR24B17X5Q3B0MDPM4PMR',
 '01J1SYR27HCZK1N3K0P1HQD1HV',
 '01J1SYR27PTD8GWXXZEHV3PR68',
 '01J1SYR27ZTCJC2RKB78M1P00X',
 '01J1SYR284ETV2FNNCHRX5DDKE',
 '01J1SYR28DW0DB1NMYBBEZHRKX',
 '01J1SY

In [25]:
data = games[0]
for i in range(len(data['phases'])):
    # rich.print(phase['orders'])
    country1 = 'ENGLAND'
    country2 = 'GERMANY'
    result = has_neighboring_units(country1, country2, data['phases'][i]['state']['units'])
    if result:
        # if result == []:
        #     print(i)
        # for unit1, unit2 in result:
        #     print(f"In the {data['phases'][i]['name']} phase, {unit1} ({country1}) and {unit2} ({country2}) are neighboring units")

        # else:
        #     print(f"In the {data['phases'][i]['name']} phase, {unit1} ({country1}) and {unit2} ({country2}) do not have neighboring units")

In [46]:
data = games[0]
for i in range(len(data['phases'])):
    # rich.print(phase['orders'])
    country1 = 'ENGLAND'
    country2 = 'ITALY'
    result = has_neighboring_units(country1, country2, data['phases'][i]['state']['units'])
    if result:
        print(f"Phase: {data['phases'][i]['name']}'s Result: {result}")

In [70]:
countries = ["good", "bad"]
country1, country2 = countries

In [72]:
country1.upper()

'GOOD'

In [41]:
games[0]['id']

'13818'

In [28]:
data = games[0]
for i in range(len(data['phases'])):
    # rich.print(phase['orders'])
    country1 = 'ENGLAND'
    country2 = 'France'
    result = has_neighboring_units(country1, country2, data['phases'][i]['state']['units'])
    if result:
        # if result == []:
        #     print(i)
        for unit1, unit2 in result:
            print(f"In the {data['phases'][i]['name']} phase, {unit1} ({country1}) and {unit2} ({country2}) are neighboring units")

        else:
            print(f"In the {data['phases'][i]['name']} phase, {unit1} ({country1}) and {unit2} ({country2}) are not neighboring units")

KeyError: 'France'