split first then augment

In [1]:
import numpy as np
import pickle
from utils import State, Action, load_data  # Ensure your utils.py is in your directory
from collections import OrderedDict
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time

In [1]:
import numpy as np
from dataclasses import dataclass
import pickle
from collections import OrderedDict
from utils import State, load_data
import random
from sklearn.model_selection import train_test_split  # Ensure scikit-learn is installed

# (Your existing functions like convert_board_to_string, get_local_board_status, etc. are assumed to be defined/imported)

def update_prev_local_action(prev: tuple, transform: str) -> tuple:
    r, c = prev
    if transform == "identity":
        return (r, c)
    elif transform == "horizontal":
        return (r, 2 - c)
    elif transform == "vertical":
        return (2 - r, c)
    elif transform == "rotate90":
        return (c, 2 - r)
    elif transform == "rotate180":
        return (2 - r, 2 - c)
    elif transform == "rotate270":
        return (2 - c, r)
    elif "_" in transform:
        base, extra = transform.split("_")
        if base == "rotate90":
            new = (c, 2 - r)
        elif base == "rotate180":
            new = (2 - r, 2 - c)
        elif base == "rotate270":
            new = (2 - c, r)
        else:
            new = (r, c)
        if extra == "horizontal":
            new = (new[0], 2 - new[1])
        elif extra == "vertical":
            new = (2 - new[0], new[1])
        return new
    else:
        return (r, c)

def transform_state(state: State, transform: str) -> State:
    global_board = state.board.transpose(0, 2, 1, 3).reshape(9, 9)

    if transform == "identity":
        transformed_global_board = global_board.copy()
    elif transform == "horizontal":
        transformed_global_board = np.fliplr(global_board)
    elif transform == "vertical":
        transformed_global_board = np.flipud(global_board)
    elif transform.startswith("rotate"):
        if "_" in transform:
            base, extra = transform.split("_")
            if base == "rotate90":
                k = 1
            elif base == "rotate180":
                k = 2
            elif base == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown base rotation: {base}")
            transformed_global_board = np.rot90(global_board, k=k)
            if extra == "horizontal":
                transformed_global_board = np.fliplr(transformed_global_board)
            elif extra == "vertical":
                transformed_global_board = np.flipud(transformed_global_board)
            else:
                raise ValueError(f"Unknown extra transform: {extra}")
        else:
            if transform == "rotate90":
                k = 1
            elif transform == "rotate180":
                k = 2
            elif transform == "rotate270":
                k = 3
            else:
                raise ValueError(f"Unknown rotation transform: {transform}")
            transformed_global_board = np.rot90(global_board, k=k)
    else:
        raise ValueError(f"Unknown transform: {transform}")

    new_board = transformed_global_board.reshape(3, 3, 3, 3).transpose(0, 2, 1, 3)
    new_prev = None
    if state.prev_local_action is not None:
        new_prev = update_prev_local_action(state.prev_local_action, transform)

    return State(board=new_board, fill_num=state.fill_num, prev_local_action=new_prev)

def augment_entire_state(state: State) -> list[State]:
    transforms = [
        "identity", "horizontal", "vertical",
        "rotate90", "rotate90_horizontal", "rotate90_vertical",
        "rotate180", "rotate180_horizontal", "rotate180_vertical",
        "rotate270", "rotate270_horizontal", "rotate270_vertical"
    ]
    return [transform_state(state, t) for t in transforms]

def state_to_key(state: State) -> tuple:
    return (state.board.tobytes(), state.fill_num, state.prev_local_action)

