In [None]:
from Dataset_Generator import DatasetGenerator
from Encoder import ObservationEncoder
from GNN_file import PaperGNN
from Extarct_Actions import Action_Extractor
from Adjacency_Matrix import adj_mat
from pre_processing import Preprocessing
from MLP_Action import ActionMLP
from TR6 import decentralizedModel

import torch
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU setups

# Optional: Force deterministic behavior on GPU (may impact performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
grid = [[0 for _ in range(20)] for _ in range(20)]  # Generate grid map
dataset_generator = DatasetGenerator(num_cases=5000, num_agents=6, grid=grid)
# cases = dataset_generator.generate_cases()
# dataset_generator.save_cases_to_file(cases, "dataset.json")
# print(f"Generated and saved {len(cases)} cases.")
cases = dataset_generator.load_cases_from_file("dataset.json")

In [None]:
p = Preprocessing(grid,cases,3)
data_tensors  = p.begin()

In [None]:
adj = adj_mat(cases,3)
adj_matrices = adj.get_adj_mat()

In [31]:
trained_model = decentralizedModel(cases, data_tensors,rfov = 3,adjacency_matrix = adj_matrices,num_of_robots = 6, k_hops = 2, num_epochs = 30)

No checkpoint found, starting training from scratch.


In [32]:
if torch.cuda.is_available():
    trained_model.to(device="cuda")

In [33]:
trained_model.train_model()

Epoch [1/30] Train Loss: 1.0030 Val Loss: 0.9697 Val Acc: 93.40%
Checkpoint saved at epoch 1.
Epoch [2/30] Train Loss: 0.9616 Val Loss: 0.9771 Val Acc: 93.34%
Checkpoint saved at epoch 2.
Epoch [3/30] Train Loss: 0.9581 Val Loss: 0.9727 Val Acc: 93.37%
Checkpoint saved at epoch 3.
Epoch [4/30] Train Loss: 0.9556 Val Loss: 1.1070 Val Acc: 79.49%
Checkpoint saved at epoch 4.
Epoch [5/30] Train Loss: 0.9540 Val Loss: 0.9788 Val Acc: 93.00%
Checkpoint saved at epoch 5.
Epoch [6/30] Train Loss: 0.9526 Val Loss: 0.9839 Val Acc: 92.05%
Checkpoint saved at epoch 6.
Epoch [7/30] Train Loss: 0.9518 Val Loss: 0.9946 Val Acc: 90.83%
Checkpoint saved at epoch 7.
Epoch [8/30] Train Loss: 0.9507 Val Loss: 0.9559 Val Acc: 94.78%
Checkpoint saved at epoch 8.
Epoch [9/30] Train Loss: 0.9505 Val Loss: 0.9692 Val Acc: 93.48%
Checkpoint saved at epoch 9.
Epoch [10/30] Train Loss: 0.9497 Val Loss: 0.9710 Val Acc: 93.33%
Checkpoint saved at epoch 10.
Epoch [11/30] Train Loss: 0.9491 Val Loss: 0.9916 Val Acc:

In [34]:
trained_model.test_model()

Test Loss: 0.9430, Test Accuracy: 96.17%


(0.9430361258378928, 0.961682895531401)

In [None]:
def prepare_data(tensors,num_of_robots,cases,rfov):
        dataset = []
        adj_obj = adj_mat(cases,rfov)
        adj_cases = adj_obj.get_adj_mat()
        for itr in range(len(tensors.keys())):
            # first_channels = self.tensors[itr]['channel 1']
            second_channels = tensors[itr]['channel 2']
            third_channels = tensors[itr]['channel 3']

            for step in range(len(second_channels)):
                batch_channels = []
                for i in range(num_of_robots):
                    agent_channels = np.stack([
                        second_channels[step][i],
                        third_channels[step][i]
                    ])
                    batch_channels.append(agent_channels)
                batch_tensor = torch.tensor(np.array(batch_channels), dtype=torch.float32)
                adj_step = torch.tensor(np.array(adj_cases[itr][step]),dtype=torch.float32)
                dataset.append((batch_tensor,adj_step))
        return dataset

