data augmentation

In [None]:
import numpy as np
import pickle
from utils import State, load_data
from collections import OrderedDict

def update_prev_local_action(prev: tuple, transform: str) -> tuple:
    r, c = prev
    if transform == "identity":
        return (r, c)
    elif transform == "horizontal":
        return (r, 2 - c)
    elif transform == "vertical":
        return (2 - r, c)
    elif transform == "rotate90":
        return (c, 2 - r)
    elif transform == "rotate180":
        return (2 - r, 2 - c)
    elif transform == "rotate270":
        return (2 - c, r)
    elif "_" in transform:
        base, extra = transform.split("_")
        if base == "rotate90":
            new = (c, 2 - r)
        elif base == "rotate180":
            new = (2 - r, 2 - c)
        elif base == "rotate270":
            new = (2 - c, r)
        else:
            new = (r, c)
        if extra == "horizontal":
            new = (new[0], 2 - new[1])
        elif extra == "vertical":
            new = (2 - new[0], new[1])
        return new
    else:
        return (r, c)

def transform_state(state: State, transform: str) -> State:
    global_board = state.board.transpose(0, 2, 1, 3).reshape(9, 9)

    if transform == "identity":
        transformed_global_board = global_board.copy()
    elif transform == "horizontal":
        transformed_global_board = np.fliplr(global_board)
    elif transform == "vertical":
        transformed_global_board = np.flipud(global_board)
    elif transform.startswith("rotate"):
        if "_" in transform:
            base, extra = transform.split("_")
            if base == "rotate90":
                k = 1
            elif base == "rotate180":
                k = 2
            elif base == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown base rotation: {base}")
            transformed_global_board = np.rot90(global_board, k=k)
            if extra == "horizontal":
                transformed_global_board = np.fliplr(transformed_global_board)
            elif extra == "vertical":
                transformed_global_board = np.flipud(transformed_global_board)
            else:
                raise ValueError(f"Unknown extra transform: {extra}")
        else:
            if transform == "rotate90":
                k = 1
            elif transform == "rotate180":
                k = 2
            elif transform == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown rotation transform: {transform}")
            transformed_global_board = np.rot90(global_board, k=k)
    else:
        raise ValueError(f"Unknown transform: {transform}")

    new_board = transformed_global_board.reshape(3, 3, 3, 3).transpose(0, 2, 1, 3)

    new_prev = None
    if state.prev_local_action is not None:
        new_prev = update_prev_local_action(state.prev_local_action, transform)

    return State(board=new_board, fill_num=state.fill_num, prev_local_action=new_prev)

def augment_entire_state(state: State) -> list[State]:
    transforms = [
        "identity", "horizontal", "vertical",
        "rotate90", "rotate90_horizontal", "rotate90_vertical",
        "rotate180", "rotate180_horizontal", "rotate180_vertical",
        "rotate270", "rotate270_horizontal", "rotate270_vertical"
    ]
    return [transform_state(state, t) for t in transforms]

def state_to_key(state: State) -> tuple:
    return (state.board.tobytes(), state.fill_num, state.prev_local_action)

