In [None]:
from Dataset_Generator import DatasetGenerator
from Encoder import ObservationEncoder
from GNN_file import PaperGNN
from Extarct_Actions import Action_Extractor
from Adjacency_Matrix import adj_mat
from pre_processing import Preprocessing
from MLP_Action import ActionMLP
from TR6 import decentralizedModel

import torch
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU setups

# Optional: Force deterministic behavior on GPU (may impact performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
grid = [[0 for _ in range(20)] for _ in range(20)]  # Generate grid map
dataset_generator = DatasetGenerator(num_cases=5000, num_agents=6, grid=grid)
# cases = dataset_generator.generate_cases()
# dataset_generator.save_cases_to_file(cases, "dataset.json")
# print(f"Generated and saved {len(cases)} cases.")
cases = dataset_generator.load_cases_from_file("dataset.json")

In [None]:
p = Preprocessing(grid,cases,3)
data_tensors  = p.begin()

In [None]:
adj = adj_mat(cases,3)
adj_matrices = adj.get_adj_mat()

In [14]:
trained_model = decentralizedModel(cases, data_tensors,rfov = 3,adjacency_matrix = adj_matrices,num_of_robots = 6, k_hops = 2, num_epochs = 30)

No checkpoint found, starting training from scratch.


In [15]:
if torch.cuda.is_available():
    trained_model.to(device="cuda")

In [16]:
trained_model.train_model()

Epoch [1/30] Train Loss: 0.2736 Val Loss: 0.2130 Val Acc: 91.95%
Checkpoint saved at epoch 1.
Epoch [2/30] Train Loss: 0.1615 Val Loss: 0.2918 Val Acc: 87.40%
Checkpoint saved at epoch 2.
Epoch [3/30] Train Loss: 0.1447 Val Loss: 0.2793 Val Acc: 89.12%
Checkpoint saved at epoch 3.
Epoch [4/30] Train Loss: 0.1359 Val Loss: 0.1417 Val Acc: 95.23%
Checkpoint saved at epoch 4.
Epoch [5/30] Train Loss: 0.1295 Val Loss: 0.2683 Val Acc: 90.28%
Checkpoint saved at epoch 5.
Epoch [6/30] Train Loss: 0.1228 Val Loss: 0.4398 Val Acc: 88.05%
Checkpoint saved at epoch 6.
Epoch [7/30] Train Loss: 0.1184 Val Loss: 0.1224 Val Acc: 95.31%
Checkpoint saved at epoch 7.
Epoch [8/30] Train Loss: 0.1149 Val Loss: 0.1182 Val Acc: 95.42%
Checkpoint saved at epoch 8.
Epoch [9/30] Train Loss: 0.1121 Val Loss: 0.1354 Val Acc: 95.20%
Checkpoint saved at epoch 9.
Epoch [10/30] Train Loss: 0.1105 Val Loss: 0.1649 Val Acc: 92.77%
Checkpoint saved at epoch 10.
Epoch [11/30] Train Loss: 0.1071 Val Loss: 0.1780 Val Acc:

In [17]:
trained_model.test_model()

Test Loss: 0.0848, Test Accuracy: 96.78%


(0.08476832363268604, 0.9677687198067633)

In [None]:
def prepare_data(tensors,num_of_robots,cases,rfov):
        dataset = []
        adj_obj = adj_mat(cases,rfov)
        adj_cases = adj_obj.get_adj_mat()
        for itr in range(len(tensors.keys())):
            # first_channels = self.tensors[itr]['channel 1']
            second_channels = tensors[itr]['channel 2']
            third_channels = tensors[itr]['channel 3']

            for step in range(len(second_channels)):
                batch_channels = []
                for i in range(num_of_robots):
                    agent_channels = np.stack([
                        second_channels[step][i],
                        third_channels[step][i]
                    ])
                    batch_channels.append(agent_channels)
                batch_tensor = torch.tensor(np.array(batch_channels), dtype=torch.float32)
                adj_step = torch.tensor(np.array(adj_cases[itr][step]),dtype=torch.float32)
                dataset.append((batch_tensor,adj_step))
        return dataset

