<a href="https://colab.research.google.com/github/OmrySh/DecisionTransformer/blob/main/DecisionTransformersConnect4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade open_spiel

Collecting open_spiel
  Downloading open_spiel-1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: open_spiel
Successfully installed open_spiel-1.3


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim


class DecisionTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, input_dim, action_dim, dropout=0.5):
        super(DecisionTransformer, self).__init__()

        # Define the transformer model
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dropout=dropout)

        # Define the input embedding layer
        self.input_embedding = nn.Linear(input_dim+1, d_model)  # +1 for the reward

        # Define the output action layer
        self.output_action = nn.Linear(d_model, action_dim)

    def forward(self, src, tgt):
        print(f"Source shape (before embedding): {src.shape}")
        print(f"Target shape (before embedding): {tgt.shape}")

        src = self.input_embedding(src)
        tgt = self.input_embedding(tgt)

        print(f"Source shape (after embedding): {src.shape}")
        print(f"Target shape (after embedding): {tgt.shape}")

        transformer_output = self.transformer(src.unsqueeze(0), tgt.unsqueeze(0))
        action_logits = self.output_action(transformer_output.squeeze(0))
        return action_logits
# # Testing the class instantiation
# test_model = DecisionTransformer(D_MODEL, NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, INPUT_DIM, ACTION_DIM)
# test_model


In [3]:
import torch.optim as optim

class DecisionTransformerModel:
    def __init__(self, state_dim, action_dim, d_model, nhead, num_encoder_layers, num_decoder_layers):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Adjust the state dimension to account for the reward concatenated
        self.model = DecisionTransformer(d_model, nhead, num_encoder_layers, num_decoder_layers,
                                            state_dim, action_dim).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()

    def train(self, data, epochs, batch_size):
        self.model.train()

        for epoch in range(epochs):
            np.random.shuffle(data)
            batches = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
            # dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)

            for batch in batches:
                states, prev_actions, actions, rewards = zip(*batch)

                # Reshape states to match the expected input shape for the model
                states = torch.stack(states).view(batch_size, -1).to(self.device)
                prev_actions = torch.stack([torch.tensor(pa) for pa in prev_actions]).to(self.device)  # Convert prev_actions to GPU tensor
                actions = torch.stack([torch.tensor(a) for a in actions]).to(self.device)
                rewards = torch.stack([torch.tensor(r) for r in rewards]).to(self.device)
                if actions.shape[0] < 32:
                    continue
                # print(f"Batch states shape: {states.shape}")
                # print(f"Batch actions shape: {actions.shape}")
                # print(f"Batch rewards shape: {rewards.shape}")

                # Concatenate the rewards to the states
                src = torch.cat((states, prev_actions.unsqueeze(1), rewards.unsqueeze(1)), dim=1)  # Include prev_actions in concatenation
                tgt = src.clone().detach()


                print(f"States shape: {states.shape}")
                print(f"Actions shape: {actions.shape}")
                print(f"Rewards shape: {rewards.shape}")
                print(f"Source shape: {src.shape}")
                print(f"Target shape: {tgt.shape}")

                # Forward pass
                outputs = self.model(src, tgt)

                # Compute loss
                loss = self.criterion(outputs, actions)

                # Backward pass and optimization
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

    def predict(self, state, prev_action, prev_reward):
        self.model.eval()

        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            prev_action = torch.LongTensor([prev_action]).unsqueeze(0).to(self.device)
            prev_reward = torch.FloatTensor([prev_reward]).unsqueeze(0).to(self.device)

            src = torch.cat((state, prev_action, prev_reward), dim=-1)
            tgt = src.clone().detach()

            print(f"State shape: {state.shape}")
            print(f"Prev_action shape: {prev_action.shape}")
            print(f"Prev_reward shape: {prev_reward.shape}")

            src = torch.cat((state, prev_action, prev_reward), dim=-1)
            tgt = src.clone().detach()

            print(f"Source shape (before model): {src.shape}")
            print(f"Target shape (before model): {tgt.shape}")

            output = self.model(src, tgt)

            predicted_action = torch.argmax(output, dim=-1).item()

        return predicted_action


