<a href="https://colab.research.google.com/github/adeotti/sudoku-env/blob/main/M9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# python version 3.9.11 
%pip install numpy==1.25.2 tensorboard==2.17.0 torchrl==0.4.0 gymnasium==0.29.1 tensordict==0.4.0

from IPython.display import clear_output
def clear():
  clear_output(wait=False)

import math,sys 
import torch

clear()

In [None]:
easyBoard = torch.tensor([
    [0, 0, 0, 5, 3, 1, 0, 0, 0],
    [0, 0, 0, 0, 4, 0, 3, 0, 1],
    [1, 0, 0, 8, 0, 0, 0, 0, 0],
    [0, 0, 4, 0, 0, 5, 6, 0, 0],
    [0, 0, 3, 9, 0, 2, 1, 4, 0],
    [6, 1, 5, 0, 7, 0, 0, 9, 8],
    [0, 2, 0, 0, 9, 6, 0, 1, 0],
    [0, 5, 7, 2, 0, 8, 0, 0, 6],
    [0, 6, 1, 7, 5, 3, 0, 2, 4]])

solution = torch.tensor([
    [8, 4, 9, 5, 3, 1, 7, 6, 2],
    [5, 7, 2, 6, 4, 9, 3, 8, 1],
    [1, 3, 6, 8, 2, 7, 4, 5, 9],
    [2, 9, 4, 1, 8, 5, 6, 7, 3],
    [7, 8, 3, 9, 6, 2, 1, 4, 5],
    [6, 1, 5, 3, 7, 4, 2, 9, 8],
    [3, 2, 8, 7, 9, 5, 1, 6, 7],
    [4, 5, 7, 2, 1, 8, 9, 3, 6],
    [9, 6, 1, 7, 5, 3, 8, 2, 4]])

Game and Utility class 

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Board_specs:
  size: tuple = (9,9)
  low: int = 1
  high: int = 9

