<a href="https://colab.research.google.com/github/adeotti/sudoku-env/blob/main/M9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from IPython.display import clear_output
from puzzle import easyBoard,solution,testBoard,testSolution
import torch
import math

In [None]:
def modi(tensor) -> list: # returns modifiables cells index of a board and a mask
    modlist = []
    for i,x in enumerate(tensor):
        for y in range(9): 
            if x[y] == 0: 
                modlist.append((i,y))
    return modlist

def region(index:tuple|list,board: torch.Tensor): # return the region (row,column,block) of a cell
    x,y = index
    xlist = board[x].tolist()
    xlist.pop(y)

    ylist = [element[y].tolist() for element in board]
    ylist.pop(x)

    #block
    n = int(math.sqrt(9))
    ix,iy = (x//n)* n , (y//n)* n
    block = torch.flatten(board[ix:ix+n , iy:iy+n]).tolist()
    local_row = x - ix
    local_col = y - iy
    action_index = local_row * n + local_col
    block_ = [num for idx, num in enumerate(block) if idx != action_index]

    #output
    Region = [xlist,ylist,block_]
    Region = [item for sublist in Region for item in sublist]
    return Region

class solver: 
    def __init__(self,state:torch.Tensor,modCells:list):
        self.board = state
        self.solution = solution
        self.modCells = modCells
        self.maxStep = len(modCells)*3
         
    def domain(self,idx:tuple|list) -> list :
        Region = region(idx,self.board)
        Region = set([item for item in Region if item != 0]) 

        domain_ = set(range(1,10)) 
        TrueDomain = list(domain_ - Region)
        return TrueDomain
    
    def collector(self):
        queu = []
        for element in self.modCells:
            queu.append({element : self.domain(element)})
        return queu
    
    def isSolvable(self) -> bool: 
        count = 0
        while True:
            self.__init__(self.board,self.modCells)
            data = self.collector()
            for dictt in data:
                for k,v in dictt.items():
                    if len(v) == 1:
                        self.board[k] = v[0]
            count+=1
            if len(data) == 0:
                break
            else:
                if count > self.maxStep:
                    break

        diff = (self.board == solution)
        diff = (diff == True).sum().item()
    
        if diff == solution.numel(): # if all True cells = 81 :
            return True
        else:
            return False
        

In [12]:
modifiableCells = modi(easyBoard)
#maskModcells = modifiableCells.copy() # training mask 

In [None]:
class Env:
    def __init__(self):
        self.modifiableCells = modifiableCells.copy()
        self.solution = solution
        self.state = easyBoard.clone()
    
    def reset(self):
        self.state  = easyBoard.clone()
        self.modifiableCells = modifiableCells.copy()

    def step(self,action : tuple|list):#,state:torch.Tensor):
        self.action = action
        x,y,value = self.action
        reward,conflicts = self.rewardFunction(action,self.state)
        if reward > 0:
            self.state[x][y] = value
            self.modifiableCells.remove((x,y))
        done = torch.equal(solution,self.state)  
        return [
                self.state, \
                torch.tensor([reward],dtype=torch.float),\
                torch.tensor([done]),  \
                torch.tensor([action]),\
                conflicts
                ]
           
    def rewardFunction(self,action:tuple|list,board:torch.Tensor):
        """ 
        This will call the solver method to check if the board is solvable after a cell is filled.
        This fill a copy of the given board so the result here does not affect the original state
        if the board is solvable then the index of the value (x,y) is removed from the list of modifiables cells
        """
        reward = 0
        x,y,value = action
        board = board.clone() 
        copyList = self.modifiableCells.copy()
        if not (x,y) in copyList:
            diff = (board == self.solution) 
            conflicts = (diff == False).sum().item() 
            return 0,conflicts
        board[x][y] = value
        conflicts = ((board == self.solution) == False).sum().to(float) 
        copyList.remove((x,y)) # remove (x,y) before passing it to Solver
        Solver = solver(board.clone(),copyList)
        if Solver.isSolvable():
            reward = (conflicts/2)*0.1 + 5 
        else:
            reward = -((conflicts/2)*0.1 + 5)
        return reward,conflicts.floor()
    

In [None]:
import torch.nn as nn
import torch.nn.functional as F

batchSize = 1
lr = 0.0003

class mask: # altering softmax output so x and y = {0,8} and value = {1,9}
  def __init__(self):
    self.newValue = -float("inf")

  def apply(self,tensor : torch.FloatTensor):
    self.mask = torch.zeros_like(tensor,dtype=torch.bool)
    self.mask[0,-1] = True
    self.mask[1,-1] = True
    self.mask[-1,0] = True
    tensor = tensor.masked_fill(mask=self.mask,value=self.newValue)
    return tensor

class network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.LazyConv2d(1,(1,1))
    self.conv2 = nn.LazyConv2d(1,(1,1))
    self.conv3 = nn.LazyConv2d(1,(1,1))

    self.linear1 = nn.LazyLinear(9)
    self.linear2 = nn.LazyLinear(9)
    self.linear3 = nn.LazyLinear(9)

    self.policy_mask = mask()
    self.policy_head = nn.LazyLinear(30)
    self.value_head = nn.LazyLinear(1)
    
    self.optimizer = torch.optim.Adam(self.parameters(),lr=lr)
    
  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = self.conv2(x)
    x = F.relu(self.conv3(x))
    x = torch.flatten(x,start_dim=1)
    x = self.linear1(x)
    x = F.relu(self.linear2(x))
    x = self.linear3(x)
    distibution = F.relu(self.policy_head(x)).reshape(3,10)
    distibution = self.policy_mask.apply(distibution)
    value = self.value_head(x)
    return F.softmax(distibution,-1),value

d = network()
d.forward(torch.rand((1,9,9),dtype=torch.float))
clear_output()

In [None]:
import random
from torch.distributions import Categorical
import gc

class collector:
    def __init__(self,totalFrame,batchSize):
        assert totalFrame % batchSize == 0 , f"TotalFrame / batchSize should yield 0"
        assert totalFrame < len(modifiableCells)*3 ,f"The memory lenght should be less than an episodes"
        self.totalFrame = totalFrame
        self.batchSize = batchSize
        self.env = Env()
        self.network = network()
        self.pointer = 0
        self.data = []
     
    def rollout(self):
        self.clear() 
        while not len(self.data) == self.totalFrame : 
            if len(self.env.modifiableCells) < 5 :
                self.env.reset()
            softmax_dist,_value = self.network.forward(self.env.state.unsqueeze(0).clone())
            dist = Categorical(softmax_dist)
            sample = dist.sample()
            _log_prob = dist.log_prob(sample)
            action = sample.tolist()
            assert len(action) == 3 , f" action is {action}"
            x,y,_ = action
            if (x,y) in self.env.modifiableCells:
                _state,_reward,_done,_action,_conflicts = self.env.step(action,self.env.state)
                self.data.append(_state,_reward,_action,_done,_conflicts,_log_prob,_value)
        random.shuffle(self.data)  
        return self.data  
    
    def extend(self,args):
        self.data = args

    def sample(self):
        output = self.data[self.pointer : self.pointer + self.batchSize]
        self.pointer += self.batchSize
        states,rewards,actions,dones,conflicts,logs,values,advantages = zip(*output)
        return states,rewards,dones,actions,logs,values,advantages,conflicts
    
    def clear(self):
        gc.collect()
        self.pointer = 0 
        self.updatedData = []
        self.networkData = []
        self.envData = []
        self.valueDAta = []


$$
 
{\large

\begin{align}

&\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \\
& \hat{A_t} = \delta_t + (\gamma \lambda)\delta_{t+1}\\
&GAE : \hat{A_t} = \delta_t + (\gamma \lambda)\hat{A}_{t+1} \\[3em]

&L_\text{critic} = \frac{1}{N} \sum_t \left( V_t - V_t^\text{target} \right)^2 \\
& V_t^\text{target} = \hat{A_t} + V(s_t) \\[3em]

&L^{CPI} = \mathbb{\hat{E_{t}}}
\begin{bmatrix}
\frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta old}(a_t | s_t)}\hat{A_t}  \\
\end{bmatrix} \\