In [None]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
cnn = ObservationEncoder().to(device=device)
gnn = PaperGNN(K = 2).to(device=device)
mlp = ActionMLP().to(device=device)
checkpoint = torch.load('TR6_96_acc.pth', map_location=device)
cnn.load_state_dict(checkpoint['cnn_state_dict'])
gnn.load_state_dict(checkpoint['gnn_state_dict'])
mlp.load_state_dict(checkpoint['mlp_state_dict'])
cnn.eval()
gnn.eval()
mlp.eval()

ActionMLP(
  (fc1): Linear(in_features=128, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [37]:
num_of_robots = 6
new_dataset_generator = DatasetGenerator(num_cases=100, num_agents=num_of_robots, grid=grid)
eval_cases = new_dataset_generator.generate_cases()

case  0
case added
case  1
case added
case  2
case added
case  3
case added
case  4
case added
case  5
case added
case  6
case added
case  7
case added
case  8
case added
case  9
case added
case  10
case added
case  11
case added
case  12
case added
case  13
case added
case  14
case added
case  15
case added
case  16
case added
case  17
case added
case  18
case added
case  19
case added
case  20
case added
case  21
case added
case  22
case added
case  23
case added
case  24
case added
case  25
case added
case  26
case added
case  27
case added
case  28
case added
case  29
case added
case  30
case added
case  31
case added
case  32
case added
case  33
case added
case  34
case added
case  35
case added
case  36
case added
case  37
case added
case  38
case added
case  39
case added
case  40
case added
case  41
case added
case  42
case added
case  43
case added
case  44
case added
case  45
case added
case  46
case added
case  47
case added
case  48
case added
case  49
case added
case  50
c

In [38]:
print(eval_cases[0]['start_positions'])
print(eval_cases[0]['goal_positions'])
print(eval_cases[0]['paths'])
# action_extractor = Action_Extractor(eval_cases, 6)
# actions = action_extractor.extract()
# print(actions)

[(0, 10), (5, 4), (11, 14), (4, 11), (18, 2), (7, 0)]
[(0, 0), (6, 14), (10, 5), (18, 19), (9, 10), (10, 16)]
[[(0, 10), (0, 9), (0, 8), (0, 7), (0, 6), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1), (0, 0)], [(5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11), (5, 12), (5, 13), (5, 14), (6, 14)], [(11, 14), (11, 13), (11, 12), (11, 11), (11, 10), (11, 9), (11, 8), (11, 7), (11, 6), (10, 6), (10, 5)], [(4, 11), (5, 11), (6, 11), (7, 11), (8, 11), (9, 11), (10, 11), (10, 12), (11, 12), (11, 13), (12, 13), (12, 14), (13, 14), (13, 15), (14, 15), (14, 16), (15, 16), (15, 17), (16, 17), (16, 18), (17, 18), (17, 19), (18, 19)], [(18, 2), (17, 2), (16, 2), (16, 3), (15, 3), (15, 4), (14, 4), (14, 5), (13, 5), (13, 6), (12, 6), (12, 7), (11, 7), (11, 8), (10, 8), (10, 9), (9, 9), (9, 10)], [(7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (7, 10), (7, 11), (7, 12), (7, 13), (7, 14), (8, 14), (8, 15), (9, 15), (9, 16), (10, 16)]]


In [39]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
success = 0
actions = [[0,0],[-1,0],[0,1],[1,0],[0,-1]]
seed = 42
torch.manual_seed(seed)
with torch.inference_mode():
    for case in eval_cases:
        threshold = 0
        for path in case['paths']:
            threshold+= len(path)
        threshold = 3*threshold # from the paper
        num_of_steps = 0
        robots_reached = set()
        step_number = 0
        current_nodes = []
        path_robot_zero = []
        while num_of_steps <= threshold and len(robots_reached) <=6:
            paths = []
            goals = []
            for robot in range(num_of_robots):
                if step_number == 0:
                    current_node = case['paths'][robot][step_number]
                    current_nodes.append(list(current_node))
                    if robot == 0:
                        print(current_node)
                        print(list(case['paths'][robot][-1]))
                    paths.append([list(current_node)])
                    goals.append(list(case['goal_positions'][robot]))
                else:
                    paths.append([current_nodes[robot]])
                    goals.append(list(case['goal_positions'][robot]))
            # print(current_nodes[0])
            step_case = {"start_positions": current_nodes, "goal_positions": goals,"paths":paths}
            # print(paths)
            # print("start_preprocessing")

            p_new = Preprocessing(grid,[step_case],3)
            data_tensors_new  = p_new.begin()
            dataset = prepare_data(data_tensors_new,num_of_robots,[step_case],3)
            x = torch.tensor(np.array([tensor for tensor,adj_ in dataset]),dtype=torch.float32)
            x = x.view(num_of_robots,2,9,9)
            # print(np.array(x).shape)
            new_encoder = cnn(x)
            # new_encoded_tensors = new_encoder.begin()
            # print(new_encoded_tensors[0])
            # print("start adj")
            new_adj = adj_mat([step_case],3)
            new_adj_matrices = new_adj.get_adj_mat()
            # print("start gnn")
            new_comm_gnn = gnn(new_encoder,torch.tensor(np.array([adj_ for tensor,adj_ in dataset])).view(6,6)).view(1,num_of_robots,128)
            # print(new_comm_gnn.shape)
            future_nodes = list(range(0,num_of_robots))
            for robot in range(num_of_robots):
                if robot not in robots_reached:
                    # prediction = model.model(torch.tensor(reshaped_gnn_features[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    # print(prediction)
                    predicted_class = torch.argmax(prediction, dim=1)
                    right_state = False
                    next = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    # future_nodes[robot] = (next[0],next[1])
                    if (next[0]>=20 or next[0]<0 or next[1]>=20 or next[1]<0):
                        # current_nodes[robot] = [current_nodes[robot][0] + actions[0][0],current_nodes[robot][1] + actions[0][1]]
                        pass
                    else:
                        current_nodes[robot] = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    if (current_nodes[robot][0] == goals[robot][0]) and (current_nodes[robot][1] == goals[robot][1]):
                        robots_reached.add(robot)
                        print("trueeeeee")
            step_number = 1
            num_of_steps+=1
        if len(robots_reached) == num_of_robots:
            success += 1


(0, 10)
[0, 0]
trueeeeee
trueeeeee
trueeeeee


  prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))


