In [1]:
import import_ipynb
import chess
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

c_const = 0.5
samplingRate = 0.4
seed = random.randint(0, 100)
mse = nn.MSELoss()

def cross_entropy(y_hat, y):
    y1, y2 = y[0], y[1]
    y_hat1 = (torch.clamp(y_hat[0], 1e-9, 1 - 1e-9))
    y_hat2 = (torch.clamp(y_hat[1], 1e-9, 1 - 1e-9))
    
    return -1/2 * ((y1 * torch.log(y_hat1)).sum(dim=1).mean() + (y2 * torch.log(y_hat2)).sum(dim=1).mean())

def train_mcts(batch, dataset_size, encoder, nnet, optimizer, reinf, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value, policy in DataLoader:
        optimizer.zero_grad()
        
        value_hat, policy_hat = nnet(embedding.view(embedding.shape[0],1, 256))
        mse_value = mse(value_hat, value)
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = c_const * mse_value + (1 - c_const) * cross_entropy_value
        print(f"Loss ({noBatch}): \t", loss.item(), "\n\t\t Value loss: ", mse_value.item(), "\n\t\t Policy loss: ", cross_entropy_value.item(), end='\n\n')

        loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_mcts.pt")
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, reinf, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        optimizer.zero_grad()
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        print(f"Loss ({noBatch}): ", mse_value.item(), end='\n')

        mse_loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

importing Jupyter notebook from dset.ipynb
Game result mean:  0.5  Standard deviation:  0.5009643210501632
Example policy:  [0.         0.06289169 0.06325486 0.         0.1257084  0.
 0.         0.         0.         0.         0.06304029 0.03160544
 0.0937361  0.17957807 0.         0.1262531  0.         0.15884805
 0.         0.         0.03149363 0.03167457 0.         0.
 0.03173487 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.        ]


In [2]:
BATCH = 64
DATASET_SIZE = 2048

encoder = autoencoder.autoencoder().cuda()
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))
nnet = net.Net().cuda()
optimizer = optim.Adam(nnet.parameters(), weight_decay=0.01)

for i in range(0, 10):
    ARGS = (chess.Board(), nnet, encoder, dset.SearchType.CUSTOM, 5)
    train_mcts(BATCH, DATASET_SIZE, encoder, nnet, optimizer, dset.ReinforcementType.MC, *ARGS)
    nnet.load_state_dict(torch.load("nnet_mcts.pt"))

  return F.mse_loss(input, target, reduction=self.reduction)


Loss (0): 	 0.19073253870010376 
		 Value loss:  0.2513963580131531 
		 Policy loss:  0.13006871938705444

Loss (1): 	 0.17051826417446136 
		 Value loss:  0.21111547946929932 
		 Policy loss:  0.1299210488796234

Loss (2): 	 0.18035593628883362 
		 Value loss:  0.23052576184272766 
		 Policy loss:  0.13018612563610077

Loss (3): 	 0.17262554168701172 
		 Value loss:  0.21529752016067505 
		 Policy loss:  0.1299535483121872

Loss (4): 	 0.1839885115623474 
		 Value loss:  0.23812207579612732 
		 Policy loss:  0.1298549324274063

Loss (5): 	 0.18532413244247437 
		 Value loss:  0.2408774197101593 
		 Policy loss:  0.12977084517478943

Loss (6): 	 0.1881498098373413 
		 Value loss:  0.24665874242782593 
		 Policy loss:  0.12964089214801788

Loss (7): 	 0.1827000230550766 
		 Value loss:  0.23548530042171478 
		 Policy loss:  0.12991474568843842

Loss (8): 	 0.18688678741455078 
		 Value loss:  0.2437898963689804 
		 Policy loss:  0.12998366355895996

Loss (9): 	 0.17627492547035217 
		 V

KeyboardInterrupt: 