In [45]:
import pandas as pd
import numpy as np

from rl_lib.swiss_round.environment import SwissRoundEnv
from rl_lib.swiss_round.agent import DQNAgent

%reload_ext autoreload
%autoreload 2

# Utils 

In [46]:
def probability_tables(team_strengths, max_draw_probability):
    index = range(len(team_strengths))
    wps = []
    dps = []
    lps = []
    for ts1 in team_strengths :
        twps = []
        tdps = []
        tlps = []
        for ts2 in team_strengths :
            strength_diff = ts1 - ts2
            tmp_win_prob = 1 / (1 + np.exp(-strength_diff))
            tmp_loss_prob = 1 / (1 + np.exp(+strength_diff))
            tmp_draw_prob = max_draw_probability * np.exp(-abs(strength_diff))
            # Softmax
            win_prob = tmp_win_prob / (tmp_win_prob + tmp_draw_prob + tmp_loss_prob)
            draw_prob = tmp_draw_prob / (tmp_win_prob + tmp_draw_prob + tmp_loss_prob)
            loss_prob = tmp_loss_prob / (tmp_win_prob + tmp_draw_prob + tmp_loss_prob)  
            
            twps.append(win_prob)
            tdps.append(draw_prob)
            tlps.append(loss_prob)
        wps.append(twps)
        dps.append(tdps)
        lps.append(tlps)
    return pd.DataFrame(wps, index=index, columns = index), pd.DataFrame(dps, index=index, columns = index),pd.DataFrame(lps, index=index, columns = index)        

In [47]:
def check_probability(team_strengths, max_draw_probability):
    wp, dp, lp = probability_tables(team_strengths=team_strengths, max_draw_probability=max_draw_probability)
    df = wp+dp+lp
    values_array = df.to_numpy()
    target_array = np.full_like(values_array, 1)

    return np.allclose(values_array, target_array, rtol=10e-5, atol=10e-8)

# Environement

In [48]:
n_teams = 18
threshold_ranks = [4,12]
agent_id = threshold_ranks[-1] 
team_strengths = [9 - 1 * i for i in range(n_teams)]
mdp = 0.3

print(np.array(team_strengths).round(2))
wp, dp, lp = probability_tables(team_strengths=team_strengths, max_draw_probability=mdp)
pd.concat([wp.loc[[agent_id]].rename(index={agent_id:'Win'}),
           dp.loc[[agent_id]].rename(index={agent_id:'Draw'}),
           lp.loc[[agent_id]].rename(index={agent_id:'Loss'}),
           ]).T.round(2)

[ 9  8  7  6  5  4  3  2  1  0 -1 -2 -3 -4 -5 -6 -7 -8]


Unnamed: 0,Win,Draw,Loss
0,0.0,0.0,1.0
1,0.0,0.0,1.0
2,0.0,0.0,1.0
3,0.0,0.0,1.0
4,0.0,0.0,1.0
5,0.0,0.0,1.0
6,0.0,0.0,1.0
7,0.01,0.0,0.99
8,0.02,0.01,0.98
9,0.05,0.01,0.94


In [49]:
n_rounds = 6
bonus_points = [30,20]
#Agent_id just below last threshold
n_baselines_simu = 2000

print(f"Valid probability set-up : {check_probability(team_strengths, mdp)}")

env = SwissRoundEnv(
    n_teams=n_teams,
    n_rounds=n_rounds,
    team_strengths=team_strengths,
    threshold_ranks=threshold_ranks,
    bonus_points=bonus_points,
    agent_id=agent_id,
    max_draw_probability=mdp
)

Valid probability set-up : True


#### Detailed tables

In [50]:
wp.round(2)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
1,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
2,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
4,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
5,0.01,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0,1.0
6,0.0,0.01,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0,1.0
7,0.0,0.0,0.01,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0,1.0
8,0.0,0.0,0.0,0.01,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0,1.0
9,0.0,0.0,0.0,0.0,0.01,0.02,0.05,0.11,0.24,0.38,0.66,0.85,0.94,0.98,0.99,1.0,1.0,1.0


In [51]:
dp.round(2)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,0.0,0.0,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,0.0,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,0.0,0.0,0.0,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0,0.0
9,0.0,0.0,0.0,0.0,0.0,0.01,0.01,0.04,0.1,0.23,0.1,0.04,0.01,0.01,0.0,0.0,0.0,0.0


In [52]:
lp.round(2)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.99,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,1.0,0.99,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0
7,1.0,1.0,0.99,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0,0.0
8,1.0,1.0,1.0,0.99,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0,0.0
9,1.0,1.0,1.0,1.0,0.99,0.98,0.94,0.85,0.66,0.38,0.24,0.11,0.05,0.02,0.01,0.0,0.0,0.0


In [53]:
wp+dp+lp

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
1,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
2,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
4,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
5,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
6,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
7,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


# Simulations

### Baselines Simulations

In [54]:
simulation_wa = env.simulate_n_tournaments(n_baselines_simu,
                                           n_cores = 32, 
                                           policy = 'win_all',
                                           display_results=False)
