# **REINFORCEMENT LEARNING BASED SOLUTION FOR SOLVING A SUDOKU**

**State Representation**

Sudoku boards are encoded as 9×9 grids.

Each cell is one-hot encoded across 10 channels (digits 0–9), allowing the network to capture both filled and empty cells.

This transforms the puzzle into a structured tensor input suitable for convolutional or dense neural layers.

**Action Space**

Defined as 729 discrete actions = 81 cells × 9 possible digits.

Each action represents placing a number (1–9) in a specific cell.

This extremely large action space makes exploration and convergence challenging.

**Reward Structure**

Positive rewards for correct placements aligning with Sudoku rules (valid digits in rows, columns, boxes).

Penalties for illegal moves, repeated actions, or stagnation.

Additional shaping:

Rewards for completing rows, columns, or boxes.

Stagnation counters prevent looping behavior.

The goal is to encourage incremental correctness instead of only rewarding complete solutions.

**Methodology**

Environment: A custom SudokuEnv built on Gym, generating puzzles from the Kaggle dataset with adjustable clue fraction.

Agent:

*   Deep Q-Network (DQN) with epsilon-greedy exploration.
*   Prioritized Replay Buffer to sample more useful experiences.
*   Periodic updates with batches from memory.

Training regime:

*   Curriculum-style — start with easier puzzles (clue_fraction ~0.9) and gradually make them harder.
*   Training loop executes thousands of attempts with capped max steps.
*   Uses real-time rendering to observe puzzle filling during training.


**Successes**
*   The agent learns local constraints (e.g., avoiding duplicate numbers in a row/column).
*   The reward shaping provides enough signal to progress beyond random guessing.
*   The curriculum approach ensures stability on simpler puzzles before tackling harder ones.
*   Demonstrates proof-of-concept: RL can capture logical structures in Sudoku.









**Failures / Limitations**



*   Exploration problem: With 729 possible actions per step, the agent wastes many steps exploring invalid moves.
*   Sparse ultimate reward: Full puzzle completion is rare; most training focuses only on partial correctness.




In [None]:
#import libraries

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gymnasium import spaces
import kagglehub
from collections import deque
import random
from IPython.display import clear_output
import torch.nn.functional as F
import gym
from gym import spaces



# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Download and load dataset
kagglehub.login()

In [None]:


path = kagglehub.dataset_download("bryanpark/sudoku")
!mv /root/.cache/kagglehub/datasets/bryanpark/sudoku/versions/3 .


In [None]:

# Load puzzles and solutions into tensors

sudoku_puzzles, sudoku_solutions = [], []
with open('/kaggle/input/sudoku/sudoku.csv', 'r') as f:
    for line in f.readlines()[1:]:
        puzzle, solution = line.strip().split(",")
        sudoku_puzzles.append([int(i) for i in puzzle])
        sudoku_solutions.append([int(i) for i in solution])

sudoku_puzzles = torch.tensor(sudoku_puzzles, dtype=torch.float32).view(-1, 9, 9).to(device)
sudoku_solutions = torch.tensor(sudoku_solutions, dtype=torch.float32).view(-1, 9, 9).to(device)


In [None]:
sudoku_puzzles[0]

In [None]:
#we get a random solution and remove 1-clue_fraction amount of numbers
def get_random_puzzle(clue_fraction=0.8):
    idx = np.random.randint(len(sudoku_puzzles))
    solution = sudoku_solutions[idx].clone()
    puzzle = solution.clone()
    mask = torch.rand_like(puzzle) > clue_fraction
    puzzle[mask] = 0
    return puzzle, solution


