In [1]:
import import_ipynb
import chess
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

c_const = 0.5
samplingRate = 0.4
seed = random.randint(0, 100)
mse = nn.MSELoss()

def cross_entropy(y_hat, y):
    y1, y2 = y[0], y[1]
    y_hat1 = (torch.clamp(y_hat[0], 1e-9, 1 - 1e-9))
    y_hat2 = (torch.clamp(y_hat[1], 1e-9, 1 - 1e-9))
    
    return -1/2 * ((y1 * torch.log(y_hat1)).sum(dim=1).mean() + (y2 * torch.log(y_hat2)).sum(dim=1).mean())

def train_mcts(batch, dataset_size, encoder, nnet, optimizer, reinf, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value, policy in DataLoader:
        optimizer.zero_grad()
        
        value_hat, policy_hat = nnet(embedding.view(embedding.shape[0],1, 256))
        mse_value = mse(value_hat, value)
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = c_const * mse_value + (1 - c_const) * cross_entropy_value
        print(f"Loss ({noBatch}): \t", loss.item(), "\n\t\t Value loss: ", mse_value.item(), "\n\t\t Policy loss: ", cross_entropy_value.item(), end='\n\n')

        loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_mcts.pt")
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, reinf, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        optimizer.zero_grad()
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        print(f"Loss ({noBatch}): ", mse_value.item(), end='\n')

        mse_loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

importing Jupyter notebook from dset.ipynb


In [2]:
BATCH = 64
DATASET_SIZE = 2048

encoder = autoencoder.autoencoder().cuda()
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))
nnet = net.Net().cuda()
optimizer = optim.Adam(nnet.parameters(), weight_decay=0.01)

for i in range(0, 10):
    ARGS = (chess.Board(), nnet, encoder, dset.SearchType.MCTS, 5)
    train_mcts(BATCH, DATASET_SIZE, encoder, nnet, optimizer, dset.ReinforcementType.MC, *ARGS)
    nnet.load_state_dict(torch.load("nnet_mcts.pt"))

  return F.mse_loss(input, target, reduction=self.reduction)


Loss (0): 	 0.12205387651920319 
		 Value loss:  0.24410775303840637 
		 Policy loss:  -0.0

Loss (1): 	 7.126970012905076e-05 
		 Value loss:  0.00014253940025810152 
		 Policy loss:  -0.0

Loss (2): 	 2.1082231072000468e-08 
		 Value loss:  4.2164462144000936e-08 
		 Policy loss:  -0.0

Loss (3): 	 0.02860681712627411 
		 Value loss:  1.0531103072919379e-10 
		 Policy loss:  0.05721363425254822

Loss (4): 	 6.642988025981622e-13 
		 Value loss:  1.3285976051963244e-12 
		 Policy loss:  -0.0

Loss (5): 	 9.553817426847383e-14 
		 Value loss:  1.9107634853694766e-13 
		 Policy loss:  -0.0

Loss (6): 	 8.088411125768691e-14 
		 Value loss:  1.6176822251537382e-13 
		 Policy loss:  -0.0

Loss (7): 	 3.365243739054946e-13 
		 Value loss:  6.730487478109892e-13 
		 Policy loss:  -0.0

Loss (8): 	 0.03359181806445122 
		 Value loss:  6.720845668189979e-12 
		 Policy loss:  0.06718363612890244

Loss (9): 	 7.154484643612236e-11 
		 Value loss:  1.430896928722447e-10 
		 Policy loss:  -0.0

L

KeyboardInterrupt: 