&\hspace{2em} = \mathbb{\hat{E_t}}
\begin{bmatrix}
r_t(\theta)\hat{A_t}  \\
\end{bmatrix} \\[1em]

&L^{CLIP} = \mathbb{\hat{E_t}}
\begin{bmatrix}
min(r_t(\theta) \hat{A_t},clip(r_t(\theta),1-\epsilon,1+\epsilon) \hat{A_t})
\end{bmatrix}\\[3em]

&Total Loss : L_t(\theta) = \mathbb{\hat{E_t}}
\begin{bmatrix}
L_t^{CLIP}{\theta} - c_1L_t^{critic}(\theta)
\end{bmatrix}\\

\end{align}
}
$$


In [None]:
def GAE(data):
    gamma = 0.99
    llambda = 0.99
    TDList = []
    AtList = []
    pointer = 1

    rewards = [item[1] for item in data]
    values = [item[5] for item in data]

    for d in rewards:
        try:
            TD = d + gamma*values[pointer] - values[pointer-1]
            TDList.append(TD)
        except(IndexError):
            TD = d + gamma - values[pointer-1]
            TDList.append(TD)
        pointer+=1

    a_t = 0  
    for td in reversed(TDList):
        a_t = td + (gamma * llambda) * a_t
        AtList.append(a_t)
    AtList.reverse()

    for sub,v in zip(data,AtList):
        sub.append(v)
    return data

