<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/AlphaGo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

def create_random_go_data(num_samples=1000, board_size=19):
    """Generate pseudo-training data with random board states and move probabilities."""
    X = np.random.randint(0, 3, (num_samples, 1, board_size, board_size)).astype(np.float32)  # Board state (1 channel)
    y_policy = np.random.rand(num_samples, board_size * board_size).astype(np.float32)  # Move probabilities
    y_value = np.random.uniform(-1, 1, (num_samples, 1)).astype(np.float32)  # Game outcome (-1 to 1)
    return torch.tensor(X), torch.tensor(y_policy), torch.tensor(y_value)

class AlphaGoNet(nn.Module):
    def __init__(self, board_size=19):
        super(AlphaGoNet, self).__init__()
        self.board_size = board_size

        # Common Feature Extraction
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(),
        )

        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Conv2d(256, 2, kernel_size=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2 * board_size * board_size, board_size * board_size),
            nn.Softmax(dim=1)
        )

        # Value Head
        self.value_head = nn.Sequential(
            nn.Conv2d(256, 1, kernel_size=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(board_size * board_size, 64), nn.ReLU(),
            nn.Linear(64, 1), nn.Tanh()
        )

    def forward(self, x):
        features = self.conv_layers(x)
        policy = self.policy_head(features)
        value = self.value_head(features)
        return policy, value

# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlphaGoNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion_policy = nn.MSELoss()
criterion_value = nn.MSELoss()

# Generate Random Data
X_train, y_policy_train, y_value_train = create_random_go_data()
X_train, y_policy_train, y_value_train = X_train.to(device), y_policy_train.to(device), y_value_train.to(device)

# Training Loop
for epoch in range(10):  # Small training loop
    optimizer.zero_grad()
    policy_pred, value_pred = model(X_train)
    loss_policy = criterion_policy(policy_pred, y_policy_train)
    loss_value = criterion_value(value_pred, y_value_train)
    loss = loss_policy + loss_value
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Testing Inference
test_board = torch.randint(0, 3, (1, 1, 19, 19), dtype=torch.float32).to(device)
policy_out, value_out = model(test_board)
print("Sample Policy Output:", policy_out.detach().cpu().numpy())
print("Sample Value Output:", value_out.detach().cpu().numpy())


Epoch 1, Loss: 0.6759
Epoch 2, Loss: 0.6712
Epoch 3, Loss: 0.6646
Epoch 4, Loss: 0.6650
Epoch 5, Loss: 0.6649
Epoch 6, Loss: 0.6655
Epoch 7, Loss: 0.6654
Epoch 8, Loss: 0.6649
Epoch 9, Loss: 0.6643
Epoch 10, Loss: 0.6643
Sample Policy Output: [[1.74505869e-04 1.00048957e-04 4.68107779e-03 8.30252189e-03
  1.50735184e-04 7.09612359e-05 1.43133942e-02 2.00251307e-04
  2.17691253e-04 1.14968456e-02 1.87502810e-04 1.10543406e-04
  3.88346525e-04 1.89784609e-04 8.88134455e-05 1.19790726e-04
  1.23454526e-03 3.78453959e-04 2.49162840e-04 5.89653244e-03
  5.12389094e-03 3.72252194e-04 2.08187499e-04 9.02030058e-03
  8.07481993e-05 1.54934329e-04 5.37571823e-03 3.41055333e-04
  9.17945465e-04 4.64270590e-03 1.87805170e-04 1.14766613e-03
  1.68724859e-04 2.34748982e-03 1.08893623e-03 1.42826354e-02
  1.68903556e-04 2.53110949e-04 5.10392291e-03 1.09903878e-04
  1.98277342e-03 4.61724633e-03 6.34160533e-04 1.97918643e-03
  2.23617084e-04 4.49238578e-05 6.36371376e-04 1.60747059e-02
  8.58140629e