In [1]:
import bitboards
import autoencoder
import pgn_reader as reader
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math
import gc

gc.collect()
torch.cuda.empty_cache()

dsetSize, batch = 2**16, 2**8
autoEncoder = autoencoder.autoencoder().cuda()
autoEncoder.load_state_dict(torch.load("autoencoderftest.pt"))
optimizer = optim.Adam(autoEncoder.parameters())
criterion = nn.BCELoss(reduction='sum')

class chessDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        self.data = data_list
        
    def __getitem__(self, idx):
            return self.data[idx]
    
    def __len__(self):
        return len(self.data)


data = reader.get_dataset(dsetSize)
while data:
    random.shuffle(data)
    
    dataset = chessDataset(data[:math.floor(1/16*len(data))])
    testset = chessDataset(data[math.floor(1/16*len(data)):-1])

    DataLoader = torch.utils.data.DataLoader(dataset, batch_size = batch, shuffle = True)
    TestLoader = torch.utils.data.DataLoader(testset, batch_size = batch, shuffle = True)   
    
    autoEncoder.eval()
    test_loss = 0
    for x in TestLoader:
        x = x.cuda()
        y = autoEncoder(x).cuda()
        loss = criterion(y, x)
        test_loss += loss.item()
    print("Test loss: ", test_loss / len(testset))
        
    autoEncoder.train()
    training_loss = 0
    for x in DataLoader:
        optimizer.zero_grad()
        x = x.cuda()
        y = autoEncoder(x).cuda()
        loss = criterion(y, x)
        loss.backward()
        optimizer.step()
        training_loss += loss.item()
    print("Training loss: ", training_loss / len(dataset), end='\n\n')
    
    data = reader.get_dataset(dsetSize)
    
    torch.save(autoEncoder.state_dict(), "autoencoderftest2.pt")

Test loss:  12.318046566105828
Training loss:  17.096751836692594

Test loss:  12.447418504810845
Training loss:  11.38854553851923

Test loss:  10.324967121278807
Training loss:  10.068977683542647

Test loss:  9.748676207046802
Training loss:  9.219289302825928

Test loss:  9.266333962615573
Training loss:  9.236233253479003

Test loss:  9.708820348101693
Training loss:  9.36866996567937

Test loss:  9.453738949626924
Training loss:  9.11348894682491

Test loss:  9.271951903125279
Training loss:  8.93566622298889

Test loss:  9.62319194696003
Training loss:  8.986421215583675

Test loss:  9.866692984840276
Training loss:  9.455791404538038

Test loss:  9.88932819273662
Training loss:  9.444732684695724

Test loss:  9.908909173267446
Training loss:  9.591388402945352

Test loss:  9.494810578456367
Training loss:  8.841355216381086

Test loss:  9.599794342651913
Training loss:  9.272343456745148

Test loss:  9.262106353618803
Training loss:  9.019902850663504

Test loss:  10.0599526154

KeyboardInterrupt: 