In [36]:
import import_ipynb
import chess
import math
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

c_const = 2
samplingRate = 0.2
seed = 42
mse = nn.MSELoss()

def cross_entropy(y_hat, y):
    y1, y2 = y[0], y[1]
    y_hat1 = (torch.clamp(y_hat[0], 1e-9, 1 - 1e-9))
    y_hat2 = (torch.clamp(y_hat[1], 1e-9, 1 - 1e-9))
    
    return -1/2 * ((y1 * torch.log(y_hat1)).sum(dim=1).mean() + (y2 * torch.log(y_hat2)).sum(dim=1).mean())

def train_mcts(batch, dataset_size, encoder, nnet, optimizer, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value, policy in DataLoader:
        value_hat, policy_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = c_const * mse_value + cross_entropy_value
        print(f"Loss ({noBatch}): ", loss, mse_value, cross_entropy_value, end='\n')

        loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_mcts.pt")
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        print(f"Loss ({noBatch}): ", mse_value, end='\n')

        mse_loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

In [37]:
BATCH = 2
DATASET_SIZE = 20
ARGS = (chess.Board(), net.Net().cuda(), autoencoder.autoencoder().cuda(), dset.SearchType.MCTS, 50)
encoder = autoencoder.autoencoder().cuda()
nnet = net.Net().cuda()
optimizer = optim.Adam(nnet.parameters(), weight_decay=0.01)

train_mcts(BATCH, DATASET_SIZE, encoder, nnet, optimizer, *ARGS)

Loss (0):  tensor(0.1299, device='cuda:0', grad_fn=<AddBackward0>) tensor(1.0529e-06, device='cuda:0', grad_fn=<MseLossBackward>) tensor(0.1299, device='cuda:0', grad_fn=<MulBackward0>)
Loss (1):  tensor(0.1302, device='cuda:0', grad_fn=<AddBackward0>) tensor(1.1383e-07, device='cuda:0', grad_fn=<MseLossBackward>) tensor(0.1302, device='cuda:0', grad_fn=<MulBackward0>)
