In [35]:
import sqlite3
import pandas as pd

ROWS_TO_READ = 850000

sql = sqlite3.connect('test.db')
df = pd.read_sql_query(f'SELECT * FROM evaluations LIMIT {ROWS_TO_READ}', sql)

df.head(10)

Unnamed: 0,id,fen,binary,eval
0,1,rnbqkbnr/pppppppp/8/8/3P4/8/PPP1PPPP/RNBQKBNR ...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.0
1,2,rnbqkbnr/ppp1pppp/8/3p4/3P4/8/PPP1PPPP/RNBQKBN...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.27
2,3,rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBN...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.0
3,4,rnbqkbnr/ppp2ppp/4p3/3p4/2PP4/8/PP2PPPP/RNBQKB...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.09
4,5,rnbqkbnr/ppp2ppp/4p3/3P4/3P4/8/PP2PPPP/RNBQKBN...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.1
5,6,rnbqkbnr/ppp2ppp/8/3p4/3P4/8/PP2PPPP/RNBQKBNR ...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.0
6,7,rnbqkbnr/ppp2ppp/8/3p4/3P4/2N5/PP2PPPP/R1BQKBN...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.09
7,8,rnbqkb1r/ppp2ppp/5n2/3p4/3P4/2N5/PP2PPPP/R1BQK...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,0.13
8,9,rnbqkb1r/ppp2ppp/5n2/3p4/3P4/2N1P3/PP3PPP/R1BQ...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,-0.22
9,10,rnbqk2r/ppp2ppp/3b1n2/3p4/3P4/2N1P3/PP3PPP/R1B...,b'\x08\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00...,-0.26


In [36]:
import numpy as np

def make_matrix(fen):
    res = [] 
    rows = fen.split('/')
    for row in rows:
        row_list = []
        pieces = row.split(" ", 1)[0]
        for thing in pieces:
            if thing.isdigit():
                row_list += '.' * int(thing)
            else:
                row_list += thing
        res.append(row_list)
    return res

def extract_metadata(fen):
    res = [] 
    data = fen.split(' ')
    
    if data[1][0] == 'w': res.append(1)
    else: res.append(0)

    if "K" in data[2]: res.append(1)
    else: res.append(0)

    if "Q" in data[2]: res.append(1)
    else: res.append(0)

    if "k" in data[2]: res.append(1)
    else: res.append(0)

    if "q" in data[2]: res.append(1)
    else: res.append(0)
        
    return res

table = {
    '.': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    
    'P': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'B': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'N': [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'R': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    'Q': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    'K': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],

    'p': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    'b': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    'n': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    'r': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    'q': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    'k': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
}

def vectorize(fen, table):
    res = []
    for i in make_matrix(fen):
        res.append(list(map(table.get, i)))
    return np.array(res)

print(vectorize(df['fen'][0], table).shape) # should be 8 by 8 by 12

(8, 8, 12)


In [37]:
from tqdm.notebook import tqdm

vecs = []
meta = []
eval = []
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    vecs.append(vectorize(row['fen'], table))
    meta.append(extract_metadata(row['fen']))
    eval.append(row['eval'])

  0%|          | 0/850000 [00:00<?, ?it/s]

In [38]:
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
import torch

X = torch.tensor(np.array(vecs), dtype=torch.float).permute(0, 3, 1, 2)
m = torch.tensor(np.array(meta), dtype=torch.float)
y = torch.tensor(np.array(eval), dtype=torch.float)

processed_dataset = TensorDataset(X, m, y)

# set proportion and split dataset into train and validation parts
proportion = 0.2
train_dataset, val_dataset = random_split(processed_dataset, [1-proportion, proportion])

In [39]:
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [40]:
def train(
    model,
    optimizer,
    scheduler,
    loss_fn,
    train_loader,
    val_loader,
    epochs=1,
    device="cpu",
    ckpt_path="best.pt",
):
    # best score for checkpointing
    best = 0
    
    # iterating over epochs
    for epoch in range(epochs):
        # training loop description
        train_loop = tqdm(
            enumerate(train_loader, 0), total=len(train_loader), desc=f"Epoch {epoch}"
        )
        model.to(device)
        model.train()
        train_loss = 0.0
        # iterate over dataset 
        for i, data in train_loop:
            inputs, meta, labels = data
            inputs, meta, labels = inputs.to(device), meta.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward pass and loss calculation
            outputs = model(inputs, meta)
            
            labels = torch.squeeze(labels)
            outputs = torch.squeeze(outputs)
            
            loss = loss_fn(outputs, labels)

            # backward pass
            loss.backward()

            # optimizer run
            optimizer.step()

            train_loss += loss.item()
            train_loop.set_postfix({"loss": train_loss/(i+1)})
        
        # validation
        errors = None
        
        with torch.no_grad():
            eval_loss = 0.0
            model.eval()  # evaluation mode
            val_loop = tqdm(enumerate(val_loader, 0), total=len(val_loader), desc="Val")
            for i, data in val_loop:
                inputs, meta, labels = data
                inputs, meta, labels = inputs.to(device), meta.to(device), labels.to(device)

                outputs = model(inputs, meta)
                labels = torch.squeeze(labels)
                outputs = torch.squeeze(outputs)
                
                loss = loss_fn(outputs, labels)

                eval_loss += loss.item()
                

            score = (i+1) / eval_loss
            print(f'eval_loss: {eval_loss / (i+1)}')

            if score > best:
                torch.save(model.state_dict(), ckpt_path)
                best = score
                
            scheduler.step(eval_loss / (i+1))


