In [364]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

In [365]:
GENERATE_DATA = False
NUM_TRAIN_DATA = int(1e4)
RATIO_TEST_TRAIN = 2/8
STATE_UPPER_LIM = 1000

ACTIONS = torch.tensor([[1, -1], [0, 1]], dtype=torch.float), torch.tensor([[0, 1], [1, 0]], dtype=torch.float)


def gen_start_config():
    return torch.round(torch.rand((1, 2)) * STATE_UPPER_LIM).float()   # TODO make a better stochastic thing

# predefined heuristic
def get_action(state):
    if state[0][0] >= state[0][1]:
        return 1
    return 0


def update_state(state, action):
    return state @ ACTIONS[action]


def terminate(state):
    K_EPS = 1e-3
    return abs(state[0][0]) < K_EPS or abs(state[0][1]) < K_EPS 


def gen_example_data(filename, n):
    with open(filename, 'w') as f:
        for i in range(n):
            state = gen_start_config()
            f.write(str(int(state[0][0])) + "," + str(int(state[0][1])))
            action = get_action(state)
            f.write(","+ str(action) + "\n")
        
        
if GENERATE_DATA:
    fname = "train_data/train.csv"
    gen_example_data(fname, NUM_TRAIN_DATA)
    fname = "test_data/test.csv"
    gen_example_data(fname, int(NUM_TRAIN_DATA * RATIO_TEST_TRAIN))

In [366]:
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.flatten = nn.Flatten()
        self.stack = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
        self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.flatten(x)
        x = self.stack(x)
        # return x
        return self.activation(x)
    

In [367]:
def get_data(fname):
   x = torch.tensor(np.loadtxt(fname, delimiter=","), dtype=torch.float)
   return x[:,0:2], x[:, -1]

train_X, train_Y = get_data("train_data/train_simple.csv")
test_X, test_Y = get_data("test_data/test_simple.csv")

train_Y.reshape(-1, 1)
test_Y.reshape(-1, 1)

tensor([[0.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [0.]])

In [368]:
EPOCHS = int(50)
BATCH_SIZE = int(10)
LR = 1e-3

model = NN()
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)


def train(x, y):
    for t in range(EPOCHS):
        total_loss = 0
        for i in range(0, len(x), BATCH_SIZE):
            start = i
            end = start + BATCH_SIZE
            x_b = x[start:end]
            y_b = y[start:end]
            pred = model.forward(x_b)
            # print("pred", pred)
            loss = loss_fn(pred, y_b.unsqueeze(dim=1))
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        if t % 10 == 0:
            print("epoch", (t), "\ttotal loss:", total_loss)
            # with torch.no_grad():
            #     print("Accuracy", test(test_X[0:10], test_Y[0:10]).item())
        
def test(x, y):
    with torch.no_grad():
        return torch.count_nonzero(torch.eq(model(x).round(), y.unsqueeze(dim=1))) / len(y)

In [369]:
train(train_X, train_Y) 

epoch 0 	total loss: 77.95457186339316
epoch 10 	total loss: 154.42037779253238
epoch 20 	total loss: 8.535392225158965
epoch 30 	total loss: 14.125405677541115
epoch 40 	total loss: 85.91433879753838


In [371]:
with torch.no_grad():
    print("Accuracy", test(test_X, test_Y).item())

Accuracy 0.9887999892234802


In [372]:
SAVE_WEIGHTS = False
if SAVE_WEIGHTS:
    torch.save(model.state_dict(), "trained_weights/supervised_simple_weights.pth")