class Game:
    def __init__(self,action = None):
        self.board = easyBoard.clone()
        self.action = action
        self.reward = 0
        self.done = torch.equal(solution.clone(),self.board)

        self.modifiableCells = []
        
        for i,x in enumerate(self.board):
            for y in range(Board_specs.high): 
                if x[y] == 0: 
                    self.modifiableCells.append((i,y))    

    def Updated_board(self):
        if self.action is not None:
            row,column,value = self.action
            if (row,column) in self.modifiableCells:

                x = self.board[row].tolist()
                x.pop(column)
            
                y = [element[column].item() for element in self.board]
                y.pop(row)
                    
                n = int(math.sqrt(Board_specs.high))
                ix,iy = (self.action[0]//n)* n , (self.action[1]//n)* n
                region = torch.flatten(self.board[ix:ix+n , iy:iy+n]).tolist()

                local_row = row - ix
                local_col = column - iy
                action_index = local_row * n + local_col
                region_ = [num for idx, num in enumerate(region) if idx != action_index]

                sector = [x,y,region_]
                sector = [item for sublist in sector for item in sublist]
                sector_ = [element for element in sector if element !=0] # filtered the zeros

                if not value in sector_:
                    self.board[row][column] = value
                    self.reward +=10

                    if self.done :
                        self.reward+= 20
                    return self.board,self.reward,self.done
                
                else :
                    self.reward -= 2
                return self.board,self.reward,self.done

            else:
                self.reward -=2
        return self.board,self.reward,self.done
    
    def reset(self):
        self.board = easyBoard 
        return easyBoard

Environment

In [None]:
from torchrl.envs import EnvBase
from torchrl.data import BoundedTensorSpec,CompositeSpec
from tensordict import TensorDictBase,TensorDict

class environment(EnvBase):
    def __init__(self):
        super().__init__()

        self.action = None
        self.game = Game(self.action)
        self.updatedBoard,self.reward,self.done = self.game.Updated_board()

        self.action_spec = BoundedTensorSpec(
            low=[[0,0,1]],
            high=[[9,9,9]],
            shape=torch.Size([1,3]),  
            dtype=torch.int
        )

        self.observation_format = BoundedTensorSpec(
            low=1.0,
            high=9.0,
            shape=(easyBoard).unsqueeze(0).shape,
            dtype=torch.float32
        )
        self.observation_spec = CompositeSpec(observation = self.observation_format)

    def _step(self,tensordict) -> TensorDictBase :
        self.action = tensordict["action"][0] # original shape -> tensor([[x, y, value]])
        self.game = Game(self.action)
        self.updated,self.reward,self.done = Game(self.action).Updated_board()
         

        output = TensorDict(
            {
                "observation" : self.updatedBoard.clone().detach().unsqueeze(0).float(),
                "reward" : self.reward,
                "done" : self.done
            }
        )
        return output

    def _reset(self,tensordict,**kwargs) -> TensorDictBase :  
        output = TensorDict(
            {
                "observation" :  self.updatedBoard.clone().detach().unsqueeze(0).float()
                }
        )
        return output

    def _set_seed(self):
        pass

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from tensordict.nn import TensorDictModule

from torchrl.modules import ValueOperator,ProbabilisticActor
from torchrl.objectives.value import GAE
from torchrl.collectors import SyncDataCollector

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hypers
l_rate = 0.01
sdg_momentum = 0.9

frames =  100         # number of steps
batchsize = 20             # for the most inner loop of the training step
total_frames = 10000    # maximum steps

gamma = 0.80
lmbda = 0.99

env = environment()
dummy_observation = env._reset(None)["observation"] 


In [None]:
class Mask: 
  # This will alter the softmax distribution so value in [x,y,value] != 0 
  def __init__(self):
    self.newValue = -float("inf")

  def apply(self,tensor : torch.FloatTensor):
    self.mask = torch.zeros_like(tensor,dtype=torch.bool)
    self.mask[-1,-1,0] = True
    tensor = tensor.masked_fill(mask=self.mask,value=self.newValue)
    return tensor

class ActorNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.size = 81
    self.action_dist = 27
    self.action_spec = (3,9)
    self.mask = Mask()

    self.input_layer = nn.LazyLinear(81)
    self.flat = nn.Flatten()
    self.dense_one = nn.Linear(self.size,self.size)
    self.dense_two = nn.Linear(self.size,self.size)
    self.dense_3 = nn.Linear(self.size,self.size)
    self.dense_4 = nn.Linear(self.size,self.size)
    self.output = nn.Linear(self.size,self.action_dist)
    
  def forward(self,x):
    x = self.flat(x)
    x = F.relu(self.input_layer(x))
    x = (self.dense_one(x))
    x = F.relu(self.dense_two(x))
    x = F.relu(self.dense_3(x))
    x = self.dense_4(x)
    x = F.relu(self.output(x))
    x = torch.unflatten(x,-1,(self.action_spec))
    x = self.mask.apply(x)
    return F.softmax(x,-1)

In [None]:
torch.set_printoptions(threshold=100000,linewidth=10000) 
@torch.no_grad()
def weights_init(w):
  if isinstance(w,nn.Linear):
    nn.init.kaiming_uniform_(w.weight,mode="fan_in",nonlinearity="relu")
    nn.init.zeros_(w.bias)
    
actor = ActorNetwork()
actor.forward(torch.rand((20,9,9)))
actor.apply(weights_init)

In [None]:
Policy = TensorDictModule(
  module=actor, 
  in_keys=["observation"],
  out_keys=["probs"]
)

PolicyModule = ProbabilisticActor(
  module=Policy ,
  spec=env.action_spec,in_keys=["probs"],
  distribution_class = Categorical,
  return_log_prob = True
)

In [None]:
Collector = SyncDataCollector(
    create_env_fn=env,
    policy=PolicyModule,
    frames_per_batch=frames,
    total_frames=total_frames,
)
Collector.rollout()
clear()

In [None]:
class ValueNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.size = 81
    self.input_layer = nn.LazyLinear(self.size)
    self.flat = nn.Flatten()
    self.dense_one = nn.LazyLinear(self.size)
    self.dense_two = nn.LazyLinear(self.size)
    self.output = nn.LazyLinear(1)

  def forward(self,x):
    x = self.flat(x)
    x = F.relu(self.input_layer(x))
    x = F.relu(self.dense_one(x))
    x = F.relu(self.dense_two(x))
    return self.output(x)

Critic = ValueNetwork()
Critic(dummy_observation)
Critic.apply(weights_init)

ValueModule = ValueOperator(
  module= Critic,
  in_keys=["observation"]
)

Advantage = GAE(
  gamma=gamma,
  lmbda=lmbda,
  value_network=ValueModule,
  average_gae=True,
  device=device
)

Training loop

In [None]:
from torchrl.data import ReplayBuffer,SamplerWithoutReplacement,LazyTensorStorage
from torchrl.objectives import ClipPPOLoss
from tqdm import tqdm
from collections import deque
from torch.utils.tensorboard import SummaryWriter
 
class Training:
    def __init__(self):
        self.collector = Collector
        self.epochs = 10
        self.batchsize = batchsize
        self.valuemodule = ValueModule
        self.advantage = Advantage

        self.policy = PolicyModule

        self.lossfunction = ClipPPOLoss(
            actor_network=PolicyModule,
            critic_network=ValueModule,
            entropy_coef=0.01
        )
        self.optimizer = torch.optim.Adam(
            params=self.lossfunction.parameters(),
            lr=l_rate
        )
        self.memory = ReplayBuffer(
            storage=LazyTensorStorage(max_size=frames),
            sampler=SamplerWithoutReplacement()
        )

    def save_logs(self):
            log_dir = "Data/"  
            self.writer = SummaryWriter(log_dir)

    def save_weight(self):
        path = "Data/actor_100k.pth"  
        torch.save(self.policy.state_dict(),path)

    def train(self,start : bool = None):
        if start:
            bestReward = -20
            newreward = 0
            rewardHistory = deque(maxlen = 10)
            self.save_logs()
            
            for i,data_tensordict in tqdm(enumerate(self.collector),total = total_frames/frames):
                for e in range(self.epochs):
                    self.advantage(data_tensordict) # this apply the advantage compute and modify data_tensordict
                    data_tensordict["advantage"] = data_tensordict["advantage"].unsqueeze(-1)
                    self.memory.extend(data_tensordict)
              
                    for _ in  range(frames//self.batchsize):
                        subdata = self.memory.sample(self.batchsize)
                        
                        loss_val = self.lossfunction(subdata)
                        loss_value = (
                            loss_val["loss_objective"] + 
                            loss_val["loss_critic"] + 
                            loss_val["loss_entropy"]
                        )
                        self.optimizer.zero_grad()
                        loss_value.backward()
                        self.optimizer.step()
        
                self.writer.add_scalar("main/Loss_sum",loss_value.item())
                self.writer.add_scalar("main/Loss_entropy",(loss_val["loss_entropy"].item()))
                self.writer.add_scalar("main/Loss_critic",loss_val["loss_critic"].item())
                self.writer.add_scalar("main/Loss_objective",loss_val["loss_objective"].item())
                
                currentReward = data_tensordict["next"]["reward"][0].mean()
                rewardHistory.append(currentReward)
                averageReward = sum(rewardHistory)/len(rewardHistory)
                self.writer.add_scalar("Reward/collector average reward",averageReward)
                self.writer.add_scalar("Reward/collector reward",data_tensordict["next"]["reward"][0].mean())

                with torch.no_grad():
                    oldreward = environment().rollout(200,PolicyModule)["next"]["reward"]
                    clear()
                    oldreward = torch.flatten(oldreward).to(float).mean()
                    newreward += oldreward
                    self.writer.add_scalar("Reward/Test reward",newreward)

                   
            self.save_weight()
      
Training().train(start=True)