In [41]:
from torch import nn

class EvalConvMetaModel(nn.Module):

    def __init__(self):
        super(EvalConvMetaModel, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(12, 128, 5, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Flatten(),
        )

        self.linear = nn.Sequential(
            nn.Linear(2048+5, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x, meta):
        x = self.conv(x)
        #print(x.shape, meta.shape)
        x = torch.cat((x, meta), 1)
        x = self.linear(x)
        return x


In [42]:
import torch.optim as optim

model = EvalConvMetaModel()
loss_fn = nn.L1Loss()
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [69]:
optimizer = optim.Adam(model.parameters(), lr=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.25)

train(
    model,
    optimizer,
    scheduler,
    loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=40
)

Epoch 0:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.39594300158389


Epoch 1:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.4439395179261676


Epoch 2:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.424370135630931


Epoch 3:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.40392016755926


Epoch 4:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.371052711933583


Epoch 5:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.3685893221302434


Epoch 6:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.4749368762110806


Epoch 7:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.405768390890356


Epoch 8:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.456753427559907


Epoch 9:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.4257902454685523


Epoch 10:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2490791812673345


Epoch 11:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.262035276438739


Epoch 12:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.24818504107249


Epoch 13:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2501185087112336


Epoch 14:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.241701264997144


Epoch 15:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.249470864688312


Epoch 16:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.254563343775523


Epoch 17:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2465683382910653


Epoch 18:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2391732766463592


Epoch 19:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.249673304257092


Epoch 20:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2355154963942976


Epoch 21:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2777376500693887


Epoch 22:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2498905257777766


Epoch 23:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2496414750187963


Epoch 24:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.243514164432033


Epoch 25:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2300917096324153


Epoch 26:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.239573993840375


Epoch 27:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2430261280443577


Epoch 28:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2330133156017498


Epoch 29:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.228627746885603


Epoch 30:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.239970188599091


Epoch 31:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2303813638629855


Epoch 32:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2332353770911872


Epoch 33:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2351373083240635


Epoch 34:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2336442398953364


Epoch 35:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.22857277242987


Epoch 36:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2395195560054377


Epoch 37:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.237451232231415


Epoch 38:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.2297165784033925


Epoch 39:   0%|          | 0/1329 [00:00<?, ?it/s]

Val:   0%|          | 0/333 [00:00<?, ?it/s]

eval_loss: 2.236039873000022


In [70]:
model = EvalConvMetaModel()
ckpt = torch.load("Conv5-Final.pt")
model.load_state_dict(ckpt)

<All keys matched successfully>

In [71]:
import sqlite3
import pandas as pd

OFFSET = 890000
ROWS_TO_READ = 20

sql = sqlite3.connect('test.db')
df = pd.read_sql_query(f'SELECT * FROM evaluations LIMIT {OFFSET}, {ROWS_TO_READ}', sql)

df.tail(20)

Unnamed: 0,id,fen,binary,eval
0,890001,2q4r/4n1pp/pQn1N1k1/3pPpN1/P1pP1P2/2P5/6PP/1R4...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,6.58
1,890002,2q4r/4n1pp/pQn1N1k1/3pPpN1/P1pP1P1P/2P5/6P1/1R...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,3.67
2,890003,2q4r/4n1p1/pQn1N1kp/3pPpN1/P1pP1P1P/2P5/6P1/1R...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,3.87
3,890004,2q4r/4n1p1/pQn1N1kp/3pPpNP/P1pP1P2/2P5/6P1/1R4...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,4.22
4,890005,2q4r/4n1p1/pQn1N2p/3pPpNk/P1pP1P2/2P5/6P1/1R4K...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,3.78
5,890006,2q4r/4nNp1/pQn1N2p/3pPp1k/P1pP1P2/2P5/6P1/1R4K...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,3.11
6,890007,2q3r1/4nNp1/pQn1N2p/3pPp1k/P1pP1P2/2P5/6P1/1R4...,b'\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00...,4.52
7,890008,2q3r1/4nNp1/pQn1N2p/3pPp1k/P1pP1P2/2P5/5KP1/1R...,b'\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00...,-5.56
8,890009,6r1/4nNp1/pQn1q2p/3pPp1k/P1pP1P2/2P5/5KP1/1R6 ...,b'\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00...,-5.4
9,890010,6r1/4nNp1/pQn1q2p/3pPp1k/P1pP1P2/2P5/5KP1/7R b...,b'\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00...,-7.46


In [78]:
import numpy as np

fen = df['fen'][15]

data = vectorize(fen, table)
data = torch.tensor(data, dtype=torch.float).permute(2, 0, 1).cpu()
data = data[None, :]

meta = extract_metadata(fen)
meta = torch.tensor(meta, dtype=torch.float).cpu()
meta = meta[None, :]

model = model.cpu()

with torch.no_grad():
    model.eval()
    res = model(data, meta).item()
print(res)

-0.40396976470947266