def augment_dataset(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    augmented_data = []
    for state, utility in data:
        for aug_state in augment_entire_state(state):
            augmented_data.append((aug_state, utility))

    unique_data = {}
    for state, utility in augmented_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

if __name__ == "__main__":
    original_data = load_data()
    print(f"Original data size: {len(original_data)}")

    augmented_data = augment_dataset(original_data)
    print(f"Augmented data size (duplicates removed): {len(augmented_data)}")

    augmented_data_for_saving = []
    for state, utility in augmented_data:
        row_data = (state.board, state.fill_num, state.prev_local_action)
        augmented_data_for_saving.append((row_data, utility))

    with open("augmented_data.pkl", "wb") as f:
        pickle.dump(augmented_data_for_saving, f)


data augmentation including swapping player 1 and 2, 1276600 data size

In [None]:
import numpy as np
import pickle
from utils import State, load_data

def update_prev_local_action(prev: tuple, transform: str) -> tuple:
    r, c = prev
    if transform == "identity":
        return (r, c)
    elif transform == "horizontal":
        return (r, 2 - c)
    elif transform == "vertical":
        return (2 - r, c)
    elif transform == "rotate90":
        return (c, 2 - r)
    elif transform == "rotate180":
        return (2 - r, 2 - c)
    elif transform == "rotate270":
        return (2 - c, r)
    elif "_" in transform:
        base, extra = transform.split("_")
        if base == "rotate90":
            new = (c, 2 - r)
        elif base == "rotate180":
            new = (2 - r, 2 - c)
        elif base == "rotate270":
            new = (2 - c, r)
        else:
            new = (r, c)
        if extra == "horizontal":
            new = (new[0], 2 - new[1])
        elif extra == "vertical":
            new = (2 - new[0], new[1])
        return new
    else:
        return (r, c)

def transform_state(state: State, transform: str) -> State:
    global_board = state.board.transpose(0, 2, 1, 3).reshape(9, 9)
    if transform == "identity":
        transformed_global_board = global_board.copy()
    elif transform == "horizontal":
        transformed_global_board = np.fliplr(global_board)
    elif transform == "vertical":
        transformed_global_board = np.flipud(global_board)
    elif transform.startswith("rotate"):
        if "_" in transform:
            base, extra = transform.split("_")
            k = {"rotate90": 1, "rotate180": 2, "rotate270": 3}.get(base, 0)
            transformed_global_board = np.rot90(global_board, k=k)
            if extra == "horizontal":
                transformed_global_board = np.fliplr(transformed_global_board)
            elif extra == "vertical":
                transformed_global_board = np.flipud(transformed_global_board)
        else:
            k = {"rotate90": 1, "rotate180": 2, "rotate270": 3}.get(transform, 0)
            transformed_global_board = np.rot90(global_board, k=k)
    else:
        raise ValueError(f"Unknown transform: {transform}")
    new_board = transformed_global_board.reshape(3, 3, 3, 3).transpose(0, 2, 1, 3)
    new_prev = update_prev_local_action(state.prev_local_action, transform) if state.prev_local_action else None
    return State(board=new_board, fill_num=state.fill_num, prev_local_action=new_prev)

def augment_entire_state(state: State) -> list[State]:
    transforms = [
        "identity", "horizontal", "vertical",
        "rotate90", "rotate90_horizontal", "rotate90_vertical",
        "rotate180", "rotate180_horizontal", "rotate180_vertical",
        "rotate270", "rotate270_horizontal", "rotate270_vertical"
    ]
    return [transform_state(state, t) for t in transforms]

def state_to_key(state: State) -> tuple:
    return (state.board.tobytes(), state.fill_num, state.prev_local_action)

def augment_dataset(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    augmented_data = []
    for state, utility in data:
        for aug_state in augment_entire_state(state):
            augmented_data.append((aug_state, utility))
    unique_data = {}
    for state, utility in augmented_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

def swap_state(state: State) -> State:
    new_board = np.where(state.board == 1, 2, np.where(state.board == 2, 1, state.board))
    new_fill_num = 3 - state.fill_num
    return State(board=new_board, fill_num=new_fill_num, prev_local_action=state.prev_local_action)

def augment_dataset_with_swap(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    swapped_data = []
    for state, utility in data:
        swapped_state = swap_state(state)
        swapped_utility = -utility
        swapped_data.append((swapped_state, swapped_utility))
    unique_data = {}
    for state, utility in swapped_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

if __name__ == "__main__":
    original_data = load_data()
    print(f"Original data size: {len(original_data)}")
    sym_augmented_data = augment_dataset(original_data)
    print(f"Symmetry-augmented data size (duplicates removed): {len(sym_augmented_data)}")
    swapped_augmented_data = augment_dataset_with_swap(sym_augmented_data)
    print(f"Swapped-augmented data size: {len(swapped_augmented_data)}")
    combined_data = sym_augmented_data + swapped_augmented_data
    final_unique_data = {}
    for state, utility in combined_data:
        key = state_to_key(state)
        if key not in final_unique_data:
            final_unique_data[key] = (state, utility)
    final_data = list(final_unique_data.values())
    print(f"Final augmented data size (combined and duplicates removed): {len(final_data)}")
    final_data_for_saving = []
    for state, utility in final_data:
        row_data = (state.board, state.fill_num, state.prev_local_action)
        final_data_for_saving.append((row_data, utility))
    with open("augmented_data2.pkl", "wb") as f:
        pickle.dump(final_data_for_saving, f)


load data from pkl

In [None]:
from utils import State, Action
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle


def load_aug_data() -> list[tuple[State, float]]:
    with open("augmented_data.pkl", "rb") as f:
        data = pickle.load(f)
    new_data = []
    for row in data:
        row_data, utility = row
        board, fill_num, prev_local_action = row_data
        state = State(board=board, fill_num=fill_num, prev_local_action=prev_local_action)
        new_data.append((state, utility))
    return new_data

data = load_aug_data()
X_list = []
y_list = []


In [None]:
from utils import State, Action
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from sklearn.model_selection import train_test_split

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(93, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

data = load_aug_data()
X_list = []
y_list = []

for state, eval_value in data:
    global_board_features = state.board.reshape(-1)
    local_board_features = state.local_board_status.reshape(-1)
    fill_num_feature = np.array([state.fill_num], dtype=np.float32)
    if state.prev_local_action is None:
        prev_action = np.array([-1, -1], dtype=np.float32)
    else:
        prev_action = np.array(state.prev_local_action, dtype=np.float32)
    features = np.concatenate([global_board_features, local_board_features, fill_num_feature, prev_action])
    X_list.append(features)
    y_list.append(eval_value)

X = torch.tensor(X_list, dtype=torch.float32)
y = torch.tensor(y_list, dtype=torch.float32).view(-1, 1)

X_train, X_eval, y_train, y_eval = train_test_split(X, y, test_size=0.2, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_train, X_eval = X_train.to(device), X_eval.to(device)
y_train, y_eval = y_train.to(device), y_eval.to(device)

net = Net().to(device)

optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()
losses = []

num_epochs = 1500
for epoch in range(num_epochs):
    net.train()
    y_pred_train = net(X_train)
    loss = loss_fn(y_pred_train, y_train)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}")

net.eval()
with torch.no_grad():
    y_pred_eval = net(X_eval)

eval_mse = loss_fn(y_pred_eval, y_eval).item()
print(f"Evaluation MSE: {eval_mse:.4f}")

trained_weights = net.state_dict()
with open("weights.txt", "w") as f:
    for name, param in trained_weights.items():
        values = param.detach().cpu().numpy().tolist()
        f.write(f"{name.upper()} = {values}\n\n")


cross validation

In [None]:
from utils import State, Action
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from sklearn.model_selection import train_test_split

for state, eval_value in data:
    global_board_features = state.board.reshape(-1)
    local_board_features = state.local_board_status.reshape(-1)
    fill_num_feature = np.array([state.fill_num], dtype=np.float32)
    if state.prev_local_action is None:
        prev_action = np.array([-1, -1], dtype=np.float32)
    else:
        prev_action = np.array(state.prev_local_action, dtype=np.float32)
    features = np.concatenate([global_board_features, local_board_features, fill_num_feature, prev_action])
    X_list.append(features)
    y_list.append(eval_value)

X = torch.tensor(X_list, dtype=torch.float32)
y = torch.tensor(y_list, dtype=torch.float32).view(-1, 1)

X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_test = X_test.to(device)
y_test = y_test.to(device)

net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()

num_epochs = 600
for epoch in range(num_epochs):
    fold_val_losses = []
    for i in range(3):
        indices = torch.randperm(X_train_full.size(0))
        fold_size = int(0.2 * X_train_full.size(0))     
        val_indices = indices[i * fold_size : (i + 1) * fold_size]
        train_indices = torch.cat([indices[:i * fold_size], indices[(i + 1) * fold_size:]])

        X_train_fold = X_train_full[train_indices].to(device)
        y_train_fold = y_train_full[train_indices].to(device)
        X_val_fold = X_train_full[val_indices].to(device)
        y_val_fold = y_train_full[val_indices].to(device)

        net.train()
        optimizer.zero_grad()
        y_pred_train = net(X_train_fold)
        loss_train = loss_fn(y_pred_train, y_train_fold)
        loss_train.backward()
        optimizer.step()

        net.eval()
        with torch.no_grad():
            y_pred_val = net(X_val_fold)
            loss_val = loss_fn(y_pred_val, y_val_fold)
            fold_val_losses.append(loss_val.item())

    avg_val_loss = np.mean(fold_val_losses)
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Average Validation MSE: {avg_val_loss:.4f}")

net.eval()
with torch.no_grad():
    y_pred_test = net(X_test)
    test_loss = loss_fn(y_pred_test, y_test).item()
    
print(f"\nFinal Test MSE: {test_loss:.4f}")

trained_weights = net.state_dict()
with open("weights.txt", "w") as f:
    for name, param in trained_weights.items():
        values = param.detach().cpu().numpy().tolist()
        f.write(f"{name.upper()} = {values}\n\n")


one-hot encoding with 639004 dataset

In [None]:
# one hot encoding load
from utils import State, Action
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

def load_aug_data() -> list[tuple[State, float]]:
    with open("augmented_data.pkl", "rb") as f:
        data = pickle.load(f)
    new_data = []
    for row in data:
        row_data, utility = row
        board, fill_num, prev_local_action = row_data
        state = State(board=board, fill_num=fill_num, prev_local_action=prev_local_action)
        new_data.append((state, utility))
    return new_data

def custom_encode_2(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([1, 0], dtype=np.float32)
    else:
        raise ValueError("Value must be 0, 1, or 2.")

def custom_encode_status(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([0, 1, 0], dtype=np.float32)
    elif val == 3:
        return np.array([1, 0, 0], dtype=np.float32)
    else:
        raise ValueError("Local board status must be 0, 1, 2, or 3.")

def custom_encode_coord(val: int) -> np.ndarray:
    if val < 0:
        return np.array([0, 0], dtype=np.float32)
    return custom_encode_2(val)

data = load_aug_data()
X_list = []
y_list = []

for state, eval_value in data:
    board_flat = state.board.reshape(-1)
    global_board_encoded = np.array([custom_encode_2(val) for val in board_flat])
    global_board_features = global_board_encoded.flatten()

    local_board_flat = state.local_board_status.reshape(-1)
    local_board_encoded = np.array([custom_encode_status(val) for val in local_board_flat])
    local_board_features = local_board_encoded.flatten()

    fill_num_feature = custom_encode_2(state.fill_num)

    if state.prev_local_action is None:
        prev_r, prev_c = -1, -1
    else:
        prev_r, prev_c = state.prev_local_action
    prev_r_enc = custom_encode_coord(prev_r)
    prev_c_enc = custom_encode_coord(prev_c)
    prev_action = np.concatenate([prev_r_enc, prev_c_enc])

    features = np.concatenate([global_board_features, local_board_features, fill_num_feature, prev_action])
    X_list.append(features)
    y_list.append(eval_value)


In [None]:
# one hot encoding
from utils import State, Action
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from sklearn.model_selection import train_test_split
import pprint
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(195, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.3)
        self.fc_mid = nn.Linear(128, 64)
        self.bn_mid = nn.BatchNorm1d(64)
        self.dropout_mid = nn.Dropout(0.3)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = F.relu(self.bn_mid(self.fc_mid(x)))
        x = self.dropout_mid(x)
        x = self.fc3(x)
        return x

X = torch.tensor(X_list, dtype=torch.float32)
y = torch.tensor(y_list, dtype=torch.float32).view(-1, 1)

X_train, X_eval, y_train, y_eval = train_test_split(X, y, test_size=0.2, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_train, X_eval = X_train.to(device), X_eval.to(device)
y_train, y_eval = y_train.to(device), y_eval.to(device)

net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()
losses = []

num_epochs = 450
for epoch in range(num_epochs):
    net.train()
    y_pred_train = net(X_train)
    loss = loss_fn(y_pred_train, y_train)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}")

net.eval()
with torch.no_grad():
    y_pred_eval = net(X_eval)

eval_mse = loss_fn(y_pred_eval, y_eval).item()
print(f"Evaluation MSE: {eval_mse:.4f}")

trained_weights = net.state_dict()
weights_dict = OrderedDict()
for name, param in trained_weights.items():
    values = param.detach().cpu().numpy().tolist()
    weights_dict[name] = values

with open("weights.py", "w") as f:
    f.write("weights = OrderedDict([\n")
    for key, value in weights_dict.items():
        f.write(f"    ('{key}', torch.tensor(\n")
        f.write(pprint.pformat(value, indent=8))
        f.write("\n    )),\n")
    f.write("])")


splitting dataset before augmentation

In [None]:
import numpy as np
import pickle
from utils import State, load_data
import random
from sklearn.model_selection import train_test_split

def update_prev_local_action(prev: tuple, transform: str) -> tuple:
    r, c = prev
    if transform == "identity":
        return (r, c)
    elif transform == "horizontal":
        return (r, 2 - c)
    elif transform == "vertical":
        return (2 - r, c)
    elif transform == "rotate90":
        return (c, 2 - r)
    elif transform == "rotate180":
        return (2 - r, 2 - c)
    elif transform == "rotate270":
        return (2 - c, r)
    elif "_" in transform:
        base, extra = transform.split("_")
        if base == "rotate90":
            new = (c, 2 - r)
        elif base == "rotate180":
            new = (2 - r, 2 - c)
        elif base == "rotate270":
            new = (2 - c, r)
        else:
            new = (r, c)
        if extra == "horizontal":
            new = (new[0], 2 - new[1])
        elif extra == "vertical":
            new = (2 - new[0], new[1])
        return new
    else:
        return (r, c)

def transform_state(state: State, transform: str) -> State:
    global_board = state.board.transpose(0, 2, 1, 3).reshape(9, 9)

    if transform == "identity":
        transformed_global_board = global_board.copy()
    elif transform == "horizontal":
        transformed_global_board = np.fliplr(global_board)
    elif transform == "vertical":
        transformed_global_board = np.flipud(global_board)
    elif transform.startswith("rotate"):
        if "_" in transform:
            base, extra = transform.split("_")
            if base == "rotate90":
                k = 1
            elif base == "rotate180":
                k = 2
            elif base == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown base rotation: {base}")
            transformed_global_board = np.rot90(global_board, k=k)
            if extra == "horizontal":
                transformed_global_board = np.fliplr(transformed_global_board)
            elif extra == "vertical":
                transformed_global_board = np.flipud(transformed_global_board)
            else:
                raise ValueError(f"Unknown extra transform: {extra}")
        else:
            if transform == "rotate90":
                k = 1
            elif transform == "rotate180":
                k = 2
            elif transform == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown rotation transform: {transform}")
            transformed_global_board = np.rot90(global_board, k=k)
    else:
        raise ValueError(f"Unknown transform: {transform}")

    new_board = transformed_global_board.reshape(3, 3, 3, 3).transpose(0, 2, 1, 3)
    new_prev = None
    if state.prev_local_action is not None:
        new_prev = update_prev_local_action(state.prev_local_action, transform)

    return State(board=new_board, fill_num=state.fill_num, prev_local_action=new_prev)

def augment_entire_state(state: State) -> list[State]:
    transforms = [
        "identity", "horizontal", "vertical",
        "rotate90", "rotate90_horizontal", "rotate90_vertical",
        "rotate180", "rotate180_horizontal", "rotate180_vertical",
        "rotate270", "rotate270_horizontal", "rotate270_vertical"
    ]
    return [transform_state(state, t) for t in transforms]

def state_to_key(state: State) -> tuple:
    return (state.board.tobytes(), state.fill_num, state.prev_local_action)

def augment_dataset(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    augmented_data = []
    for state, utility in data:
        for aug_state in augment_entire_state(state):
            augmented_data.append((aug_state, utility))
    unique_data = {}
    for state, utility in augmented_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

def swap_state(state: State) -> State:
    new_board = np.where(state.board == 1, 2, np.where(state.board == 2, 1, state.board))
    new_fill_num = 3 - state.fill_num
    return State(board=new_board, fill_num=new_fill_num, prev_local_action=state.prev_local_action)

def augment_dataset_with_swap(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    swapped_data = []
    for state, utility in data:
        swapped_state = swap_state(state)
        swapped_utility = -utility
        swapped_data.append((swapped_state, swapped_utility))
    
    unique_data = {}
    for state, utility in swapped_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

def custom_encode_2(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([1, 0], dtype=np.float32)
    else:
        raise ValueError("Value must be 0, 1, or 2.")

def custom_encode_status(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([0, 1, 0], dtype=np.float32)
    elif val == 3:
        return np.array([1, 0, 0], dtype=np.float32)
    else:
        raise ValueError("Local board status must be 0, 1, 2, or 3.")

def custom_encode_coord(val: int) -> np.ndarray:
    if val < 0:
        return np.array([0, 0], dtype=np.float32)
    return custom_encode_2(val)

if __name__ == "__main__":
    original_data = load_data()
    print(f"Original data size: {len(original_data)}")

    train_data, test_data = train_test_split(original_data, test_size=0.15, random_state=42)
    print(f"Train data size: {len(train_data)} | Test data size: {len(test_data)}")

    sym_aug_train = augment_dataset(train_data)
    swap_aug_train = augment_dataset_with_swap(sym_aug_train)
    combined_train = sym_aug_train + swap_aug_train
    unique_train = {}
    for state, utility in combined_train:
        key = state_to_key(state)
        if key not in unique_train:
            unique_train[key] = (state, utility)
    final_train = list(unique_train.values())
    print(f"Final augmented train data size: {len(final_train)}")

    sym_aug_test = augment_dataset(test_data)
    swap_aug_test = augment_dataset_with_swap(sym_aug_test)
    combined_test = sym_aug_test + swap_aug_test
    unique_test = {}
    for state, utility in combined_test:
        key = state_to_key(state)
        if key not in unique_test:
            unique_test[key] = (state, utility)
    final_test = list(unique_test.values())
    print(f"Final augmented test data size: {len(final_test)}")

    def extract_features(state: State) -> np.ndarray:
        board_flat = state.board.reshape(-1)
        global_board_encoded = np.array([custom_encode_2(val) for val in board_flat])
        global_board_features = global_board_encoded.flatten()

        local_board_flat = state.local_board_status.reshape(-1)
        local_board_encoded = np.array([custom_encode_status(val) for val in local_board_flat])
        local_board_features = local_board_encoded.flatten()

        fill_num_feature = custom_encode_2(state.fill_num)

        if state.prev_local_action is None:
            prev_r, prev_c = -1, -1
        else:
            prev_r, prev_c = state.prev_local_action
        prev_r_enc = custom_encode_coord(prev_r)
        prev_c_enc = custom_encode_coord(prev_c)
        prev_action = np.concatenate([prev_r_enc, prev_c_enc])

        features = np.concatenate([global_board_features, local_board_features, fill_num_feature, prev_action])
        return features

    X_train = []
    y_train = []
    for state, utility in final_train:
        features = extract_features(state)
        X_train.append(features)
        y_train.append(utility)

    X_test = []
    y_test = []
    for state, utility in final_test:
        features = extract_features(state)
        X_test.append(features)
        y_test.append(utility)

    data = {
        'X_train': X_train,
        'y_train': y_train,
        'X_test': X_test,
        'y_test': y_test
    }

    with open("features15.pkl", "wb") as f:
        pickle.dump(data, f)



In [None]:
import pickle

with open("features15.pkl", "rb") as f:
    data_loaded = pickle.load(f)

X_train = data_loaded['X_train']
y_train = data_loaded['y_train']
X_eval = data_loaded['X_test']
y_eval = data_loaded['y_test']

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import pprint
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(195, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.4)
        
        self.fc2 = nn.Linear(512, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.3)

        self.fc_mid = nn.Linear(128, 64)
        self.bn_mid = nn.BatchNorm1d(64)
        self.dropout_mid = nn.Dropout(0.2)

        self.fc3 = nn.Linear(64, 1)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = F.relu(self.bn_mid(self.fc_mid(x)))
        x = self.dropout_mid(x)
        x = self.fc3(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_eval_tensor = torch.tensor(X_eval, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)
y_eval_tensor = torch.tensor(y_eval, dtype=torch.float32).unsqueeze(1).to(device)

net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()
losses = []

num_epochs = 1200
for epoch in range(num_epochs):
    net.train()
    y_pred_train = net(X_train_tensor)
    loss = loss_fn(y_pred_train, y_train_tensor)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}")
    if epoch + 1 == num_epochs:
        print(f"Epoch {num_epochs}, Training Loss: {loss.item():.4f}")

net.eval()
with torch.no_grad():
    y_pred_eval = net(X_eval_tensor)
eval_mse = loss_fn(y_pred_eval, y_eval_tensor).item()
print(f"Evaluation MSE: {eval_mse:.4f}")

trained_weights = net.state_dict()
weights_dict = OrderedDict()
for name, param in trained_weights.items():
    values = param.detach().cpu().numpy().tolist()
    weights_dict[name] = values

with open("weights.py", "w") as f:
    f.write("weights = OrderedDict([\n")
    for key, value in weights_dict.items():
        f.write(f"    ('{key}', torch.tensor(\n")
        f.write(pprint.pformat(value, indent=8))
        f.write("\n    )),\n")
    f.write("])")