def augment_dataset(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    augmented_data = []
    for state, utility in data:
        for aug_state in augment_entire_state(state):
            augmented_data.append((aug_state, utility))
    unique_data = {}
    for state, utility in augmented_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

def swap_state(state: State) -> State:
    new_board = np.where(state.board == 1, 2, np.where(state.board == 2, 1, state.board))
    new_fill_num = 3 - state.fill_num
    return State(board=new_board, fill_num=new_fill_num, prev_local_action=state.prev_local_action)

def augment_dataset_with_swap(data: list[tuple[State, float]]) -> list[tuple[State, float]]:
    swapped_data = []
    for state, utility in data:
        swapped_state = swap_state(state)
        swapped_utility = -utility
        swapped_data.append((swapped_state, swapped_utility))
    
    unique_data = {}
    for state, utility in swapped_data:
        key = state_to_key(state)
        if key not in unique_data:
            unique_data[key] = (state, utility)
    return list(unique_data.values())

# --- Custom Encoding Functions ---

def custom_encode_2(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([1, 0], dtype=np.float32)
    else:
        raise ValueError("Value must be 0, 1, or 2.")

def custom_encode_status(val: int) -> np.ndarray:
    if val == 0:
        return np.array([0, 0, 0], dtype=np.float32)
    elif val == 1:
        return np.array([0, 0, 1], dtype=np.float32)
    elif val == 2:
        return np.array([0, 1, 0], dtype=np.float32)
    elif val == 3:
        return np.array([1, 0, 0], dtype=np.float32)
    else:
        raise ValueError("Local board status must be 0, 1, 2, or 3.")

def custom_encode_coord(val: int) -> np.ndarray:
    if val < 0:
        return np.array([0, 0], dtype=np.float32)
    return custom_encode_2(val)

# --- Main processing: splitting, augmenting, feature extraction ---

if __name__ == "__main__":
    # Load the original dataset (each element is a (State, utility) pair).
    original_data = load_data()
    print(f"Original data size: {len(original_data)}")

    # --- Split the dataset into train and test sets (80/20 split) ---
    train_data, test_data = train_test_split(original_data, test_size=0.15, random_state=42)
    print(f"Train data size: {len(train_data)} | Test data size: {len(test_data)}")

    # --- Augment the training set ---
    sym_aug_train = augment_dataset(train_data)
    swap_aug_train = augment_dataset_with_swap(sym_aug_train)
    combined_train = sym_aug_train + swap_aug_train
    unique_train = {}
    for state, utility in combined_train:
        key = state_to_key(state)
        if key not in unique_train:
            unique_train[key] = (state, utility)
    final_train = list(unique_train.values())
    print(f"Final augmented train data size: {len(final_train)}")

    # --- Augment the test set ---
    sym_aug_test = augment_dataset(test_data)
    swap_aug_test = augment_dataset_with_swap(sym_aug_test)
    combined_test = sym_aug_test + swap_aug_test
    unique_test = {}
    for state, utility in combined_test:
        key = state_to_key(state)
        if key not in unique_test:
            unique_test[key] = (state, utility)
    final_test = list(unique_test.values())
    print(f"Final augmented test data size: {len(final_test)}")

    # --- Feature Extraction ---
    # Define a helper function for feature extraction.
    def extract_features(state: State) -> np.ndarray:
        # Global board: shape (3,3,3,3) -> flatten to 81 elements; encode each cell into 2 features.
        board_flat = state.board.reshape(-1)  # 81 elements
        global_board_encoded = np.array([custom_encode_2(val) for val in board_flat])
        global_board_features = global_board_encoded.flatten()  # 81*2 = 162 features

        # Local board status: assume state.local_board_status exists with shape (3,3) -> 9 elements; encode each into 3 features.
        local_board_flat = state.local_board_status.reshape(-1)
        local_board_encoded = np.array([custom_encode_status(val) for val in local_board_flat])
        local_board_features = local_board_encoded.flatten()  # 9*3 = 27 features

        # Fill number: encode into 2 features.
        fill_num_feature = custom_encode_2(state.fill_num)  # 2 features

        # Previous local action: encode each coordinate into 2 features.
        if state.prev_local_action is None:
            prev_r, prev_c = -1, -1
        else:
            prev_r, prev_c = state.prev_local_action
        prev_r_enc = custom_encode_coord(prev_r)  # 2 features
        prev_c_enc = custom_encode_coord(prev_c)  # 2 features
        prev_action = np.concatenate([prev_r_enc, prev_c_enc])  # 4 features

        # Total features: 162 + 27 + 2 + 4 = 195
        features = np.concatenate([global_board_features, local_board_features, fill_num_feature, prev_action])
        return features

    # Process training set features.
    X_train = []
    y_train = []
    for state, utility in final_train:
        features = extract_features(state)
        X_train.append(features)
        y_train.append(utility)

    # Process test set features.
    X_test = []
    y_test = []
    for state, utility in final_test:
        features = extract_features(state)
        X_test.append(features)
        y_test.append(utility)

    # --- Save the final augmented datasets (features and labels) ---
    data = {
        'X_train': X_train,
        'y_train': y_train,
        'X_test': X_test,
        'y_test': y_test
    }

    # Save to a single pickle file.
    with open("features15.pkl", "wb") as f:
        pickle.dump(data, f)

    print("Augmented and feature-extracted datasets have been saved.")


Original data size: 80000
Train data size: 68000 | Test data size: 12000
Final augmented train data size: 1085528
Final augmented test data size: 191896
Augmented and feature-extracted datasets have been saved.


In [1]:
import pickle

with open("features15.pkl", "rb") as f:
    data_loaded = pickle.load(f)

X_train = data_loaded['X_train']
y_train = data_loaded['y_train']
X_eval = data_loaded['X_test']
y_eval = data_loaded['y_test']

1252

In [4]:
# one hot encoding

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import pprint
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(195, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.45)

        self.fc2 = nn.Linear(512, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.20)

        self.out = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.out(x)
        return x

    

# Set up GPU device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_eval_tensor = torch.tensor(X_eval, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)  # Add extra dimension
y_eval_tensor = torch.tensor(y_eval, dtype=torch.float32).unsqueeze(1).to(device)

# --- Training ---
net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()
losses = []

num_epochs = 1200
for epoch in range(num_epochs):
    net.train()
    y_pred_train = net(X_train_tensor)  # Use the tensor version.
    loss = loss_fn(y_pred_train, y_train_tensor)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}")
    if epoch + 1 == num_epochs:
        print(f"Epoch {num_epochs}, Training Loss: {loss.item():.4f}")

# --- Evaluation ---
net.eval()
with torch.no_grad():
    y_pred_eval = net(X_eval_tensor)
eval_mse = loss_fn(y_pred_eval, y_eval_tensor).item()
print(f"Evaluation MSE: {eval_mse:.4f}")

trained_weights = net.state_dict()
weights_dict = OrderedDict()
for name, param in trained_weights.items():
    values = param.detach().cpu().numpy().tolist()
    weights_dict[name] = values

with open("weights.py", "w") as f:
    f.write("weights = OrderedDict([\n")
    for key, value in weights_dict.items():
        f.write(f"    ('{key}', torch.tensor(\n")
        f.write(pprint.pformat(value, indent=8))
        f.write("\n    )),\n")  # close parentheses here
    f.write("])\n")



Using device: cuda
Epoch 50/1200, Training Loss: 0.2821
Epoch 100/1200, Training Loss: 0.2473
Epoch 150/1200, Training Loss: 0.2179
Epoch 200/1200, Training Loss: 0.1957
Epoch 250/1200, Training Loss: 0.1823
Epoch 300/1200, Training Loss: 0.1743
Epoch 350/1200, Training Loss: 0.1681
Epoch 400/1200, Training Loss: 0.1636
Epoch 450/1200, Training Loss: 0.1594
Epoch 500/1200, Training Loss: 0.1572
Epoch 550/1200, Training Loss: 0.1529
Epoch 600/1200, Training Loss: 0.1504
Epoch 650/1200, Training Loss: 0.1480
Epoch 700/1200, Training Loss: 0.1452
Epoch 750/1200, Training Loss: 0.1428
Epoch 800/1200, Training Loss: 0.1408
Epoch 850/1200, Training Loss: 0.1391
Epoch 900/1200, Training Loss: 0.1372
Epoch 950/1200, Training Loss: 0.1358
Epoch 1000/1200, Training Loss: 0.1341
Epoch 1050/1200, Training Loss: 0.1335
Epoch 1100/1200, Training Loss: 0.1320
Epoch 1150/1200, Training Loss: 0.1312
Epoch 1200/1200, Training Loss: 0.1308
Epoch 1200, Training Loss: 0.1308
Evaluation MSE: 0.1313
