In [5]:
import torch
import itertools 
inputs = list(itertools.product([-1, 1], repeat=4))

inputs

[(-1, -1, -1, -1),
 (-1, -1, -1, 1),
 (-1, -1, 1, -1),
 (-1, -1, 1, 1),
 (-1, 1, -1, -1),
 (-1, 1, -1, 1),
 (-1, 1, 1, -1),
 (-1, 1, 1, 1),
 (1, -1, -1, -1),
 (1, -1, -1, 1),
 (1, -1, 1, -1),
 (1, -1, 1, 1),
 (1, 1, -1, -1),
 (1, 1, -1, 1),
 (1, 1, 1, -1),
 (1, 1, 1, 1)]

In [6]:
all_possible_labels  = []
for i in itertools.product([-1, 1], repeat=16):
    all_possible_labels.append(torch.tensor(i, dtype=torch.float32))

In [7]:
all_possible_labels[100]

tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1.,  1.,  1., -1., -1.,  1.,
        -1., -1.])

In [16]:
class Lut2(torch.nn.Module):
    def __init__(self):
        super(Lut2, self).__init__()
        self.fc1 = torch.nn.Linear(3, 1)

    def forward(self, x):
        x1, x2 = x
        x = torch.tensor([x1, x2, x1*x2], dtype=torch.float32)
        x = x.unsqueeze(0)
        x = self.fc1(x)
        return x

class Lut4(torch.nn.Module):
    def __init__(self):
        super(Lut4, self).__init__()
        self.fc1 = torch.nn.Linear(15, 1)

    def forward(self, x):
        x1, x2, x3, x4 = x
        x = torch.tensor([x1, x2, x3, x4, 
                          x1*x2, x1*x3, x1*x4, x2*x3, x2*x4, x3*x4,
                          x1*x2*x3, x1*x2*x4, x1*x3*x4, x2*x3*x4,
                          x1*x2*x3*x4], dtype=torch.float32)
        x = x.unsqueeze(0)
        x = self.fc1(x)
        return x

In [11]:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lut1 = Lut2()
        self.lut2 = Lut2()
        self.lut3 = Lut2()
        self.debug = False

    def forward(self, x):
        x1, x2, x3, x4 = x
        y1 = self.lut1([x1, x2])
        y2 = self.lut2([x3, x4])
        y1 = torch.where(y1 > 0, 1, -1)
        y2 = torch.where(y2 > 0, 1, -1)
        if self.debug:
            print(y1, y2)
        return self.lut3([y1, y2])

In [17]:
import random

labels_models_map = {}

labels_to_train = random.sample(all_possible_labels, 50)


for i in labels_to_train:
    model = Lut4()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.003)
    for epoch in range(1000):
        for j in range(len(inputs)):
            optimizer.zero_grad()
            outputs = model(inputs[j])
            loss = criterion(outputs, i[j])
            loss.backward()
            optimizer.step()
        if epoch % 100 == 0:
            print(f'Epoch {epoch} Loss: {loss.item()}')

    print(f'Final Loss: {loss.item()}')
    labels_models_map[i] = model

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 0 Loss: 0.33074048161506653
Epoch 100 Loss: 5.403677505455562e-10
Epoch 200 Loss: 2.7853275241795927e-12
Epoch 300 Loss: 2.7853275241795927e-12
Epoch 400 Loss: 2.7853275241795927e-12
Epoch 500 Loss: 2.7853275241795927e-12
Epoch 600 Loss: 2.7853275241795927e-12
Epoch 700 Loss: 2.7853275241795927e-12
Epoch 800 Loss: 2.7853275241795927e-12
Epoch 900 Loss: 2.7853275241795927e-12
Final Loss: 2.7853275241795927e-12
Epoch 0 Loss: 0.5481227040290833
Epoch 100 Loss: 9.459313332627062e-10
Epoch 200 Loss: 1.566746732351021e-12
Epoch 300 Loss: 1.566746732351021e-12
Epoch 400 Loss: 1.566746732351021e-12
Epoch 500 Loss: 1.566746732351021e-12
Epoch 600 Loss: 1.566746732351021e-12
Epoch 700 Loss: 1.566746732351021e-12
Epoch 800 Loss: 1.566746732351021e-12
Epoch 900 Loss: 1.566746732351021e-12
Final Loss: 1.566746732351021e-12
Epoch 0 Loss: 0.24808713793754578
Epoch 100 Loss: 4.228617456192296e-10
Epoch 200 Loss: 2.2737367544323206e-13
Epoch 300 Loss: 2.2737367544323206e-13
Epoch 400 Loss: 2.2737

KeyboardInterrupt: 

In [18]:

def predict(model):
    predictions = []
    for i in range(16):
        x = inputs[i]
        output = model(x)
        predictions.append(output.item())   
    return predictions

# for labels, model in labels_models_map.items():
#     predictions = predict(model)
#     if predictions == list(labels):
#         print(f'found model for labels {labels}')
#         break
    

In [19]:
def visualise_binary_result(x):
    out = ""
    for i in x:
        if i == 1:
            out += "1"
        else:
            out += "0"
    return out

def count_errors(x, y):
    errors = 0
    for i in range(len(x)):
        if x[i] != y[i]:
            errors += 1
    return errors

for labels, model in labels_models_map.items():
    predictions = predict(model)
    binary_labels = [l == 1 for l in labels]
    binary_predictions = [p > 0 for p in predictions]
    if binary_labels == binary_predictions:
        print(f'found model for labels {labels}')
    else:
        print("-"*30)
        print(visualise_binary_result(binary_labels))
        print(visualise_binary_result(binary_predictions))
        print(count_errors(binary_labels, binary_predictions))
    

found model for labels tensor([-1.,  1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1., -1., -1., -1.,
        -1., -1.])
found model for labels tensor([-1., -1.,  1.,  1., -1., -1., -1., -1.,  1.,  1., -1., -1., -1., -1.,
        -1.,  1.])
found model for labels tensor([ 1.,  1.,  1.,  1., -1., -1., -1., -1., -1.,  1., -1.,  1.,  1.,  1.,
         1., -1.])
found model for labels tensor([-1., -1., -1., -1., -1.,  1.,  1.,  1., -1., -1.,  1., -1., -1., -1.,
        -1.,  1.])
found model for labels tensor([-1.,  1., -1., -1.,  1.,  1., -1., -1., -1., -1., -1.,  1., -1., -1.,
        -1., -1.])
found model for labels tensor([ 1., -1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1., -1., -1., -1.,
        -1., -1.])
found model for labels tensor([ 1., -1.,  1.,  1.,  1., -1., -1.,  1., -1., -1., -1.,  1., -1.,  1.,
        -1., -1.])
found model for labels tensor([-1., -1.,  1.,  1., -1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1., -1.,
        -1.,  1.])
found model for labels tensor([ 1.,  1.,

In [15]:
for labels, model in labels_models_map.items():
    model.debug = True
    predictions = predict(model)
    break


tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[1]])
tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[-1]])
tensor([[-1]]) tensor([[1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[-1]])
tensor([[1]]) tensor([[1]])


In [19]:
model((-1, 0.3))

tensor([[-0.3000]], grad_fn=<AddmmBackward0>)