In [None]:
from federated import agglomerate
import torch
from time import sleep
from server import move_left, move_forward, move_right, stop_robot
from server import display_camera_stream
from server import get_state, set_drive, get_weights, set_weights
import numpy as np

In [None]:
def x_y_to_action(x=float, y=float):
    """
    X and Y are clamped to between 0 and 1, -1, -1 if not applicable
    """
    if x < 0 or y < 0:
        return (0, 0.2)

    if x < 0.5:
        direction = 1
    else:
        direction = -1
    turn_speed = abs(x-0.5)/0.5

    forward_speed = abs(y-0.5)/0.5

    return (forward_speed, turn_speed*direction)

In [None]:
def PSO(colour_cords: list[dict], colors):
    best_score = (-1, -1)
    best_robot = 0
    acts = [(0,0)]*len(colour_cords)
    for index, robot in enumerate(colour_cords): 
        box_cords = robot["red"]
        if box_cords[0]>= 0 and box_cords[1]>= 0:
            score = tuple(np.subtract((0.5, 0.5), tuple(box_cords)))
        else:
            score = (-1, -1)
        acts[index] = x_y_to_action(*box_cords)
        
        if score > tuple(best_score):
            best_robot = index
    
    if len(colour_cords) > 1 and best_score > (0, 0):
        for index, robot in enumerate(colour_cords): 
            if index == best_robot:
                continue
            best_robot_colour = colors[best_robot]
            robot_cords = robot[best_robot_colour]
            if robot_cords[0]>= 0 and robot_cords[1]>= 0:
                random_int = np.random.random(1)[0]
                acts[index] = random_int * acts[index] + (1-random_int) * x_y_to_action(*robot[best_robot_colour]) 

    return acts

In [None]:
# Test connection
IPS = ["194.47.156.140", "194.47.156.22", "194.47.156.221"]
COLOURS = ["blue", "yellow", "pink"]
for i in range(0, 100):
    states = []
    for index, IP in enumerate(IPS):
        states.append(get_state(IP))
states

In [None]:
USE_FL = False
weights = [None]*len(IPS)
states = [(0,0)]*len(IPS)
for i in range(0, 20):
    states = []
    for index, IP in enumerate(IPS):
        state = get_state(IP)
        print(f"{IP=} | {COLOURS[index]} | {state=}")
        states.append(state)

    descisions = PSO(states, COLOURS)

    for index, IP in enumerate(IPS):
        forward = np.clip(descisions[index], a_min=0.1, a_max=0.3)[0]
        turn = np.clip(descisions[index], a_min=-0.5, a_max=0.5)[1]
        move_robot =  set_drive(IP, forward, turn)

    sleep(1) 

    for index, IP in enumerate(IPS):
        stop_robot(IP)
    
    if USE_FL and i % 10 == 0:
        for index, IP in enumerate(IPS):
            weights[index] = get_weights(IP)

        agg_weight = agglomerate(weights)
        
        for index, IP in enumerate(IPS):
            send_weights = set_weights(IP, agg_weight)
            pass

for index, IP in enumerate(IPS):
    stop_robot(IP)
if USE_FL:
    for index, IP in enumerate(IPS):
        weights[index] = get_weights(IP)

    agg_weight = agglomerate(weights)
    torch.save(agg_weight, "model.pth")

In [None]:
states

In [None]:
for index, IP in enumerate(IPS):
    stop_robot(IP)