In [1]:
import torch
from torch import nn
from time import sleep
import threading

In [24]:
class Box:
    def __init__(self, net, id_ , input_size, name=None):
        self.name = "Id"+str(id_) if name == None else name
        self.id = id_
        self.net = net
        self.input_size = input_size
        self.listeners = []
        self.inputs = [0]*input_size
        self.inputs_ready = [False]*input_size
        self.thread = threading.Thread(target=self._predict)
        self.thread.start()
    
    def add_listener(self, listener):
        self.listeners.append(listener)
    
    def set_input(self, idx, input_):
        self.inputs[idx] = input_
        self.inputs_ready[idx] = True
    
    def set_inputs(self, inputs):
        self.inputs = inputs
        self.inputs_ready = [True]*self.input_size
    
    def create_input_tensor(self):
        return torch.FloatTensor(self.inputs)
    
    def send_output(self, listener, output):
        #TODO: Send via IP or something here
        listener.set_input(self.id, output)
    
    def send_prediction(self, prediction):
        for listener in self.listeners:
            self.send_output(listener, prediction)
      
    def _predict(self):
        while True:
            if all(self.inputs_ready):
                prediction = self.net(self.create_input_tensor())
                self.send_prediction(prediction)
                print(self.name + "\tPred: " + str(prediction))
                self.inputs_ready = [False]*self.input_size
                break

In [25]:
boxnet = nn.Sequential(
          nn.Linear(2,1))

net1 = nn.Sequential(
          nn.Linear(5,3),
          nn.ReLU(),
          nn.Linear(3,1))

net2 = nn.Sequential(
          nn.Linear(5,3),
          nn.ReLU(),
          nn.Linear(3,1))

x1 = torch.rand(5)
x2 = torch.rand(5)


in1 = Box(net1, 0, 5)
in2 = Box(net2, 1, 5)
box = Box(boxnet, 2, 2)

in1.add_listener(box)
in2.add_listener(box)

in1.set_inputs(x1)
in2.set_inputs(x2)

Id1	Pred: tensor([0.2417], grad_fn=<AddBackward0>)
Id0	Pred: tensor([-0.1062], grad_fn=<AddBackward0>)
Id2	Pred: tensor([0.4807], grad_fn=<AddBackward0>)