In [None]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
cnn = ObservationEncoder().to(device=device)
gnn = PaperGNN(K = 2).to(device=device)
mlp = ActionMLP().to(device=device)
checkpoint = torch.load('TR6_99_SR.pth', map_location=device)
cnn.load_state_dict(checkpoint['cnn_state_dict'])
gnn.load_state_dict(checkpoint['gnn_state_dict'])
mlp.load_state_dict(checkpoint['mlp_state_dict'])
cnn.eval()
gnn.eval()
mlp.eval()

ActionMLP(
  (fc1): Linear(in_features=128, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [32]:
np.random.seed(61)
random.seed(61)
num_of_robots = 6
new_dataset_generator = DatasetGenerator(num_cases=100, num_agents=num_of_robots, grid=grid)
eval_cases = new_dataset_generator.generate_cases()

case  0
case added
case  1
case added
case  2
case added
case  3
case added
case  4
case added
case  5
case added
case  6
case added
case  7
case added
case  8
case added
case  9
case added
case  10
case added
case  11
case added
case  12
case added
case  13
case added
case  14
case added
case  15
case added
case  16
case added
case  17
case added
case  18
case added
case  19
case added
case  20
case added
case  21
case added
case  22
case added
case  23
case added
case  24
case added
case  25
case added
case  26
case added
case  27
case added
case  28
case added
case  29
case added
case  30
case added
case  31
case added
case  32
case added
case  33
case added
case  34
case added
case  35
case added
case  36
case added
case  37
case added
case  38
case added
case  39
case added
case  40
case added
case  41
case added
case  42
case added
case  43
case added
case  44
case added
case  45
case added
case  46
case added
case  47
case added
case  48
case added
case  49
case added
case  50
c

In [34]:
print(eval_cases[1]['start_positions'])
print(eval_cases[1]['goal_positions'])
print(eval_cases[1]['paths'])
# action_extractor = Action_Extractor(eval_cases, 6)
# actions = action_extractor.extract()
# print(actions)

[(12, 9), (9, 10), (14, 10), (19, 11), (5, 19), (8, 13)]
[(13, 7), (4, 5), (1, 17), (11, 15), (0, 17), (12, 2)]
[[(12, 9), (12, 8), (12, 7), (13, 7)], [(9, 10), (8, 10), (8, 9), (7, 9), (7, 8), (6, 8), (6, 7), (5, 7), (5, 6), (4, 6), (4, 5)], [(14, 10), (13, 10), (12, 10), (11, 10), (10, 10), (9, 10), (8, 10), (7, 10), (7, 11), (6, 11), (6, 12), (5, 12), (5, 13), (4, 13), (4, 14), (3, 14), (3, 15), (2, 15), (2, 16), (1, 16), (1, 17)], [(19, 11), (18, 11), (17, 11), (16, 11), (15, 11), (14, 11), (14, 12), (13, 12), (13, 13), (12, 13), (12, 14), (11, 14), (11, 15)], [(5, 19), (4, 19), (3, 19), (2, 19), (1, 19), (1, 18), (0, 18), (0, 17)], [(8, 13), (8, 12), (8, 11), (8, 10), (8, 9), (8, 8), (8, 7), (8, 6), (8, 5), (9, 5), (9, 4), (10, 4), (10, 3), (11, 3), (11, 2), (12, 2)]]


In [35]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
success = 0
actions = [[0,0],[-1,0],[0,1],[1,0],[0,-1]]
seed = 42
torch.manual_seed(seed)
with torch.inference_mode():
    for case in eval_cases:
        threshold = 0
        for path in case['paths']:
            threshold+= len(path)
        threshold = 3*threshold # from the paper
        num_of_steps = 0
        robots_reached = set()
        step_number = 0
        current_nodes = []
        path_robot_zero = []
        while num_of_steps <= threshold and len(robots_reached) <=6:
            paths = []
            goals = []
            for robot in range(num_of_robots):
                if step_number == 0:
                    current_node = case['paths'][robot][step_number]
                    current_nodes.append(list(current_node))
                    if robot == 0:
                        print(current_node)
                        print(list(case['paths'][robot][-1]))
                    paths.append([list(current_node)])
                    goals.append(list(case['goal_positions'][robot]))
                else:
                    paths.append([current_nodes[robot]])
                    goals.append(list(case['goal_positions'][robot]))
            # print(current_nodes[0])
            step_case = {"start_positions": current_nodes, "goal_positions": goals,"paths":paths}
            # print(paths)
            # print("start_preprocessing")

            p_new = Preprocessing(grid,[step_case],3)
            data_tensors_new  = p_new.begin()
            dataset = prepare_data(data_tensors_new,num_of_robots,[step_case],3)
            x = torch.tensor(np.array([tensor for tensor,adj_ in dataset]),dtype=torch.float32)
            x = x.view(num_of_robots,2,9,9)
            # print(np.array(x).shape)
            new_encoder = cnn(x)
            # new_encoded_tensors = new_encoder.begin()
            # print(new_encoded_tensors[0])
            # print("start adj")
            new_adj = adj_mat([step_case],3)
            new_adj_matrices = new_adj.get_adj_mat()
            # print("start gnn")
            new_comm_gnn = gnn(new_encoder,torch.tensor(np.array([adj_ for tensor,adj_ in dataset])).view(6,6)).view(1,num_of_robots,128)
            # print(new_comm_gnn.shape)
            future_nodes = list(range(0,num_of_robots))
            for robot in range(num_of_robots):
                if robot not in robots_reached:
                    # prediction = model.model(torch.tensor(reshaped_gnn_features[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    # print(prediction)
                    predicted_class = torch.argmax(prediction, dim=1)
                    right_state = False
                    next = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    # future_nodes[robot] = (next[0],next[1])
                    if (next[0]>=20 or next[0]<0 or next[1]>=20 or next[1]<0):
                        # current_nodes[robot] = [current_nodes[robot][0] + actions[0][0],current_nodes[robot][1] + actions[0][1]]
                        pass
                    else:
                        current_nodes[robot] = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    if (current_nodes[robot][0] == goals[robot][0]) and (current_nodes[robot][1] == goals[robot][1]):
                        robots_reached.add(robot)
                        print("trueeeeee")
            step_number = 1
            num_of_steps+=1
        if len(robots_reached) == num_of_robots:
            success += 1


(15, 5)
[12, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee


  prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))


(12, 9)
[13, 7]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(9, 0)
[11, 19]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(15, 3)
[14, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(18, 14)
[14, 4]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(4, 5)
[16, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 1)
[8, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 4)
[19, 12]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(2, 14)
[17, 8]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(6, 19)
[10, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(10, 9)
[6, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(2, 9)
[17, 7]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(7, 4)
[9, 19]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(0, 3)
[14, 8]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(8, 3)
[15, 15]

In [36]:
print(success)

99


In [37]:
num_of_robots = 8
new_dataset_generator = DatasetGenerator(num_cases=100, num_agents=num_of_robots, grid=grid)
eval_cases = new_dataset_generator.generate_cases()

case  0
case added
case  1
case added
case  2
case added
case  3
case added
case  4
case added
case  5
case added
case  6
case added
case  7
case added
case  8
case added
case  9
case added
case  10
case added
case  11
case added
case  12
case added
case  13
case added
case  14
case added
case  15
case added
case  16
case added
case  17
case added
case  18
case added
case  19
case added
case  20
case added
case  21
case added
case  22
case added
case  23
case added
case  24
case added
case  25
case added
case  26
case added
case  27
case added
case  28
case added
case  29
case added
case  30
case added
case  31
case added
case  32
case added
case  33
case added
case  34
case added
case  35
case added
case  36
case added
case  37
case added
case  38
case added
case  39
case added
case  40
case added
case  41
case added
case  42
case added
case  43
case added
case  44
case added
case  45
case added
case  46
case added
case  47
case added
case  48
case added
case  49
case added
case  50
c

In [38]:
print(eval_cases[0]['start_positions'])
print(eval_cases[0]['goal_positions'])
print(eval_cases[0]['paths'])
# action_extractor = Action_Extractor(eval_cases, 6)
# actions = action_extractor.extract()
# print(actions)

[(11, 9), (17, 5), (14, 12), (14, 10), (8, 4), (7, 3), (13, 19), (4, 10)]
[(3, 10), (12, 14), (4, 7), (0, 9), (10, 19), (6, 5), (2, 6), (0, 13)]
[[(11, 9), (10, 9), (9, 9), (8, 9), (7, 9), (6, 9), (5, 9), (4, 9), (3, 9), (3, 10)], [(17, 5), (17, 6), (17, 7), (17, 8), (17, 9), (16, 9), (16, 10), (15, 10), (15, 11), (14, 11), (14, 12), (13, 12), (13, 13), (12, 13), (12, 14)], [(14, 12), (13, 12), (12, 12), (11, 12), (10, 12), (9, 12), (8, 12), (8, 11), (7, 11), (7, 10), (6, 10), (6, 9), (5, 9), (5, 8), (4, 8), (4, 7)], [(14, 10), (13, 10), (12, 10), (11, 10), (10, 10), (9, 10), (8, 10), (7, 10), (6, 10), (5, 10), (4, 10), (4, 9), (3, 9), (2, 9), (1, 9), (0, 9)], [(8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (9, 9), (9, 10), (9, 11), (9, 12), (9, 13), (9, 14), (9, 15), (9, 16), (9, 17), (9, 18), (9, 19), (10, 19)], [(7, 3), (7, 4), (6, 4), (6, 5)], [(13, 19), (13, 18), (13, 17), (12, 17), (12, 16), (11, 16), (11, 15), (10, 15), (10, 14), (9, 14), (8, 14), (8, 13), (8, 12), (7, 12), (7,

In [39]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
success = 0
actions = [[0,0],[-1,0],[0,1],[1,0],[0,-1]]
seed = 42
torch.manual_seed(seed)
with torch.inference_mode():
    for case in eval_cases:
        threshold = 0
        for path in case['paths']:
            threshold+= len(path)
        threshold = 3*threshold # from the paper
        num_of_steps = 0
        robots_reached = set()
        step_number = 0
        current_nodes = []
        path_robot_zero = []
        while num_of_steps <= threshold and len(robots_reached) <=6:
            paths = []
            goals = []
            for robot in range(num_of_robots):
                if step_number == 0:
                    current_node = case['paths'][robot][step_number]
                    current_nodes.append(list(current_node))
                    if robot == 0:
                        print(current_node)
                        print(list(case['paths'][robot][-1]))
                    paths.append([list(current_node)])
                    goals.append(list(case['goal_positions'][robot]))
                else:
                    paths.append([current_nodes[robot]])
                    goals.append(list(case['goal_positions'][robot]))
            # print(current_nodes[0])
            step_case = {"start_positions": current_nodes, "goal_positions": goals,"paths":paths}
            # print(paths)
            # print("start_preprocessing")

            p_new = Preprocessing(grid,[step_case],3)
            data_tensors_new  = p_new.begin()
            dataset = prepare_data(data_tensors_new,num_of_robots,[step_case],3)
            x = torch.tensor(np.array([tensor for tensor,adj_ in dataset]),dtype=torch.float32)
            x = x.view(num_of_robots,2,9,9)
            # print(np.array(x).shape)
            new_encoder = cnn(x)
            # new_encoded_tensors = new_encoder.begin()
            # print(new_encoded_tensors[0])
            # print("start adj")
            new_adj = adj_mat([step_case],3)
            new_adj_matrices = new_adj.get_adj_mat()
            # print("start gnn")
            new_comm_gnn = gnn(new_encoder,torch.tensor(np.array([adj_ for tensor,adj_ in dataset])).view(num_of_robots,num_of_robots)).view(1,num_of_robots,128)
            # print(new_comm_gnn.shape)
            future_nodes = list(range(0,num_of_robots))
            for robot in range(num_of_robots):
                if robot not in robots_reached:
                    # prediction = model.model(torch.tensor(reshaped_gnn_features[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    # print(prediction)
                    predicted_class = torch.argmax(prediction, dim=1)
                    right_state = False
                    next = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    # future_nodes[robot] = (next[0],next[1])
                    if (next[0]>=20 or next[0]<0 or next[1]>=20 or next[1]<0):
                        # current_nodes[robot] = [current_nodes[robot][0] + actions[0][0],current_nodes[robot][1] + actions[0][1]]
                        pass
                    else:
                        current_nodes[robot] = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    if (current_nodes[robot][0] == goals[robot][0]) and (current_nodes[robot][1] == goals[robot][1]):
                        robots_reached.add(robot)
                        print("trueeeeee")
            step_number = 1
            num_of_steps+=1
        if len(robots_reached) == num_of_robots:
            success += 1


(11, 9)
[3, 10]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 6)
[4, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(12, 14)
[10, 6]
trueeeeee


  prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))


trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(13, 9)
[2, 16]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(13, 8)
[5, 3]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(12, 5)
[13, 1]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 19)
[2, 3]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(10, 18)
[18, 3]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 7)
[9, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 17)
[6, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(10, 18)
[6, 12]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(18, 8)
[15, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(13, 9)
[12, 8]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(10, 12)
[12, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
tr

In [40]:
print(success)

9


In [41]:
num_of_robots = 4
new_dataset_generator = DatasetGenerator(num_cases=100, num_agents=num_of_robots, grid=grid)
eval_cases = new_dataset_generator.generate_cases()

case  0
case added
case  1
case added
case  2
case added
case  3
case added
case  4
case added
case  5
case added
case  6
case added
case  7
case added
case  8
case added
case  9
case added
case  10
case added
case  11
case added
case  12
case added
case  13
case added
case  14
case added
case  15
case added
case  16
case added
case  17
case added
case  18
case added
case  19
case added
case  20
case added
case  21
case added
case  22
case added
case  23
case added
case  24
case added
case  25
case added
case  26
case added
case  27
case added
case  28
case added
case  29
case added
case  30
case added
case  31
case added
case  32
case added
case  33
case added
case  34
case added
case  35
case added
case  36
case added
case  37
case added
case  38
case added
case  39
case added
case  40
case added
case  41
case added
case  42
case added
case  43
case added
case  44
case added
case  45
case added
case  46
case added
case  47
case added
case  48
case added
case  49
case added
case  50
c

In [42]:
print(eval_cases[0]['start_positions'])
print(eval_cases[0]['goal_positions'])
print(eval_cases[0]['paths'])
# action_extractor = Action_Extractor(eval_cases, 6)
# actions = action_extractor.extract()
# print(actions)

[(1, 17), (9, 14), (14, 9), (7, 18)]
[(11, 15), (4, 9), (18, 6), (14, 16)]
[[(1, 17), (2, 17), (3, 17), (4, 17), (5, 17), (6, 17), (7, 17), (8, 17), (9, 17), (9, 16), (10, 16), (10, 15), (11, 15)], [(9, 14), (8, 14), (8, 13), (7, 13), (7, 12), (6, 12), (6, 11), (5, 11), (5, 10), (4, 10), (4, 9)], [(14, 9), (15, 9), (15, 8), (16, 8), (16, 7), (17, 7), (17, 6), (18, 6)], [(7, 18), (8, 18), (9, 18), (10, 18), (11, 18), (12, 18), (12, 17), (13, 17), (13, 16), (14, 16)]]


In [43]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
success = 0
actions = [[0,0],[-1,0],[0,1],[1,0],[0,-1]]
seed = 42
torch.manual_seed(seed)
with torch.inference_mode():
    for case in eval_cases:
        threshold = 0
        for path in case['paths']:
            threshold+= len(path)
        threshold = 3*threshold # from the paper
        num_of_steps = 0
        robots_reached = set()
        step_number = 0
        current_nodes = []
        path_robot_zero = []
        while num_of_steps <= threshold and len(robots_reached) <=6:
            paths = []
            goals = []
            for robot in range(num_of_robots):
                if step_number == 0:
                    current_node = case['paths'][robot][step_number]
                    current_nodes.append(list(current_node))
                    if robot == 0:
                        print(current_node)
                        print(list(case['paths'][robot][-1]))
                    paths.append([list(current_node)])
                    goals.append(list(case['goal_positions'][robot]))
                else:
                    paths.append([current_nodes[robot]])
                    goals.append(list(case['goal_positions'][robot]))
            # print(current_nodes[0])
            step_case = {"start_positions": current_nodes, "goal_positions": goals,"paths":paths}
            # print(paths)
            # print("start_preprocessing")

            p_new = Preprocessing(grid,[step_case],3)
            data_tensors_new  = p_new.begin()
            dataset = prepare_data(data_tensors_new,num_of_robots,[step_case],3)
            x = torch.tensor(np.array([tensor for tensor,adj_ in dataset]),dtype=torch.float32)
            x = x.view(num_of_robots,2,9,9)
            # print(np.array(x).shape)
            new_encoder = cnn(x)
            # new_encoded_tensors = new_encoder.begin()
            # print(new_encoded_tensors[0])
            # print("start adj")
            new_adj = adj_mat([step_case],3)
            new_adj_matrices = new_adj.get_adj_mat()
            # print("start gnn")
            new_comm_gnn = gnn(new_encoder,torch.tensor(np.array([adj_ for tensor,adj_ in dataset])).view(num_of_robots,num_of_robots)).view(1,num_of_robots,128)
            # print(new_comm_gnn.shape)
            future_nodes = list(range(0,num_of_robots))
            for robot in range(num_of_robots):
                if robot not in robots_reached:
                    # prediction = model.model(torch.tensor(reshaped_gnn_features[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    # print(prediction)
                    predicted_class = torch.argmax(prediction, dim=1)
                    right_state = False
                    next = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    # future_nodes[robot] = (next[0],next[1])
                    if (next[0]>=20 or next[0]<0 or next[1]>=20 or next[1]<0):
                        # current_nodes[robot] = [current_nodes[robot][0] + actions[0][0],current_nodes[robot][1] + actions[0][1]]
                        pass
                    else:
                        current_nodes[robot] = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    if (current_nodes[robot][0] == goals[robot][0]) and (current_nodes[robot][1] == goals[robot][1]):
                        robots_reached.add(robot)
                        print("trueeeeee")
            step_number = 1
            num_of_steps+=1
        if len(robots_reached) == num_of_robots:
            success += 1


(1, 17)
[11, 15]
trueeeeee
trueeeeee
trueeeeee
trueeeeee


  prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))


(3, 5)
[3, 0]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(16, 2)
[18, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(2, 14)
[16, 1]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(7, 16)
[7, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(12, 17)
[16, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(19, 17)
[14, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(11, 17)
[5, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(9, 19)
[4, 15]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(19, 14)
[12, 7]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(11, 9)
[0, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 13)
[2, 15]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(5, 7)
[5, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(18, 12)
[2, 6]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(12, 8)
[9, 19]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(5, 12)
[16, 15]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(19, 16)
[15, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(9, 5)
[13, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(9, 8)
[6, 7]
trueeeeee
trueeeeee
trueeeeee
tr

In [44]:
print(success)

100
