In [1]:
import torch
import numpy as np
from scipy import sparse
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import os
import random
import time

class ChessEpochDataLoader:
    def __init__(self, folder_path, device, batch_size=32):
        self.folder_path = folder_path
        self.device = device
        self.batch_size = batch_size
        self.batch_files = self.get_batch_files()

    def get_batch_files(self):
        batch_files = []
        for file in os.listdir(self.folder_path):
            if file.startswith("encoding_") and file.endswith(".npz"):
                batch_num = file.split("_")[1].split(".")[0]
                batch_files.append(batch_num)
        return sorted(batch_files)

    def load_batch(self, batch_num):
        # print(f"Loading batch {batch_num}")
        encodings = sparse.load_npz(os.path.join(self.folder_path, f"encoding_{batch_num}.npz"))
        to_move = np.load(os.path.join(self.folder_path, f"to_move_{batch_num}.npy"))
        outcomes = np.load(os.path.join(self.folder_path, f"outcomes_{batch_num}.npy"))

        encodings = torch.from_numpy(encodings.todense()).float().to(self.device)
        encodings = encodings.view(-1, 2, 6, 8, 8)
        to_move = torch.from_numpy(to_move).float().to(self.device)
        outcomes = torch.from_numpy(outcomes).float().to(self.device)

        # Flip boards where black is to move
        black_to_move = to_move == 0
        encodings[black_to_move] = encodings[black_to_move].flip(1)
        encodings[black_to_move] = encodings[black_to_move].flip(3)
        outcomes[black_to_move] = 1 - outcomes[black_to_move]

        return encodings, outcomes

    def __iter__(self):
        # random.shuffle(self.batch_files)  # Shuffle batch order for each epoch
        is_first = 0
        while True:
            for batch_num in self.batch_files:
                encodings, outcomes = self.load_batch(batch_num)
                for i in range(0, len(encodings), self.batch_size):
                    batch_encodings = encodings[i:i + self.batch_size]
                    batch_outcomes = outcomes[i:i + self.batch_size]

                    yield batch_encodings, batch_outcomes
                # Clear the current batch from GPU memory
                del encodings, outcomes
                torch.cuda.empty_cache()

    def __len__(self):
        return len(self.batch_files)


In [2]:
class SimpleChessNet(nn.Module):
    def __init__(self, acc_size):
        super(SimpleChessNet, self).__init__()
        self.fc1 = nn.Linear(768, acc_size)
        self.fc2 = nn.Linear(acc_size, 32)
        self.fc3 = nn.Linear(32, 32)
        self.fc4 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_1):
        x_1 = x_1.view(-1, 768)
        x = self.fc1(x_1).clamp(0, 1)
        x = self.fc2(x).clamp(0, 1)
        x = self.fc3(x).clamp(0, 1)
        x = self.sigmoid(self.fc4(x))
        return x


In [3]:
def train(model, data_loader, num_epochs, learning_rate, gamma, device):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    for epoch in range(num_epochs):
        # print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        s_loss = 0
        for i, (player_view, labels) in enumerate(data_loader):
            optimizer.zero_grad()
            outputs = model(player_view)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            s_loss += loss.item()
            if i % 10000 == 0:
                if i % 120000 == 0 and i != 0:
                    print(f"  Batch {i}, Loss: {s_loss / 10000:.5f}")
                s_loss = 0
                # torch.save(model.state_dict(), f'models/pre_train_{i}.pth')
                scheduler.step()
            if i == 360000:
                break ## zadnih 10000 bacthov (320_000 primerov) je testnih. podatke samo enkrat uporabim



In [4]:
t0 = time.time()
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    folder_path = "/home/luka/PycharmProjects/pythonProject/diplomska/nets"  # Replace with your data folder path
    data_loader = ChessEpochDataLoader(folder_path, device, batch_size=32)

    num_epochs = 1
    learning_rates = [0.001 * 0.5 ** i for i in range(5)]
    gammas = [0.99, 0.95, 0.9]
    input_sizes = [32, 32 * 16, 32 * 16 * 16]  # [32, 512, 4096]
    for input_size in input_sizes:
        for lr in learning_rates:
            for gamma in gammas:
                print(f"Training model with input_size={input_size}, learning_rate={lr}, gamma={gamma}, starting_time={time.time()-t0}")
                model = SimpleChessNet(input_size).to(device)
                # checkpoint_path = r"/home/luka/PycharmProjects/pythonProject/diplomska/nets/models/pre_train_360000.pth"
                # checkpoint = torch.load(checkpoint_path)
                model.train()
                train(model, data_loader, num_epochs, lr, gamma, device)

Training model with input_size=32, learning_rate=0.001, gamma=0.99, starting_time=0.03031444549560547
  Batch 120000, Loss: 0.10415
  Batch 240000, Loss: 0.10300
  Batch 360000, Loss: 0.10135
Training model with input_size=32, learning_rate=0.001, gamma=0.95, starting_time=386.5256402492523
  Batch 120000, Loss: 0.10289
  Batch 240000, Loss: 0.10097
  Batch 360000, Loss: 0.09902
Training model with input_size=32, learning_rate=0.001, gamma=0.9, starting_time=772.0853114128113
  Batch 120000, Loss: 0.10210
  Batch 240000, Loss: 0.10044
  Batch 360000, Loss: 0.09912
Training model with input_size=32, learning_rate=0.0005, gamma=0.99, starting_time=1157.6811785697937
  Batch 120000, Loss: 0.10389
  Batch 240000, Loss: 0.10219
  Batch 360000, Loss: 0.10047
Training model with input_size=32, learning_rate=0.0005, gamma=0.95, starting_time=1543.3803930282593
  Batch 120000, Loss: 0.10294
  Batch 240000, Loss: 0.10107
  Batch 360000, Loss: 0.09914
Training model with input_size=32, learning_r