In [2]:
%load_ext autoreload
%autoreload 2

In [44]:
import pickle
with open('inputTestCases/_input2ways_n=4_.pickle', 'rb') as f:
    roadDefs = pickle.load(f) # deserialize using load()

In [69]:
import math
from junctionart.roundabout.encodingGFN.gflownet.gflownet import GFlowNet
from junctionart.roundabout.encodingGFN.policy import ForwardPolicy, BackwardPolicy
from junctionart.roundabout.encodingGFN.setGenerationEnv import SetGenerationEnv
from junctionart.roundabout.encodingGFN.gflownet.utils import trajectory_balance_loss
from torch.optim import Adam
from tqdm import tqdm
import torch 
import numpy as np

size = 4
nActions = 30



def train(batch_size, num_epochs, setEnv):
    forward_policy = ForwardPolicy(setEnv.state_dim, hidden_dim=128, num_actions=setEnv.num_actions)
    backward_policy = BackwardPolicy(setEnv.state_dim, num_actions=setEnv.num_actions)
    
    losses = []
    rewards = []

    flows = []
    model = GFlowNet(forward_policy, backward_policy, setEnv)
    
    opt = Adam([
        {"params": model.forward_policy.parameters()},
        {"params": model.total_flow, "lr": 5e-2},
    ],
    lr=5e-5)
    
    for i in (p := tqdm(range(num_epochs))):
        s0 = torch.zeros(batch_size, size).float()
   
        s, log, traj_length = model.sample_states(s0, return_log=True)
    

        loss = trajectory_balance_loss(log.total_flow,
                                       log.rewards,
                                       log.fwd_probs,
                                       log.back_probs)

        
        flows.append(model.total_flow.item())
        rewards.append(log.rewards.mean())
        losses.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        opt.step()
        opt.zero_grad()
        if i % 10 == 0: p.set_description(f"{loss.item():.3f} T.L : {traj_length}")

    return model, losses, rewards, flows

batch_size = 32
num_epochs = 1000



In [1]:
import seaborn as sns
import pandas as pd
import ast


def sample(nIter, nBatch, model, setEnv):
    terminalStates = []
    for i in tqdm(range(nIter)):
        s0 = torch.zeros(nBatch, size).float()
        s, traj_length = model.sample_states(s0, return_log=False)
        terminalStates.append(s)

    s = torch.concat(terminalStates)



    statesForPlot = setEnv.getStateForm(s).long().tolist()
    rewardsForPlot = setEnv.reward(s, showProgress=True).tolist()


    stateDict = {}
    for i in tqdm(range(len(statesForPlot))):
        state = str(statesForPlot[i])
        if state in stateDict:
            reward, freq = stateDict[state]
            stateDict[state] = (reward, freq + 1)
        else:
            stateDict[state] = (rewardsForPlot[i], 1)

        
    return stateDict

def getTopK(stateDict, K):
    maxReward = 0
    modes = []
    proxyRewards = []
    rewardWithConfig = []
    for _ in stateDict:
        reward, freq = stateDict[_]
        rewardWithConfig.append((reward, _))
        if reward > maxReward:
            maxReward = reward

    rewardWithConfig.sort(key = lambda x : x[0], reverse=True)

    for reward, config in rewardWithConfig[:K]: # top-500 samples
        modes.append(ast.literal_eval(config))
        proxyRewards.append(reward)
        
    for mode in modes:
        for i, _ in enumerate(mode):
            mode[i] -= 1
    
    return modes, proxyRewards



In [7]:
from junctionart.roundabout.encodingGFN.RoundaboutLaneEncodingEnv import RoundaboutLaneEncodingEnv
from junctionart.roundabout.RewardUtil import RewardUtil

def getRoundabouts(roadDefinition, modes):
    env = RoundaboutLaneEncodingEnv()
    roundabouts = []
    for i in tqdm(range(len(modes))):
        env.generateWithRoadDefinition(
            roadDefinition=roadDefinition,
            outgoingLanesMerge=False,
            nSegments=nActions,
            laneToCircularId=modes[i]
        )
        roundabouts.append(env.getRoundabout())
    return roundabouts

def getRewards(roundabouts):
    rewards = [roundabout.getReward() for roundabout in roundabouts]
    return rewards

def getDiversityScore(roundabouts):
    distances = []
    for i in tqdm(range(len(roundabouts))):
        for j in range(i + 1, len(roundabouts)):
            distance = RewardUtil.getDistance(roundabouts[i], roundabouts[j])
            distances.append(distance)

    distances = np.array(distances)
    return distances.sum() / (len(roundabouts) * (len(roundabouts) - 1))

In [83]:
scoresList = []
diversityScores = []
import math
output = {"roundabouts" : [], "modes" : [], "proxyRewards" : []}

