<a href="https://colab.research.google.com/github/adeotti/sudoku-env/blob/main/M9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython.display import clear_output
from puzzle import easyBoard,solution
import torch
import math

In [None]:
modifiableCells = []
for i,x in enumerate(easyBoard):
    for y in range(9): 
        if x[y] == 0: 
            modifiableCells.append((i,y))

In [None]:
def actionCheck(action:tuple):
    x,y,value = action
    assert x in range(0,10), f"{x} is too big or too small idk..."
    assert y in range(0,10), f"{y} is too big or too small idk..."
    assert value in range(1,10), f"{value} is too big or too small idk..."

def region(index:tuple|list,board: torch.Tensor): # return the region (row,column,block) of a cell
    x,y = index
    xlist = board[x].tolist()
    xlist.pop(y)

    ylist = [element[y].tolist() for element in board]
    ylist.pop(x)

    #block
    n = int(math.sqrt(9))
    ix,iy = (x//n)* n , (y//n)* n
    block = torch.flatten(board[ix:ix+n , iy:iy+n]).tolist()
    local_row = x - ix
    local_col = y - iy
    action_index = local_row * n + local_col
    block_ = [num for idx, num in enumerate(block) if idx != action_index]

    #output
    Region = [xlist,ylist,block_]
    Region = [item for sublist in Region for item in sublist]
    return Region


class solver: 
    def __init__(self,state:torch.Tensor,modCells:list):
        self.board = state.clone()
        self.solution = solution.clone()
        self.modCells = modCells
        self.maxStep = len(modCells)*3
         
    def domain(self,idx:tuple|list) -> list :
        Region = region(idx,self.board)
        Region = set([item for item in Region if item != 0]) 

        domain_ = set(range(1,10)) 
        TrueDomain = list(domain_ - Region)
        return TrueDomain
    
    def collector(self):
        queu = []
        for element in self.modCells:
            queu.append({element : self.domain(element)})
        return queu
    
    def isSolvable(self) -> bool: 
        count = 0
        while True:
            self.__init__(self.board,self.modCells)
            data = self.collector()
            for dictt in data:
                for k,v in dictt.items():
                    if len(v) == 1:
                        self.board[k] = v[0]
            count+=1
            if len(data) == 0:
                break
            else:
                if count > self.maxStep:
                    break

        diff = (self.board == solution)
        diff = (diff == True).sum().item()
    
        if diff == solution.numel(): # if all True cells = 81 :
            return True
        else:
            return False
        

In [None]:
class Env:
    def __init__(self):
        self.modifiableCells = modifiableCells.copy()

    def step(self,action : tuple|list,state:torch.Tensor):
        actionCheck(action)
        self.action = action
        x,y,value = self.action
        reward = self.rewardFunction(action,state)
        
        if reward > 0:
            state.squeeze(0)[x][y] = value
        done = torch.equal(solution,state)  
        return [state,reward,done,action]
           
    def rewardFunction(self,action:tuple|list,board:torch.Tensor):
        """ 
        This will call the solver method to check if the board is solvable after a cell is filled.
        This fill a copy of the given board so the result here does not affect the original state
        if the board is solvable then the index of the value (x,y) is removed from the list of modifiables cells
        """
        reward = 0
        x,y,value = action
        board = board.squeeze(0).clone()
        copyList = self.modifiableCells.copy()

        board[x][y] = value
    
        if (x,y) not in copyList:
            return 0
        
        copyList.remove((x,y))
        Solver = solver(board.clone(),copyList)

        diff = (board == solution) # the difference between the state and the solution
        conflicts = (diff == False).sum().item() * 0.1

        if Solver.isSolvable():
            reward = round((conflicts + 0.5),2)
            self.modifiableCells.remove((x,y))
        else:
            reward = -conflicts
           
        return reward 
    

In [None]:
import torch.nn as nn
import torch.nn.functional as F

batchSize = 1

class Mask: 
  # This will alter the softmax distribution so value in [x,y,value] != 0 
  def __init__(self):
    self.newValue = -float("inf")

  def apply(self,tensor : torch.FloatTensor):
    self.mask = torch.zeros_like(tensor,dtype=torch.bool)
    self.mask[-1,-1,0] = True
    tensor = tensor.masked_fill(mask=self.mask,value=self.newValue)
    return tensor

class Actor(nn.Module):
  def __init__(self):
    super().__init__()
    self.batchsize = batchSize
    self.action_dist = 27
    self.action_spec = (3,9)
    self.mask = Mask()

    self.input = nn.LazyLinear(9)
    self.conv1 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv2 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv3 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv4 = nn.LazyConv2d(self.batchsize,(3,3))
    self.output = nn.LazyLinear(self.action_dist)
    
  def forward(self,x):
    x = F.relu(self.input(x))
    x = self.conv1(x)
    x = F.relu(self.conv2(x))
    x = self.conv3(x)
    x = F.relu(self.conv4(x))
    x = torch.flatten(x,1,2)
    x = F.relu(self.output(x))
    x = torch.unflatten(x,-1,(self.action_spec))
    x = self.mask.apply(x)
    return F.softmax(x,-1)


class Critic(nn.Module):
  ...

In [None]:
import random
from torch.distributions import Categorical

class collector:
    def __init__(self,totalFrame,batchSize):
        assert totalFrame % batchSize == 0 , f"TotalFrame / batchSize should yield 0"
        assert totalFrame < len(modifiableCells)*3 ,f"The memory lenght should be less than an episodes"
        
        self.state = easyBoard.clone().unsqueeze(0)
        self.totalFrame = totalFrame
        self.batchSize = batchSize
        self.env = Env()
        self.actor = Actor()
        self.pointer = 0
        self.envData = []
        self.networkData = []
        
    def rollout(self):
        for _ in range(self.totalFrame):
            dist = Categorical(self.actor.forward(self.state))
            sample = dist.sample()
            logProb = dist.log_prob(sample)
            action = sample.tolist()[0]
            self.networkData.append(logProb)

            dataPoint = self.env.step(action,self.state)
            self.envData.append(dataPoint)

        for sublist,logs in zip(self.envData,self.networkData): # puting the logProb into the sublist of self.envData
            sublist.append(logs)

        random.shuffle(self.envData)
        return self.envData

    def sample(self):
        output = self.envData[self.pointer : self.pointer + self.batchSize]
        self.pointer += self.batchSize
        return output
    
    def clear(self):
        self.networkData = []
        self.envData = []


In [None]:
t = collector(4,2)
x = t.rollout()
 

In [None]:
t.sample()