In [19]:
import numpy as np
import torch
import torch.nn as nn
import os

import value_net

data_dir = f'{os.getcwd()}/data'
db_white_path = f'{data_dir}/train/white.txt'
db_black_path = f'{data_dir}/train/black.txt'
db_draw_path = f'{data_dir}/train/draw.txt'

In [20]:
model = value_net.Net()

loss_history = []
acc_history = []

In [21]:
epochs = 1000
batch_size = 257
lr = .0001
betas=(0.9, 0.99)

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

In [24]:
def sample_from_file(size, path):  
    with open(path, 'r') as file:
        db = file.read().split('\n')
    
    fen_batch = np.random.choice(db, size=size, replace=False)
    
    fen_tensor_list = []
    for fen in fen_batch:
        fen_tensor_list.append(value_net.Encoder.from_fen(fen))
    
    return torch.stack(fen_tensor_list)
    
def gen_label(one_hot, size):
    label = torch.zeros([1, 3])
    label[0, one_hot] = 1.
    label = label.repeat(size, 1)
    
    return label

def gen_uniform_sample(batch_size):
    # generates sample containing equal quantities of outcomes

    sample_size = batch_size // 3
    white_x = sample_from_file(sample_size, db_white_path)
    white_label = gen_label(0, sample_size)

    black_x = sample_from_file(sample_size, db_black_path)
    black_label = gen_label(1, sample_size)

    draw_x = sample_from_file(sample_size, db_draw_path)
    draw_label = gen_label(2, sample_size)

    x = torch.cat([white_x, black_x, draw_x], 0)
    labels = torch.cat([white_label, black_label, draw_label], 0)

    return x, labels

In [25]:
import tqdm

def epoch(model):
    opt.zero_grad()
    
    x, target = gen_uniform_sample(batch_size)

    y = model(x)
    loss = loss_fn(y, target)
    
    with torch.no_grad():
        y_labels = torch.argmax(y, -1)
        t_labels = torch.argmax(target, -1)

        correct_labels = torch.eq(y_labels, t_labels)
        accuracy = correct_labels.sum() / correct_labels.numel()

    loss.backward()
    opt.step()

    return loss.item(), accuracy.item()

t = tqdm.trange(epochs)
for i in t:
    loss, acc = epoch(model)
    loss_history.append(loss)
    acc_history.append(acc)

    t.set_description(f'Loss: {loss}, Acc: {acc}')
    t.refresh()

Loss: 1.0986889600753784, Acc: 0.3333333432674408: 100%|██████████| 2/2 [00:17<00:00,  8.70s/it]
