In [1]:
import import_ipynb
import chess
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

##Hyperparameter
c_const = 0.5
samplingRate = 0.4
seed = random.randint(0, 100)


mse = nn.MSELoss()

def cross_entropy(y_hat, y):
    loss = nn.BCELoss()    
    y_hat_concat = torch.cat((y_hat[0], y_hat[1]), 1)
    
    return loss(y_hat_concat, y)

def generate_dataset(dataset_size, encoder, reinf_type, game_generator, *args):
    return dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, game_generator, *args)

def subset_of_dataset(dataset, sampling_rate):
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    return subset

def get_dataloader(subset, batch):
    return torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)

def generate_trainable_data(batch, dataset_size, encoder, nnet, optimizer, reinf, game_generator, *args):
    dataset = generate_dataset(dataset_size, encoder, reinf, game_generator, *args)
    subset = subset_of_dataset(dataset, samplingRate)    
    DataLoader = get_dataloader(subset, batch)
    
    return DataLoader
    

def train_mcts(nnet, optimizer, data_loader, file="nnet_mcts.pt"):    
    noBatch = 0
    running_loss, running_mse, running_cross_entropy = 0, 0, 0
    
    for embedding, value, policy in data_loader:
        optimizer.zero_grad()
        
        value_hat, policy_hat = nnet(embedding.view(embedding.shape[0],1, 256))
        mse_value = mse(value_hat, value)
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = c_const * mse_value + (1 - c_const) * cross_entropy_value
        
        running_loss += loss.item()
        running_mse += mse_value.item()
        running_cross_entropy += cross_entropy_value.item()
        
        loss.backward()
        optimizer.step()
        noBatch += 1
    
    print(f"Loss: \t", running_loss/noBatch, "\n\t\t Value loss: ", running_mse/noBatch, "\n\t\t Policy loss: ", running_cross_entropy/noBatch, end='\n\n')

    torch.save(nnet.state_dict(), file)
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, reinf, game_generator, *args):
    dataset = dset.SearchDataset(dataset_size, dset.Encode(encoder), reinf, game_generator, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        optimizer.zero_grad()
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value)
        print(f"Loss ({noBatch}): ", mse_value.item(), end='\n')

        mse_loss.backward() 
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

importing Jupyter notebook from dset.ipynb


In [2]:
encoder = autoencoder.autoencoder().cuda()
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))
nnet = net.Net().cuda()
optimizer = optim.Adam(nnet.parameters(), weight_decay=0.01)

##Hyperparameters 
BATCH = 2048
DATASET_SIZE = 5120
reinf = dset.ReinforcementType.PARAM
ARGS = (chess.Board(), nnet, encoder, dset.SearchType.CUSTOM, 5)
GameGenerator = dset.GameGenerator(64, 0, 0.2)

for j in range(0, 10000):
    dataset = generate_dataset(DATASET_SIZE, encoder, reinf, GameGenerator, *ARGS)

    subset = subset_of_dataset(dataset, samplingRate)
    DataLoader = get_dataloader(subset, BATCH)

    train_mcts(nnet, optimizer, DataLoader)
    nnet.load_state_dict(torch.load("nnet_mcts.pt"))

  return F.mse_loss(input, target, reduction=self.reduction)


Loss: 	 0.11751578996578853 
		 Value loss:  0.15457196036974588 
		 Policy loss:  0.08045961707830429

Loss: 	 0.10410129030545552 
		 Value loss:  0.1277504786849022 
		 Policy loss:  0.08045210440953572

Loss: 	 0.11162883788347244 
		 Value loss:  0.14291685819625854 
		 Policy loss:  0.08034082253774007

Loss: 	 0.10108283162117004 
		 Value loss:  0.12175237387418747 
		 Policy loss:  0.08041328440109889

Loss: 	 0.11997389296690623 
		 Value loss:  0.1595462510983149 
		 Policy loss:  0.0804015298684438

Loss: 	 0.11664571613073349 
		 Value loss:  0.15293515721956888 
		 Policy loss:  0.08035627504189809

Loss: 	 0.11824163049459457 
		 Value loss:  0.15602863828341165 
		 Policy loss:  0.08045462518930435

Loss: 	 0.12008560945590337 
		 Value loss:  0.15971514582633972 
		 Policy loss:  0.08045607556899388

Loss: 	 0.11771867175896962 
		 Value loss:  0.15498030185699463 
		 Policy loss:  0.08045704414447148

Loss: 	 0.11802918215592702 
		 Value loss:  0.15568620463212332 
	

Loss: 	 0.11679733296235402 
		 Value loss:  0.15313669045766196 
		 Policy loss:  0.08045797795057297

Loss: 	 0.1240023747086525 
		 Value loss:  0.16754663983980814 
		 Policy loss:  0.08045811206102371

Loss: 	 0.1170232022802035 
		 Value loss:  0.15358894566694895 
		 Policy loss:  0.08045745889345805

Loss: 	 0.11777183910210927 
		 Value loss:  0.15525421500205994 
		 Policy loss:  0.08028946320215861

Loss: 	 0.12061337133248647 
		 Value loss:  0.16076893111069998 
		 Policy loss:  0.08045780658721924

Loss: 	 0.12036361545324326 
		 Value loss:  0.16026940445105234 
		 Policy loss:  0.08045782645543416

Loss: 	 0.11585667729377747 
		 Value loss:  0.15125460426012674 
		 Policy loss:  0.08045874536037445