In [5]:
import numpy as np
import pyspiel
from open_spiel.python.algorithms import tabular_qlearner
from open_spiel.python.rl_environment import Environment
import torch

def collect_data_fixed_length(num_games=1000, sequence_length=10):
    """Collects data of fixed sequence length from Q-learning players playing games."""
    game = pyspiel.load_game("connect_four")
    env = Environment(game)

    # Initialize two Q-learning agents
    players = [tabular_qlearner.QLearner(player_id=player_id, num_actions=env.action_spec()["num_actions"])
               for player_id in [0, 1]]

    dataset = []

    for _ in range(num_games):
        time_step = env.reset()
        game_data = []
        prev_action = 0  # Initialize previous action as 0 for the first state

        while not time_step.last():
            current_player = time_step.observations["current_player"]
            state_tensor = torch.tensor(time_step.observations["info_state"][current_player])
            action = players[current_player].step(time_step).action
            game_data.append((state_tensor, prev_action, action))
            prev_action = action  # Update the previous action
            time_step = env.step([action])

        # Determine the reward
        if time_step.rewards[0] == 1:  # Player 1 wins
            reward = 1
        elif time_step.rewards[0] == -1:  # Player 2 wins
            reward = -1
        else:  # Draw
            reward = 0

        # Ensure the game data is of fixed length
        last_valid_state = game_data[-1][0] if game_data else torch.zeros_like(state_tensor)
        last_valid_prev_action = game_data[-1][1] if game_data else 0
        while len(game_data) < sequence_length:
            game_data.append((last_valid_state, last_valid_prev_action, 0))

        # If the game is longer than sequence_length, truncate it
        game_data = game_data[:sequence_length]

        # Attach the reward to each (state, action) tuple and add to the main dataset
        game_data_with_rewards = [(s, pa, a, reward) for s, pa, a in game_data]
        dataset.extend(game_data_with_rewards)

    return dataset


def check_data_shapes(data):
    state_shapes = set()
    prev_action_shapes = set()
    action_shapes = set()
    reward_shapes = set()

    for state, prev_action, action, reward in data:
        state_shapes.add(state.shape)
        prev_action_shapes.add(np.array(prev_action).shape)
        action_shapes.add(np.array(action).shape)
        reward_shapes.add(np.array(reward).shape)

    print("Unique state shapes:", state_shapes)
    print("Unique previous action shapes:", prev_action_shapes)
    print("Unique action shapes:", action_shapes)
    print("Unique reward shapes:", reward_shapes)


data = collect_data_fixed_length()
check_data_shapes(data)

# 2. Train the Decision Transformer

# Hyperparameters
D_MODEL = 512
NHEAD = 8
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
INPUT_DIM = 127  # Connect 4 board has 7x6 = 42 positions
ACTION_DIM = 7  # 7 possible columns to drop a piece
EPOCHS = 1
BATCH_SIZE = 32

data = collect_data_fixed_length()
model = DecisionTransformerModel(INPUT_DIM, ACTION_DIM, D_MODEL, NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS)
model.train(data, EPOCHS, BATCH_SIZE)

# 3. Test the trained model

def test_model(model, num_tests=100):
    game = pyspiel.load_game("connect_four")

    correct_predictions = 0

    for _ in range(num_tests):
        state = game.new_initial_state().observation_tensor()
        true_action = np.random.choice(7)  # Randomly choose among the 7 possible actions
        predicted_action = model.predict(state, true_action, 0)  # Assume previous reward as 0 for simplicity

        if true_action == predicted_action:
            correct_predictions += 1

    accuracy = correct_predictions / num_tests
    print(f"Accuracy on test data: {accuracy:.2f}")

test_model(model)


TypeError: ignored