baseline_reward_wa = simulation_wa.loc[agent_id,'Avg_Points'] + sum([b * simulation_wa.loc[agent_id,f"Top-{t} %"] for b,t in zip(
    bonus_points, threshold_ranks
)])
print(f"Baseline WinAll average reward = {baseline_reward_wa:.1f}")
simulation_wa.loc[agent_id]

Simulating tournaments: 100%|██████████| 2000/2000 [00:15<00:00, 129.71it/s]


Baseline WinAll average reward = 16.6


Team          12.000000
Strength      -3.000000
Avg_Points     7.189162
Avg_Rank      12.634458
Top-4 %        0.000511
Top-12 %       0.468303
Name: 12, dtype: float64

In [55]:
simulation_lf = env.simulate_n_tournaments(n_baselines_simu,n_cores = 32, policy = 'lose_first',display_results=False)
baseline_reward_lf = simulation_lf.loc[agent_id,'Avg_Points'] + sum([b * simulation_lf.loc[agent_id,f"Top-{t} %"] for b,t in zip(
    bonus_points, threshold_ranks
)])
print(f"Baseline LoseFirst average reward = {baseline_reward_lf:.1f}")
simulation_lf.loc[agent_id]

Simulating tournaments: 100%|██████████| 2000/2000 [00:17<00:00, 111.17it/s]

Baseline LoseFirst average reward = 16.3





Team          12.000000
Strength      -3.000000
Avg_Points     7.138037
Avg_Rank      12.835890
Top-4 %        0.000000
Top-12 %       0.458078
Name: 12, dtype: float64

### RL Agent

In [43]:
print(f"Baselines average reward : WinAll = {baseline_reward_wa:.1f}, LoseFirst = {baseline_reward_lf:.1f}")

Baselines average reward : WinAll = 16.5, LoseFirst = 16.2


In [44]:
agent = DQNAgent(env,
                 hidden_dims=[256,128,64],
                 dropout= 0.1,
                 buffer_size=10000,
                 epsilon_decay=0.9995)
agent.train(n_episodes=3000)
agent.evaluate(n_episodes=400)

--- Training in progress ---
Episode 100/3000 | Avg Reward: 7.14 | Avg nb gambits played 3.86 | Epsilon: 0.765 | Failed episodes: 0
Episode 200/3000 | Avg Reward: 7.77 | Avg nb gambits played 3.36 | Epsilon: 0.564 | Failed episodes: 2
Episode 300/3000 | Avg Reward: 10.69 | Avg nb gambits played 2.70 | Epsilon: 0.415 | Failed episodes: 5
Episode 400/3000 | Avg Reward: 10.31 | Avg nb gambits played 2.57 | Epsilon: 0.306 | Failed episodes: 9
Episode 500/3000 | Avg Reward: 11.21 | Avg nb gambits played 2.18 | Epsilon: 0.225 | Failed episodes: 11
Episode 600/3000 | Avg Reward: 11.63 | Avg nb gambits played 2.08 | Epsilon: 0.166 | Failed episodes: 14
Episode 700/3000 | Avg Reward: 10.57 | Avg nb gambits played 1.63 | Epsilon: 0.123 | Failed episodes: 14
Episode 800/3000 | Avg Reward: 12.81 | Avg nb gambits played 1.59 | Epsilon: 0.090 | Failed episodes: 18
Episode 900/3000 | Avg Reward: 16.30 | Avg nb gambits played 1.41 | Epsilon: 0.067 | Failed episodes: 19


KeyboardInterrupt: 

### Verbosed simulation

In [None]:
# Simulate tournament
final_standings = env.simulate_tournament(verbose= True)

print("\nFinal standings (team_id, points, opponent_average):")
for rank, (team_id, points, opp_avg,strength) in enumerate(final_standings, 1):

    print(f"Rank {rank}: Team {team_id} - Strength {strength:.2f} - Points: {points} - Opponent Avg: {opp_avg:.2f}")


--- Simulating round n°1 ---
Game : Team 8 (points : 0, strength : -0.40) vs Team 15 (points : 0, strength : -2.50) : Team 8 wins
Game : Team 4 (points : 0, strength : 0.80) vs Team 1 (points : 0, strength : 1.70) : Team 1 wins
Game : Team 13 (points : 0, strength : -1.90) vs Team 7 (points : 0, strength : -0.10) : Team 7 wins
Game : Team 11 (points : 0, strength : -1.30) vs Team 0 (points : 0, strength : 2.00) : Team 0 wins
Game : Team 17 (points : 0, strength : -3.10) vs Team 14 (points : 0, strength : -2.20) : Team 14 wins
Game : Team 5 (points : 0, strength : 0.50) vs Team 16 (points : 0, strength : -2.80) : Team 5 wins
Game : Team 9 (points : 0, strength : -0.70) vs Team 2 (points : 0, strength : 1.40) : Team 9 wins
Game : Team 3 (points : 0, strength : 1.10) vs Team 12 (points : 0, strength : -1.60) : Team 12 wins
Game : Team 6 (points : 0, strength : 0.20) vs Team 10 (points : 0, strength : -1.00) : Team 6 wins
--- Simulating round n°2 ---
Game : Team 0 (points : 3, strength : 2

ValueError: Could not find a perfect matching after multiple attempts