In [None]:
#Sudoku Gym environment
class SudokuEnv(gym.Env):
    def __init__(self, clue_fraction=0.8):
        super().__init__()
        self.action_space = spaces.Discrete(81 * 9) #81 cells, 9 numbers
        self.observation_space = spaces.Box(low=0, high=9, shape=(9, 9), dtype=np.float32)
        self.grid = None
        self.solution = None
        self.original_grid = None
        self.current_step = 0
        self.max_steps = 5000
        self.clue_fraction = clue_fraction
        self.rewarded_rows = set()
        self.rewarded_cols = set()
        self.rewarded_boxes = set()
        self.rewarded_cells = set()
        self.recent_moves = deque(maxlen=10)
        self.stagnation_counter = 0
        self.temp_epsilon_boost = False

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.grid, self.solution = get_random_puzzle(clue_fraction=self.clue_fraction)
        self.original_grid = self.grid.clone()
        self.current_step = 0
        self.rewarded_rows.clear()
        self.rewarded_cols.clear()
        self.rewarded_boxes.clear()
        self.rewarded_cells.clear()
        self.recent_moves.clear()
        self.stagnation_counter = 0
        self.temp_epsilon_boost = False
        return self.grid.clone(), {}

    #checks if rows, columns, boxes are completed
    def check_subgoal_completion(self):
        row_ok = [len(torch.unique(self.grid[r][self.grid[r] != 0])) == len(self.grid[r][self.grid[r] != 0]) for r in range(9)]
        col_ok = [len(torch.unique(self.grid[:, c][self.grid[:, c] != 0])) == len(self.grid[:, c][self.grid[:, c] != 0]) for c in range(9)]
        box_ok = []
        for r in range(0, 9, 3):
            for c in range(0, 9, 3):
                box = self.grid[r:r+3, c:c+3].flatten()
                nonzero = box[box != 0]
                box_ok.append(len(torch.unique(nonzero)) == len(nonzero))
        return row_ok, col_ok, box_ok

    #Returns the number of conflicts in a grid for a given num in row and col
    @staticmethod
    def count_conflicts_static(grid, row, col, num):
        conflicts = 0
        conflicts += torch.sum(grid[row, :] == num).item() - (1 if grid[row, col] == num else 0)
        conflicts += torch.sum(grid[:, col] == num).item() - (1 if grid[row, col] == num else 0)
        r0, c0 = (row//3)*3, (col//3)*3
        box = grid[r0:r0+3, c0:c0+3]
        conflicts += torch.sum(box == num).item() - (1 if grid[row, col] == num else 0)
        return int(conflicts)

    def count_conflicts(self, row, col, num):
        return SudokuEnv.count_conflicts_static(self.grid, row, col, num)

    #renders the puzzel, correct - green, wrong red ( the agent ofcourse does not know if it's correct or wrong )
    def render(self):
        print("\n" + "-" * 25)
        for i in range(9):
            row_str = ""
            for j in range(9):
                val = int(self.grid[i, j])
                if self.original_grid[i, j] != 0:
                    # Original clues = normal
                    cell = f"{val}"
                else:
                    if val == 0:
                        cell = "."
                    elif self.solution is not None and val == int(self.solution[i, j]):

                        cell = f"\033[92m{val}\033[0m"
                    else:

                        cell = f"\033[91m{val}\033[0m"

                sep = " | " if (j + 1) % 3 == 0 and j != 8 else " "
                row_str += cell + sep
            print(row_str)
            if (i + 1) % 3 == 0 and i != 8:
                print("------+-------+------")
        print("-" * 25 + "\n")

    #every step of an episode has the following reward structure
    def step(self, action):
        pos, num = divmod(action, 9)
        num += 1
        row, col = divmod(pos, 9)

        reward = 0.0
        done = False

        empty_cells_before = torch.sum(self.grid == 0).item()
        row_ok_before, col_ok_before, box_ok_before = self.check_subgoal_completion()
        score_before = sum(row_ok_before) + sum(col_ok_before) + sum(box_ok_before)
        prev_val = self.grid[row, col].item()

        if self.original_grid[row, col] != 0:
            reward -= 1.0   # stronger penalty for changing clues
        else:
            self.grid[row, col] = num
            conflicts = self.count_conflicts(row, col, num)

            #penalty for conflicts
            #we decide to remove wrong nubers from the board and penalize instead because the wrong number
            #end up polluting the board due to sparse rewards vs penalities
            if conflicts > 0:
                reward -= 0.2 * conflicts
                self.grid[row, col] = 0 if prev_val == 0 else prev_val
            else:
            #we reward for filling in a cell for the first time to encourage exploration
                if prev_val == 0 and (row, col) not in self.rewarded_cells:
                    filled_ratio = 1 - (empty_cells_before / 81)
                    reward += 0.5 + 1.0 * filled_ratio
                    self.rewarded_cells.add((row, col))
                reward += 0.1


        empty_cells_after = torch.sum(self.grid == 0).item()
        row_ok_after, col_ok_after, box_ok_after = self.check_subgoal_completion()
        score_after = sum(row_ok_after) + sum(col_ok_after) + sum(box_ok_after)

        if score_after > score_before:
            reward += 0.5 * (score_after - score_before)

        #rewards first time for subgoals
        for r in range(9):
            if row_ok_after[r] and r not in self.rewarded_rows:
                reward += 10.0
                self.rewarded_rows.add(r)
        for c in range(9):
            if col_ok_after[c] and c not in self.rewarded_cols:
                reward += 10.0
                self.rewarded_cols.add(c)
        for b in range(9):
            if box_ok_after[b] and b not in self.rewarded_boxes:
                reward += 10.0
                self.rewarded_boxes.add(b)

        #rewards if the cells are all filled or complete
        if empty_cells_after == 0:
            reward += 40.0
            if all(row_ok_after) and all(col_ok_after) and all(box_ok_after):
                reward += 200.0
                done = True

        if empty_cells_after == empty_cells_before:
            self.stagnation_counter += 1
        else:
            self.stagnation_counter = 0

        #if the state stays stagnant penalize
        if self.stagnation_counter > 50:
            reward -= 2.0

        #stop doing the same move over and over again
        if (row, col, num) in self.recent_moves:
            reward -= 0.1
        self.recent_moves.append((row, col, num))

        #to encourage faster completion
        reward -= 0.01

        self.current_step += 1
        return self.grid.clone(), reward, done, False, {}






In [None]:
#Dueling Q-Network
class DuelingQNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(10, 32, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(128 * 9 * 9, 512)

        self.value_stream = nn.Sequential(
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.adv_stream = nn.Sequential(
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        x = x.view(-1, 10, 9, 9)
        x = self.conv_layers(x)
        x = F.relu(self.fc(x))
        value = self.value_stream(x)
        adv = self.adv_stream(x)
        return value + adv - adv.mean(dim=1, keepdim=True)


In [None]:
#prioritized replay buffer to retain those instances with max TDerror
class PrioritizedReplayBuffer:
    def __init__(self, max_size=10000, alpha=0.6):
        self.buffer = deque(maxlen=max_size)
        self.priorities = deque(maxlen=max_size)
        self.alpha = alpha

    def add(self, exp, td_error=1.0):
        self.buffer.append(exp)
        self.priorities.append(abs(td_error) + 1e-5)

    def sample(self, batch_size):
        probs = np.array(self.priorities) ** self.alpha
        probs /= probs.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[i] for i in indices]
        return zip(*samples), indices

    def update_priorities(self, indices, td_errors):
        for i, td in zip(indices, td_errors):
            self.priorities[i] = abs(td.item()) + 1e-5

    def size(self):
        return len(self.buffer)

In [None]:
#The game playing agent
class DQNAgent:
    def __init__(self, state_size, action_size, tau=0.01, device="cuda"):
        self.state_size = state_size
        self.action_size = action_size
        #Double DQN
        self.q_network = DuelingQNetwork(state_size, action_size).to(device)
        self.target_network = DuelingQNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=0.001)
        self.loss_fn = nn.SmoothL1Loss()
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay_steps = 100_000
        self.epsilon_decay_rate = (self.epsilon - self.epsilon_min) / self.epsilon_decay_steps
        self.steps_done = 0
        self.tau = tau
        self.device = device
        self.update_target_network(hard=True)

    def update_target_network(self, hard=False):
        if hard:
            self.target_network.load_state_dict(self.q_network.state_dict())
        else:
            for target_param, local_param in zip(self.target_network.parameters(), self.q_network.parameters()):
                target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

    def act(self, state, grid=None, original_grid=None, inference=False, temp_epsilon_boost=False):
        self.steps_done += 1
        if not inference:
            self.epsilon = max(self.epsilon_min, self.epsilon - self.epsilon_decay_rate)

        effective_epsilon = self.epsilon
        if temp_epsilon_boost:
            effective_epsilon = max(0.3, effective_epsilon)

        if not inference and np.random.rand() <= effective_epsilon:
            return np.random.randint(self.action_size)

        with torch.no_grad():
            q_values = self.q_network(state)

            # Mask invalid actions during inference
            if inference and grid is not None and original_grid is not None:
                for pos in range(81):
                    row, col = divmod(pos, 9)
                    if original_grid[row, col] != 0:  # fixed clue
                        q_values[0, pos*9:(pos+1)*9] = -1e9
                    else:
                        for num in range(1, 10):
                            if SudokuEnv.count_conflicts_static(grid, row, col, num) > 0:
                                q_values[0, pos*9+num-1] = -1e9
            return torch.argmax(q_values).item()

    def train(self, batch, indices=None, replay_buffer=None):
        states, actions, rewards, next_states, dones = batch
        states = torch.stack(states).to(self.device)
        next_states = torch.stack(next_states).to(self.device)
        actions = torch.tensor(actions, dtype=torch.long, device=self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        dones = torch.tensor(dones, dtype=torch.float32, device=self.device)

        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # Double DQN
        next_actions = self.q_network(next_states).argmax(1)
        next_q_values = self.target_network(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        loss = self.loss_fn(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.update_target_network(hard=False)

        if replay_buffer and indices is not None:
            td_errors = (q_values - target_q_values).detach()
            replay_buffer.update_priorities(indices, td_errors)

In [None]:

def one_hot_encode(grid, device="cuda"):
    one_hot = torch.zeros(10, 9, 9, device=device)
    for i in range(9):
        for j in range(9):
            value = int(grid[i, j])
            one_hot[value, i, j] = 1.0
    return one_hot.unsqueeze(0)

In [None]:
#Trainging loop for the agent
def train_and_resolve_multiple_times(agent, replay_buffer, env, num_attempts=10, max_steps=None, render_frequency=10):
    for attempt in range(num_attempts):
        env.clue_fraction = max(0.3, env.clue_fraction - 0.0005 * attempt)
        grid, _ = env.reset()
        original_grid = grid.clone()
        state = one_hot_encode(grid, device=agent.device)

        done = False
        total_reward = 0
        step_count = 0

        while not done:
            action = agent.act(state, temp_epsilon_boost=env.temp_epsilon_boost)
            next_grid, reward, done, _, _ = env.step(action)
            next_state = one_hot_encode(next_grid, device=agent.device)

            if replay_buffer.size() > 64:
                batch, indices = replay_buffer.sample(64)
                agent.train(batch, indices, replay_buffer)

            replay_buffer.add((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward
            step_count += 1

            if max_steps and step_count >= max_steps:
                break

            if step_count % render_frequency == 0:
                clear_output(wait=True)
                print(f"Attempt {attempt + 1}, Step {step_count}, Reward: {total_reward:.2f}, Clue Fraction: {env.clue_fraction:.2f}")
                env.render()

        print(f"Attempt {attempt+1} finished | Steps: {step_count} | Total Reward: {total_reward:.2f}")



In [None]:
env = SudokuEnv(clue_fraction=0.90)
replay_buffer = PrioritizedReplayBuffer(max_size=20000)
agent = DQNAgent(state_size=10*9*9, action_size=81*9, device=device)

train_and_resolve_multiple_times(
    agent,
    replay_buffer,
    env,
    num_attempts=500,
    max_steps=2500,
    render_frequency=50
)