def criticLoss(advantages = None,values = None):
    if not advantages is None and not values is None:
        L = []
        for item in range(len(advantages)):
            vTarget = 0
            vTarget = advantages[item] + values[item]
            loss = (values[item] - vTarget)**2
            L.append(loss)
        L = torch.mean(torch.stack(L))
        return L
    return None



In [None]:
torch.autograd.set_detect_anomaly(True)
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

class Agent:
    def __init__(self):
        self.network = network()
        self.env = Env()
        self.totalFrame = 20
        self.batchSize = 20
        self.epochs = 100
        self.epsilon = 0.2
        self.c1 = 0.5 # Weight of the critic loss in the total loss 
        self.memory = collector(self.totalFrame,self.batchSize )
        self.valueLossfunction = criticLoss
        self.writter = SummaryWriter("./data/")

    def save(self):
        torch.save(self.network.state_dict(),"./data/policy.pth")

    def learn(self):
        for i in tqdm(range(self.epochs),total=self.epochs):
            self.memory.clear() 
            roollout = self.memory.rollout()
            advantage = GAE(roollout)
            self.memory.extend(advantage)
            for _ in range(1): # minibatches
                states,rewards,dones,actions,oldProbs,values,advantages,conf = self.memory.sample()
                for _ in range(10): # epochs optim
                    conf = torch.mean(torch.stack(conf)).floor()
                    r = torch.mean(torch.stack(rewards)) 
                    criticLoss = self.valueLossfunction(advantages,values)
                    #c = torch.mean(torch.pow(v - vtarget))
                    newProbs = [] # new log prob 
                    for s,a in zip(states,actions):
                        probs = self.network.forward(s)
                        dist = Categorical(probs)
                        np = dist.log_prob(a)
                        newProbs.append(np)
                    ratio = torch.exp(torch.stack(newProbs)) / torch.exp(torch.stack(oldProbs))
                    actorLosslist = []
                    for i in range(len(advantages)):
                        ratioAdvantage = ratio*advantages[i]
                        clippedRatio = (torch.clamp(ratio,(1-self.epsilon),(1+self.epsilon))*advantages[i])
                        policyLoss = torch.min(ratioAdvantage,clippedRatio)
                        actorLosslist.append(policyLoss)
                    actorLoss = -torch.mean(torch.stack(actorLosslist))
                    totalLoss = actorLoss + self.c1*criticLoss # actorLoss + (weight critic loss  * critic loss
                    self.network.optimizer.zero_grad()
                    totalLoss.backward(retain_graph=True)
                    self.network.optimizer.step()
            self.writter.add_scalar("main/conflicts",conf)
            self.writter.add_scalar("main/Loss",totalLoss)
            self.writter.add_scalar("main/reward",r)
        self.save()


In [None]:
z = Agent()
z.learn()

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir "./"