trueeeeee
trueeeeee
trueeeeee
(4, 14)
[10, 3]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(13, 6)
[1, 19]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(7, 0)
[7, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 16)
[5, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 19)
[16, 0]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(0, 12)
[0, 6]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 15)
[13, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(3, 14)
[13, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(16, 18)
[10, 9]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(14, 5)
[11, 3]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(13, 3)
[17, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(15, 12)
[7, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(4, 15)
[19, 8]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
tr

In [40]:
print(success)

93


In [41]:
num_of_robots = 4
new_dataset_generator = DatasetGenerator(num_cases=100, num_agents=num_of_robots, grid=grid)
eval_cases = new_dataset_generator.generate_cases()

case  0
case added
case  1
case added
case  2
case added
case  3
case added
case  4
case added
case  5
case added
case  6
case added
case  7
case added
case  8
case added
case  9
case added
case  10
case added
case  11
case added
case  12
case added
case  13
case added
case  14
case added
case  15
case added
case  16
case added
case  17
case added
case  18
case added
case  19
case added
case  20
case added
case  21
case added
case  22
case added
case  23
case added
case  24
case added
case  25
case added
case  26
case added
case  27
case added
case  28
case added
case  29
case added
case  30
case added
case  31
case added
case  32
case added
case  33
case added
case  34
case added
case  35
case added
case  36
case added
case  37
case added
case  38
case added
case  39
case added
case  40
case added
case  41
case added
case  42
case added
case  43
case added
case  44
case added
case  45
case added
case  46
case added
case  47
case added
case  48
case added
case  49
case added
case  50
c

In [42]:
print(eval_cases[0]['start_positions'])
print(eval_cases[0]['goal_positions'])
print(eval_cases[0]['paths'])
# action_extractor = Action_Extractor(eval_cases, 6)
# actions = action_extractor.extract()
# print(actions)

[(17, 7), (9, 18), (3, 16), (17, 15)]
[(8, 14), (8, 18), (5, 13), (14, 3)]
[[(17, 7), (16, 7), (15, 7), (14, 7), (14, 8), (13, 8), (13, 9), (12, 9), (12, 10), (11, 10), (11, 11), (10, 11), (10, 12), (9, 12), (9, 13), (8, 13), (8, 14)], [(9, 18), (8, 18)], [(3, 16), (3, 15), (3, 14), (4, 14), (4, 13), (5, 13)], [(17, 15), (17, 14), (17, 13), (17, 12), (17, 11), (17, 10), (17, 9), (17, 8), (17, 7), (17, 6), (16, 6), (16, 5), (15, 5), (15, 4), (14, 4), (14, 3)]]


In [43]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
success = 0
actions = [[0,0],[-1,0],[0,1],[1,0],[0,-1]]
seed = 42
torch.manual_seed(seed)
with torch.inference_mode():
    for case in eval_cases:
        threshold = 0
        for path in case['paths']:
            threshold+= len(path)
        threshold = 3*threshold # from the paper
        num_of_steps = 0
        robots_reached = set()
        step_number = 0
        current_nodes = []
        path_robot_zero = []
        while num_of_steps <= threshold and len(robots_reached) <=6:
            paths = []
            goals = []
            for robot in range(num_of_robots):
                if step_number == 0:
                    current_node = case['paths'][robot][step_number]
                    current_nodes.append(list(current_node))
                    if robot == 0:
                        print(current_node)
                        print(list(case['paths'][robot][-1]))
                    paths.append([list(current_node)])
                    goals.append(list(case['goal_positions'][robot]))
                else:
                    paths.append([current_nodes[robot]])
                    goals.append(list(case['goal_positions'][robot]))
            # print(current_nodes[0])
            step_case = {"start_positions": current_nodes, "goal_positions": goals,"paths":paths}
            # print(paths)
            # print("start_preprocessing")

            p_new = Preprocessing(grid,[step_case],3)
            data_tensors_new  = p_new.begin()
            dataset = prepare_data(data_tensors_new,num_of_robots,[step_case],3)
            x = torch.tensor(np.array([tensor for tensor,adj_ in dataset]),dtype=torch.float32)
            x = x.view(num_of_robots,2,9,9)
            # print(np.array(x).shape)
            new_encoder = cnn(x)
            # new_encoded_tensors = new_encoder.begin()
            # print(new_encoded_tensors[0])
            # print("start adj")
            new_adj = adj_mat([step_case],3)
            new_adj_matrices = new_adj.get_adj_mat()
            # print("start gnn")
            new_comm_gnn = gnn(new_encoder,torch.tensor(np.array([adj_ for tensor,adj_ in dataset])).view(num_of_robots,num_of_robots)).view(1,num_of_robots,128)
            # print(new_comm_gnn.shape)
            future_nodes = list(range(0,num_of_robots))
            for robot in range(num_of_robots):
                if robot not in robots_reached:
                    # prediction = model.model(torch.tensor(reshaped_gnn_features[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))
                    # print(prediction)
                    predicted_class = torch.argmax(prediction, dim=1)
                    right_state = False
                    next = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    # future_nodes[robot] = (next[0],next[1])
                    if (next[0]>=20 or next[0]<0 or next[1]>=20 or next[1]<0):
                        # current_nodes[robot] = [current_nodes[robot][0] + actions[0][0],current_nodes[robot][1] + actions[0][1]]
                        pass
                    else:
                        current_nodes[robot] = [current_nodes[robot][0] + actions[predicted_class][0],current_nodes[robot][1] + actions[predicted_class][1]]
                    if (current_nodes[robot][0] == goals[robot][0]) and (current_nodes[robot][1] == goals[robot][1]):
                        robots_reached.add(robot)
                        print("trueeeeee")
            step_number = 1
            num_of_steps+=1
        if len(robots_reached) == num_of_robots:
            success += 1


(17, 7)
[8, 14]
trueeeeee
trueeeeee
trueeeeee
trueeeeee


  prediction = mlp(torch.tensor(new_comm_gnn[0][robot],dtype=torch.float32).unsqueeze(0).to(device))


(7, 5)
[3, 11]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(9, 2)
[17, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(11, 12)
[18, 7]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(0, 3)
[6, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 19)
[3, 0]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(14, 16)
[1, 19]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(15, 6)
[13, 4]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(17, 3)
[8, 0]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(17, 11)
[18, 8]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(1, 3)
[13, 17]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(14, 0)
[7, 10]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(6, 18)
[0, 4]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(10, 13)
[15, 5]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(8, 2)
[4, 10]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(8, 4)
[3, 2]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(16, 11)
[11, 16]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(11, 9)
[3, 0]
trueeeeee
trueeeeee
trueeeeee
trueeeeee
(12, 17)
[2, 7]
trueeeeee
trueeeeee
trueeeeee
trueeee

In [44]:
print(success)

99
