In [1]:
import import_ipynb
import chess
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

c_const = 0.5
samplingRate = 0.4
seed = random.randint(0, 100)
mse = nn.MSELoss()

def cross_entropy(y_hat, y):
    loss = nn.BCELoss()    
    y_hat_concat = torch.cat((y_hat[0], y_hat[1]), 1)
    
    return loss(y_hat_concat, y)

def train_mcts(batch, dataset_size, encoder, nnet, optimizer, reinf, game_generator, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, game_generator, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    running_loss, running_mse, running_cross_entropy = 0, 0, 0
    for embedding, value, policy in DataLoader:
        optimizer.zero_grad()
        
        value_hat, policy_hat = nnet(embedding.view(embedding.shape[0],1, 256))
        mse_value = mse(value_hat, value)
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = c_const * mse_value + (1 - c_const) * cross_entropy_value
        
        running_loss += loss.item()
        running_mse += mse_value.item()
        running_cross_entropy += cross_entropy_value.item()
        
        loss.backward()
        optimizer.step()
        noBatch += 1
    
    print(f"Loss: \t", running_loss/noBatch, "\n\t\t Value loss: ", running_mse/noBatch, "\n\t\t Policy loss: ", running_cross_entropy/noBatch, end='\n\n')

        
    torch.save(nnet.state_dict(), "nnet_mcts.pt")
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, reinf, game_generator, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, game_generator, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        optimizer.zero_grad()
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        print(f"Loss ({noBatch}): ", mse_value.item(), end='\n')

        mse_loss.backward()
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

importing Jupyter notebook from dset.ipynb


In [2]:
BATCH = 64
DATASET_SIZE = 2048

encoder = autoencoder.autoencoder().cuda()
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))
nnet = net.Net().cuda()
optimizer = optim.Adam(nnet.parameters(), weight_decay=0.01)
    
for i in range(0, 100):
    ARGS = (chess.Board(), nnet, encoder, dset.SearchType.CUSTOM, 5)
    GameGenerator = dset.GameGenerator(128, 0, 0)
    train_mcts(BATCH, DATASET_SIZE, encoder, nnet, optimizer, dset.ReinforcementType.MC, GameGenerator, *ARGS)
    nnet.load_state_dict(torch.load("nnet_mcts.pt"))

  return F.mse_loss(input, target, reduction=self.reduction)


Loss: 	 0.16193206837544075 
		 Value loss:  0.2434749832520118 
		 Policy loss:  0.08038915579135601

Loss: 	 0.16637234638134638 
		 Value loss:  0.25227900967001915 
		 Policy loss:  0.08046568309267361

Loss: 	 0.166471799214681 
		 Value loss:  0.2524803876876831 
		 Policy loss:  0.08046321012079716

Loss: 	 0.1652719428141912 
		 Value loss:  0.25008121877908707 
		 Policy loss:  0.08046266809105873

Loss: 	 0.1575735198954741 
		 Value loss:  0.23468338822325072 
		 Policy loss:  0.08046364970505238

Loss: 	 0.15115154658754668 
		 Value loss:  0.22192280739545822 
		 Policy loss:  0.08038028702139854

Loss: 	 0.16519233584403992 
		 Value loss:  0.2499204513927301 
		 Policy loss:  0.08046421781182289

Loss: 	 0.16437221834292778 
		 Value loss:  0.24835808002031767 
		 Policy loss:  0.08038635208056523

Loss: 	 0.1652281110485395 
		 Value loss:  0.24999184533953667 
		 Policy loss:  0.08046437799930573

Loss: 	 0.16750573987762132 
		 Value loss:  0.2545459493994713 
		 Poli

Loss: 	 0.163885818173488 
		 Value loss:  0.24731109042962393 
		 Policy loss:  0.08046054343382518

Loss: 	 0.16318000165315774 
		 Value loss:  0.24605478117099175 
		 Policy loss:  0.0803052238546885

Loss: 	 0.16834144982007834 
		 Value loss:  0.2562218721096332 
		 Policy loss:  0.08046102638428028

Loss: 	 0.16474351860009706 
		 Value loss:  0.24910198954435495 
		 Policy loss:  0.08038505052144711

Loss: 	 0.16530372574925423 
		 Value loss:  0.2501450392107169 
		 Policy loss:  0.08046241228779157

Loss: 	 0.16500265896320343 
		 Value loss:  0.24962020149597755 
		 Policy loss:  0.08038511528418614

Loss: 	 0.1647491569702442 
		 Value loss:  0.24911535244721633 
		 Policy loss:  0.08038296149327205

Loss: 	 0.16266390222769517 
		 Value loss:  0.24494393055255598 
		 Policy loss:  0.0803838733297128

Loss: 	 0.16090421607861152 
		 Value loss:  0.24134745047642633 
		 Policy loss:  0.08046097824206719

Loss: 	 0.17278762658437094 
		 Value loss:  0.2651154225071271 
		 Pol