Loss: 	 0.11903956284125645 
		 Value loss:  0.15762048959732056 
		 Policy loss:  0.08045863608519237

Loss: 	 0.11701323091983795 
		 Value loss:  0.15356662372748056 
		 Policy loss:  0.08045983811219533

Loss: 	 0.11498903979857762 
		 Value loss:  0.14951951305071512 


Loss: 	 0.11666097243626912 
		 Value loss:  0.1528646151224772 
		 Policy loss:  0.08045732478300731

Loss: 	 0.11457087844610214 
		 Value loss:  0.14868445694446564 
		 Policy loss:  0.08045729746421178

Loss: 	 0.11938898513714473 
		 Value loss:  0.15832037727038065 
		 Policy loss:  0.08045759052038193

Loss: 	 0.11547728627920151 
		 Value loss:  0.15049666166305542 
		 Policy loss:  0.0804579108953476

Loss: 	 0.11570057769616444 
		 Value loss:  0.15094123780727386 
		 Policy loss:  0.08045991758505504

Loss: 	 0.11808006962140401 
		 Value loss:  0.1557006686925888 
		 Policy loss:  0.08045947055021922

Loss: 	 0.12271096060673396 
		 Value loss:  0.16496237615744272 
		 Policy loss:  0.08045954753955205

Loss: 	 0.11800543467203777 
		 Value loss:  0.15555210411548615 
		 Policy loss:  0.08045876522858937

Loss: 	 0.11476611842711766 
		 Value loss:  0.14907423158486685 
		 Policy loss:  0.08045800775289536

Loss: 	 0.11855674783388774 
		 Value loss:  0.15665595730145773 
	

Loss: 	 0.11787186811367671 
		 Value loss:  0.1553707718849182 
		 Policy loss:  0.08037296185890834

Loss: 	 0.12014958014090855 
		 Value loss:  0.15984109044075012 
		 Policy loss:  0.080458069841067

Loss: 	 0.11700024704138438 
		 Value loss:  0.15354210138320923 
		 Policy loss:  0.0804583951830864

Loss: 	 0.11126382648944855 
		 Value loss:  0.1420687735080719 
		 Policy loss:  0.08045887698729833

Loss: 	 0.11948770533005397 
		 Value loss:  0.15851863225301108 
		 Policy loss:  0.08045677592356999

Loss: 	 0.11593619237343471 
		 Value loss:  0.15141520897547403 
		 Policy loss:  0.08045717577139537

Loss: 	 0.11328715831041336 
		 Value loss:  0.1461168328921 
		 Policy loss:  0.0804574837287267

Loss: 	 0.12000205864508946 
		 Value loss:  0.15954575935999551 
		 Policy loss:  0.08045835544665654

Loss: 	 0.11378192404905955 
		 Value loss:  0.14710656305154166 
		 Policy loss:  0.08045728504657745

Loss: 	 0.12291755775610606 
		 Value loss:  0.16546033322811127 
		 Polic

Loss: 	 0.11305534094572067 
		 Value loss:  0.1456513206164042 
		 Policy loss:  0.08045935879151027

Loss: 	 0.1148065874973933 
		 Value loss:  0.14923799534638724 
		 Policy loss:  0.08037517964839935

Loss: 	 0.1228496680657069 
		 Value loss:  0.16524035235246023 
		 Policy loss:  0.08045898129542668

Loss: 	 0.11680557082096736 
		 Value loss:  0.1531511147816976 
		 Policy loss:  0.08046002686023712

Loss: 	 0.11749717841545741 
		 Value loss:  0.15453593929608664 
		 Policy loss:  0.08045842001835506

Loss: 	 0.12284885098536809 
		 Value loss:  0.16523910562197366 
		 Policy loss:  0.08045859634876251

Loss: 	 0.12050204475720723 
		 Value loss:  0.16054533421993256 
		 Policy loss:  0.08045875529448192

Loss: 	 0.11156331251064937 
		 Value loss:  0.14266853531201681 
		 Policy loss:  0.0804580847422282

Loss: 	 0.12385579695304234 
		 Value loss:  0.16725301245848337 
		 Policy loss:  0.08045858144760132

Loss: 	 0.11480042586723964 
		 Value loss:  0.14914168417453766 
		 

Loss: 	 0.11900587379932404 
		 Value loss:  0.15763740738232931 
		 Policy loss:  0.08037434269984563

Loss: 	 0.11509658147891362 
		 Value loss:  0.1497355451186498 
		 Policy loss:  0.08045761535565059

Loss: 	 0.11587273826201756 
		 Value loss:  0.15128685037295023 
		 Policy loss:  0.08045862863461177

Loss: 	 0.11290709426005681 
		 Value loss:  0.1453558107217153 
		 Policy loss:  0.08045837531487147

Loss: 	 0.11557145913441975 
		 Value loss:  0.15068378547827402 
		 Policy loss:  0.08045913279056549

Loss: 	 0.11123305062452953 
		 Value loss:  0.1420071025689443 
		 Policy loss:  0.08045900116364162

Loss: 	 0.11726964513460796 
		 Value loss:  0.15416505436102548 
		 Policy loss:  0.08037423590819041

Loss: 	 0.12151023993889491 
		 Value loss:  0.16256136198838553 
		 Policy loss:  0.0804591178894043

Loss: 	 0.11637068539857864 
		 Value loss:  0.15228140850861868 
		 Policy loss:  0.08045996228853862

Loss: 	 0.11621640374263127 
		 Value loss:  0.15197356541951498 
		

KeyboardInterrupt: 