for roadDefinition in roadDefs[:2]:
    env = SetGenerationEnv(size, nActions, roadDefinition, base = 10**9)
    doneTraining = False
    while not doneTraining:
        try:
            model, losses, rewards, flows = train(64, 500, env)
            doneTraining = True
        except ValueError:
            print("Error , trying again.")
        
    env.base = math.exp(1)
    stateDict = sample(1, 500, model, env)
    modes, proxyRewards = getTopK(stateDict, 50)
    roundabouts = getRoundabouts(roadDefinition, modes)
    
    output["roundabouts"].append(roundabouts)
    output["modes"].append(modes)
    output["proxyRewards"].append(proxyRewards)
    
    
    rewards = getRewards(roundabouts[:50])
    scoresList.append(rewards)
    # diversityScores.append(getDiversityScore(roundabouts))


# import pickle
# with open('analysis/expGFN_N=8_K=200.pkl', 'wb') as file:
#     pickle.dump(output, file)

9.420 T.L : 5: 100%|██████████████████████████████████████████████████████████████████████████| 500/500 [00:16<00:00, 29.86it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.90it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 2754.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 777298.74it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.98it/s]
8.462 T.L : 5: 100%|██████████████████████████████████████████████████████████████████████████| 500/500 [00:16<00:00, 29.75it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 55.03it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 500

In [84]:
import numpy as np
scores = np.asarray(scoresList)
print(scores.mean(), "+-", scores.std())

proxy = np.log(np.asarray(output["proxyRewards"]))
print(proxy.mean(), "+-", proxy.std())

6.58 +- 0.8594184079946158
0.7399999880811943 +- 0.048989785471673056


In [39]:
diversityScores = np.asarray(diversityScores)
print(diversityScores.mean(), "+-", diversityScores.std())

6.857124875450433 +- 0.48311343129633405


{'[7, 13, 8, 4]': (1.491824746131897, 1),
 '[18, 29, 27, 27]': (1.6487212181091309, 1),
 '[24, 3, 2, 15]': (1.6487212181091309, 1),
 '[30, 11, 28, 9]': (1.6487212181091309, 1),
 '[3, 21, 2, 15]': (1.6487212181091309, 1),
 '[30, 29, 19, 25]': (1.8221187591552734, 1),
 '[21, 2, 29, 25]': (1.8221187591552734, 1),
 '[21, 21, 17, 17]': (2.22554087638855, 1),
 '[16, 19, 1, 25]': (1.491824746131897, 1),
 '[18, 26, 26, 11]': (1.6487212181091309, 1),
 '[15, 14, 20, 22]': (1.2214027643203735, 1),
 '[19, 12, 15, 17]': (1.8221187591552734, 1),
 '[12, 4, 30, 25]': (1.3498587608337402, 1),
 '[26, 10, 7, 5]': (1.8221187591552734, 1),
 '[10, 29, 4, 7]': (1.8221187591552734, 1),
 '[19, 15, 13, 13]': (1.8221187591552734, 1),
 '[18, 18, 25, 28]': (1.491824746131897, 1),
 '[18, 22, 21, 2]': (1.491824746131897, 1),
 '[17, 9, 11, 12]': (1.8221187591552734, 1),
 '[14, 15, 8, 6]': (1.8221187591552734, 1),
 '[1, 4, 22, 5]': (1.6487212181091309, 1),
 '[26, 12, 15, 6]': (2.0137526988983154, 1),
 '[20, 24, 19, 21

In [16]:
for _ in stateDict:
    reward, a = stateDict[_]
    if reward < 0.8:
        print(_, stateDict[_])

In [22]:
k = sample(1, 10, model, env)

100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 221.09it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 1169.05it/s]


tensor(0.6000)
tensor(0.4000)
tensor(0.7000)
tensor(0.5000)
tensor(0.7000)
tensor(0.3000)
tensor(0.5000)
tensor(0.6000)
tensor(0.3000)
tensor(0.6000)


100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 178481.02it/s]


In [27]:
import pickle
with open('analysis/expGFN_N=4_K=200.pkl', 'wb') as file:
    pickle.dump(output, file)

In [1]:
import numpy as np
import pickle
with open('analysis/expSAC_N=8_K=200.pkl', 'rb') as f:
    output = pickle.load(f) # deserialize using load()

from tqdm import tqdm 
from junctionart.roundabout.RewardUtil import RewardUtil
def getRewards(roundabouts):
    rewards = []
    for roundaboutList in tqdm(roundabouts):
        rewardList = [roundabout.getReward() for roundabout in roundaboutList]
        rewards.append(rewardList)
    return rewards

def getDiversityScore(roundabouts):
    distances = []
    for i in tqdm(range(len(roundabouts))):
        for j in range(i + 1, len(roundabouts)):
            distance = RewardUtil.getDistance(roundabouts[i], roundabouts[j])
            distances.append(distance)

    distances = np.array(distances)
    return distances.sum() / (len(roundabouts) * (len(roundabouts) - 1))

from junctionart.roundabout.Roundabout import Roundabout

roundabouts = output['roundabouts']
diversityScores = [getDiversityScore(roundaboutList[:50]) for roundaboutList in roundabouts]
diversityScores = np.asarray(diversityScores)

print(diversityScores.mean(), "+-", diversityScores.std())

100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:17<00:00,  1.55s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:02<00:00,  1.24s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:59<00:00,  1.19s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:59<00:00,  1.18s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:57<00:00,  1.16s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:59<00:00,  1.19s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:59<00:00,  1.20s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████|

3.363292109436151 +- 1.9678179251599142



