In [1]:
import chess
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import dset
import net
import autoencoder
import bitboards

##Hyperparameter
c_const = 0.3
samplingRate = 0.4
seed = random.randint(0, 100)


mse = nn.MSELoss()

def cross_entropy(y_hat, y):    
    loss = nn.CrossEntropyLoss()   
    y_hat_concat = torch.cat((y_hat[0], y_hat[1]), 1)
    y = torch.reshape(y, (32, 128))
            
    return loss(y_hat_concat, y)

def generate_dataset(dataset_size, transform, reinf_type, game_generator, *args):
    return dset.SearchDataset(dataset_size, transform, reinf_type, game_generator, *args)

def subset_of_dataset(dataset, sampling_rate):
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    return subset

def get_dataloader(subset, batch):
    return torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)

def generate_trainable_data(batch, dataset_size, nnet, optimizer, reinf, game_generator, *args):
    dataset = generate_dataset(dataset_size, reinf, game_generator, *args)
    subset = subset_of_dataset(dataset, samplingRate)    
    DataLoader = get_dataloader(subset, batch)
    
    return DataLoader
    

def train_mcts(nnet, encoder, optimizer, data_loader, file="nnet_mcts.pt"):    
    noBatch = 0
    running_loss, running_mse, running_cross_entropy = 0, 0, 0
    
    for position, value, policy in data_loader:        
        optimizer.zero_grad()
        
        embedding = encoder.encode(position.squeeze())
        value_hat, policy_hat = nnet(embedding.squeeze())
        mse_value = mse(value_hat, torch.tensor(value, dtype=torch.float).cuda())
        cross_entropy_value = cross_entropy(policy_hat, policy)
        loss = torch.tensor(c_const).cuda() * mse_value + torch.tensor(1.0 - c_const).cuda() * cross_entropy_value        
    
        running_loss += loss.item()
        running_mse += mse_value.item()
        running_cross_entropy += cross_entropy_value.item()
        
        loss.backward()
        optimizer.step()
        noBatch += 1
    
    print(f"Loss: \t", running_loss/noBatch, "\n\t\t Value loss: ", running_mse/noBatch, "\n\t\t Policy loss: ", running_cross_entropy/noBatch, end='\n\n')

    torch.save(nnet.state_dict(), file)
                
def train_alpha_beta(batch, dataset_size, encoder, nnet, optimizer, reinf, game_generator, *args):
    dataset = dset.SearchDataset(dataset_size, reinf, game_generator, *args)
    pick = math.floor(samplingRate*len(dataset))
    subset = torch.utils.data.random_split(dataset, [pick, len(dataset) - pick], generator=torch.Generator().manual_seed(seed))
    
    DataLoader = torch.utils.data.DataLoader(subset[0], batch_size=batch, shuffle=True, drop_last=True)
    
    noBatch = 0
    for embedding, value in DataLoader:
        optimizer.zero_grad()
        value_hat = nnet(embedding.view(embedding.shape[0],1, 256))

        mse_value = mse(value_hat, value.cuda())
        print(f"Loss ({noBatch}): ", mse_value.item(), end='\n')

        mse_loss.backward() 
        optimizer.step()
        noBatch += 1
        
    torch.save(nnet.state_dict(), "nnet_alpha_beta.pt")

In [None]:
encoder = autoencoder.autoencoder().cuda()
nnet = net.Net().cuda()

nnet.load_state_dict(torch.load("nnet_mcts.pt"))
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))

encoder.train()
nnet.train()

params = list(encoder.parameters()) + list(nnet.parameters())
optimizer = optim.Adam(params, weight_decay=0.01)

##Hyperparameters 
BATCH = 32
DATASET_SIZE = 256
reinf = dset.ReinforcementType.PARAM
ARGS = (chess.Board(), nnet, encoder, dset.SearchType.CUSTOM, 5)
GameGenerator = dset.GameGenerator(8, 0, 1, dset.ReinforcementType.MC)
transform = None

for j in range(0, 100):
    dataset = generate_dataset(DATASET_SIZE, transform, reinf, GameGenerator, *ARGS)

    subset = subset_of_dataset(dataset, samplingRate)
    DataLoader = get_dataloader(subset, BATCH)

    train_mcts(nnet, encoder, optimizer, DataLoader)
    nnet.load_state_dict(torch.load("nnet_mcts.pt"))
    encoder.load_state_dict(torch.load("autoencoderftest2.pt"))

  self.weight = Parameter(torch.empty(
  mse_value = mse(value_hat, torch.tensor(value, dtype=torch.float).cuda())
  return F.mse_loss(input, target, reduction=self.reduction)


Loss: 	 6.8625664710998535 
		 Value loss:  0.2374453842639923 
		 Policy loss:  9.701904296875

Loss: 	 6.857539812723796 
		 Value loss:  0.2212721904118856 
		 Policy loss:  9.701654752095541

Loss: 	 6.8448459307352705 
		 Value loss:  0.18342192967732748 
		 Policy loss:  9.699741999308268

Loss: 	 6.852213064829509 
		 Value loss:  0.20855568846066794 
		 Policy loss:  9.69949467976888

Loss: 	 6.861380100250244 
		 Value loss:  0.23886252442995706 
		 Policy loss:  9.699602127075195

Loss: 	 6.8560285568237305 
		 Value loss:  0.21713222563266754 
		 Policy loss:  9.70127010345459

Loss: 	 6.8626400629679365 
		 Value loss:  0.23840253551801047 
		 Policy loss:  9.70159943898519

Loss: 	 6.867472012837728 
		 Value loss:  0.2542979617913564 
		 Policy loss:  9.701689720153809

Loss: 	 6.8614020347595215 
		 Value loss:  0.2341767648855845 
		 Policy loss:  9.701641400655111

