<a href="https://colab.research.google.com/github/adeotti/sudoku-env/blob/main/M9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython.display import clear_output
from puzzle import easyBoard,solution
import torch
import math

In [None]:
modifiableCells = []
for i,x in enumerate(easyBoard):
    for y in range(9): 
        if x[y] == 0: 
            modifiableCells.append((i,y))

In [None]:
def actionCheck(action:tuple):
    x,y,value = action
    assert x in range(0,10), f"{x} is too big or too small idk..."
    assert y in range(0,10), f"{y} is too big or too small idk..."
    assert value in range(1,10), f"{value} is too big or too small idk..."

def region(index:tuple|list,board: torch.Tensor): # return the region (row,column,block) of a cell
    x,y = index
    xlist = board[x].tolist()
    xlist.pop(y)

    ylist = [element[y].tolist() for element in board]
    ylist.pop(x)

    #block
    n = int(math.sqrt(9))
    ix,iy = (x//n)* n , (y//n)* n
    block = torch.flatten(board[ix:ix+n , iy:iy+n]).tolist()
    local_row = x - ix
    local_col = y - iy
    action_index = local_row * n + local_col
    block_ = [num for idx, num in enumerate(block) if idx != action_index]

    #output
    Region = [xlist,ylist,block_]
    Region = [item for sublist in Region for item in sublist]
    return Region


class solver: 
    def __init__(self,state:torch.Tensor,modCells:list):
        self.board = state
        self.solution = solution
        self.modCells = modCells
        self.maxStep = len(modCells)*3
         
    def domain(self,idx:tuple|list) -> list :
        Region = region(idx,self.board)
        Region = set([item for item in Region if item != 0]) 

        domain_ = set(range(1,10)) 
        TrueDomain = list(domain_ - Region)
        return TrueDomain
    
    def collector(self):
        queu = []
        for element in self.modCells:
            queu.append({element : self.domain(element)})
        return queu
    
    def isSolvable(self) -> bool: 
        count = 0
        while True:
            self.__init__(self.board,self.modCells)
            data = self.collector()
            for dictt in data:
                for k,v in dictt.items():
                    if len(v) == 1:
                        self.board[k] = v[0]
            count+=1
            if len(data) == 0:
                break
            else:
                if count > self.maxStep:
                    break

        diff = (self.board == solution)
        diff = (diff == True).sum().item()
    
        if diff == solution.numel(): # if all True cells = 81 :
            return True
        else:
            return False
        

In [None]:
class Env:
    def __init__(self):
        self.modifiableCells = modifiableCells.copy()

    def step(self,action : tuple|list,state:torch.Tensor):
        actionCheck(action)
        self.action = action
        x,y,value = self.action
        reward = self.rewardFunction(action,state)
        if reward > 0:
            state[x][y] = value
        done = torch.equal(solution,state)  
        return [
                state, \
                torch.tensor([reward],dtype=torch.float),\
                torch.tensor([done]),  \
                torch.tensor([action])
        ]
           
    def rewardFunction(self,action:tuple|list,board:torch.Tensor):
        """ 
        This will call the solver method to check if the board is solvable after a cell is filled.
        This fill a copy of the given board so the result here does not affect the original state
        if the board is solvable then the index of the value (x,y) is removed from the list of modifiables cells
        """
        reward = 0
        x,y,value = action
        board = board.clone() 
        copyList = self.modifiableCells.copy()
        board[x][y] = value
    
        if (x,y) not in copyList:
            return 0
        
        copyList.remove((x,y))
        Solver = solver(board.clone(),copyList)

        diff = (board == solution) # the difference between the state and the solution
        conflicts = (diff == False).sum().item() #* 0.1

        if Solver.isSolvable():
            reward = conflicts + 10 #round((conflicts + 5),2)  # 0.5
            self.modifiableCells.remove((x,y))
        else:
            reward = -conflicts
           
        return reward 
    

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

batchSize = 1
lr = 0.001

class Mask: 
  # This will alter the softmax distribution so value in [x,y,value] != 0 
  def __init__(self):
    self.newValue = -float("inf")

  def apply(self,tensor : torch.FloatTensor):
    self.mask = torch.zeros_like(tensor,dtype=torch.bool)
    self.mask[-1,-1,0] = True
    tensor = tensor.masked_fill(mask=self.mask,value=self.newValue)
    return tensor

class Actor(nn.Module):
  def __init__(self):
    super().__init__()
    self.batchsize = batchSize
    self.action_dist = 27
    self.action_spec = (3,9)
    self.mask = Mask()

    self.inputt = nn.LazyLinear(9)
    self.conv1 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv2 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv3 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv4 = nn.LazyConv2d(self.batchsize,(3,3))
    self.output = nn.LazyLinear(self.action_dist)

    self.optimizer = Adam(self.parameters(),lr = lr)
    
  def forward(self,x:torch.Tensor):
    if not x.shape == torch.Size([1,9,9]) :
      x = x.unsqueeze(0)
    assert x.shape == torch.Size([1,9,9])

    x = self.conv1(x)
    x = F.relu(self.conv2(x))
    x = self.conv3(x)
    x = F.relu(self.conv4(x))
    x = torch.flatten(x,1,2)
    x = F.relu(self.output(x))
    x = torch.unflatten(x,-1,(self.action_spec))
    x = self.mask.apply(x)
    return F.softmax(x,-1)

class Critic(nn.Module):
  def __init__(self):
    super().__init__()
    self.batchsize = batchSize
    self.input = nn.LazyLinear(9)
    self.conv1 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv2 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv3 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv4 = nn.LazyConv2d(self.batchsize,(3,3))

    self.optimizer = Adam(self.parameters(),lr = lr)
  
  def forward(self,x):
    if not x.shape == torch.Size([1,9,9]) :
      x = x.unsqueeze(0)
    assert x.shape == torch.Size([1,9,9])

    x = self.input(x)
    x = F.relu(self.conv1(x))
    x = self.conv2(x)
    x = F.relu(self.conv3(x))
    x = self.conv4(x) # -> shape [1,1,1]
    return torch.flatten(x)
  
Actor().forward(torch.rand((1,9,9),dtype=torch.float))
Critic().forward(torch.rand((1,9,9),dtype=torch.float))
clear_output()


In [None]:
import random
from torch.distributions import Categorical
import sys

class collector:
    def __init__(self,totalFrame,batchSize):
        assert totalFrame % batchSize == 0 , f"TotalFrame / batchSize should yield 0"
        assert totalFrame < len(modifiableCells)*3 ,f"The memory lenght should be less than an episodes"
        
        self.state = easyBoard.clone()
        self.totalFrame = totalFrame
        self.batchSize = batchSize
        self.env = Env()
        self.actor = Actor()
        self.critic = Critic()

        self.pointer = 0
        self.updatedData = []
        self.valueDAta = []
        self.envData = []
        self.networkData = []
        
    def rollout(self):
        self.clear() # clearing before each epochs
        for _ in range(self.totalFrame):
            value = self.critic.forward(self.state.clone())
            self.valueDAta.append(value)

            dist = Categorical(self.actor.forward(self.state.clone()))
            sample = dist.sample()
            logProb = dist.log_prob(sample)
            action = sample.tolist()[0]
            self.networkData.append(logProb)

            dataPoint = self.env.step(action,self.state)
            self.envData.append(dataPoint)

        for sublist,logs in zip(self.envData,self.networkData): # puting the logProb into the sublist of self.envData
            sublist.append(logs)

        for sublist,value in zip(self.envData,self.valueDAta): # putting the return in the sublists of self.envData
            sublist.append(value)

        random.shuffle(self.envData) # important here !
        return self.envData  #[states,rewards,dones,actions,oldProbs,values]
    
    def extend(self,args):
        self.updatedData = args

    def sample(self):
        output = self.updatedData[self.pointer : self.pointer + self.batchSize]
        self.pointer += self.batchSize

        states = [item[0] for item in output] 
        rewards = [item[1] for item in output]
        dones =  [item[2] for item in output]
        actions = [item[3] for item in output]
        oldProbs = [item[4] for item in output]
        values = [item[5] for item in output]
        advanatages = [item[6] for item in output]

        return states,rewards,dones,actions,oldProbs,values,advanatages
    
    def clear(self):
        self.pointer = 0 
        self.updatedData = []
        self.networkData = []
        self.envData = []
        self.valueDAta = []


$$
 
{\large

\begin{align}

&\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \\
& \hat{A_t} = \delta_t + (\gamma \lambda)\delta_{t+1}\\
&GAE : \hat{A_t} = \delta_t + (\gamma \lambda)\hat{A}_{t+1} \\[3em]

&L_\text{critic} = \frac{1}{N} \sum_t \left( V_t - V_t^\text{target} \right)^2 \\
& V_t^\text{target} = \hat{A_t} + V(s_t) \\[3em]

&L^{CPI} = \mathbb{\hat{E_{t}}}
\begin{bmatrix}
\frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta old}(a_t | s_t)}\hat{A_t}  \\
\end{bmatrix} \\

&\hspace{2em} = \mathbb{\hat{E_t}}
\begin{bmatrix}
r_t(\theta)\hat{A_t}  \\
\end{bmatrix} \\[1em]

&L^{CLIP} = \mathbb{\hat{E_t}}
\begin{bmatrix}
min(r_t(\theta) \hat{A_t},clip(r_t(\theta),1-\epsilon,1+\epsilon) \hat{A_t})
\end{bmatrix}\\[3em]

&Total Loss : L_t(\theta) = \mathbb{\hat{E_t}}
\begin{bmatrix}
L_t^{CLIP}{\theta} - c_1L_t^{critic}(\theta)
\end{bmatrix}\\

\end{align}
}
$$


In [None]:
def GAE(data):
    gamma = 0.9
    llambda = 0.9
    TDList = []
    AtList = []
    pointer = 1

    rewards = [item[1] for item in data]
    values = [item[5] for item in data]

    for d in rewards:
        try:
            TD = d + gamma*values[pointer] - values[pointer-1]
            TDList.append(TD)
        except(IndexError):
            TD = d + gamma - values[pointer-1]
            TDList.append(TD)
        pointer+=1

    a_t = 0  
    for td in reversed(TDList):
        a_t = td + (gamma * llambda) * a_t
        AtList.append(a_t)
    AtList.reverse()

    for sub,v in zip(data,AtList):
        sub.append(v)
    return data

def criticLoss(advantages = None,values = None):
    if not advantages is None and not values is None:
        L = []
        for item in range(len(advantages)):
            vTarget = 0
            vTarget = advantages[item] + values[item]
            loss = (values[item] - vTarget)**2
            L.append(loss)
        L = torch.mean(torch.stack(L))
        return L
    return None


In [None]:
torch.autograd.set_detect_anomaly(True)
from tqdm import tqdm
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter

class Agent:
    def __init__(self):
        self.policy = Actor()
        self.value = Critic()

        self.totalFrame = 100
        self.batchSize = 20
        self.epochs = 50
        self.epsilon = 0.2
        self.c1 = 0.5 # the weight of the critic loss in the Total loss formula

        self.memory = collector(self.totalFrame,self.batchSize )
        self.valueLossfunction = criticLoss
        self.debug = False
        self.writter = SummaryWriter("tdata/")
    
    def save(self):
        torch.save(self.policy.state_dict(),"tdata/policy.pth")

    def learn(self):
        for i in tqdm(range(self.epochs),total=self.epochs):
            self.memory.clear() 
            roollout = self.memory.rollout()
            advantage = GAE(roollout)
            self.memory.extend(advantage)

            for e in range(self.totalFrame//self.batchSize):
                states,rewards,dones,actions,oldProbs,values,advantages = self.memory.sample()
                r = torch.mean(torch.stack(rewards))
                d = torch.flatten(torch.stack(dones))
                if torch.any(d).item():
                    self.save()
                    print("weights saved")
                    break

                criticLoss = self.valueLossfunction(advantages,values)
                
                newProbs = [] # computing the new log prob 
                state_action = list(zip(states,actions))
                for s,a in state_action:
                    probs = self.policy.forward(s)
                    dist = Categorical(probs)
                    np = dist.log_prob(a)
                    newProbs.append(np)

                ratio = torch.exp(torch.stack(newProbs)) / torch.exp(torch.stack(oldProbs))

                actorLosslist = []
                for i in range(len(advantages)):
                    ratioAdvantage = ratio*advantages[i]
                    clippedRatio = (torch.clamp(ratio,(1-self.epsilon),(1+self.epsilon))*advantages[i])
                    policyLoss = torch.min(ratioAdvantage,clippedRatio)
                    actorLosslist.append(policyLoss)
                actorLoss = -torch.mean(torch.stack(actorLosslist))
                
                totalLoss = actorLoss + self.c1*criticLoss # actorLoss + (weight critic loss  * critic loss)

                if self.debug: # print the computation graph, for debugging purposes
                    graph = make_dot(totalLoss,params=dict(list(self.policy.named_parameters())))
                    graph.render()
                    self.debug = False

                self.policy.optimizer.zero_grad()
                self.value.optimizer.zero_grad()
                totalLoss.backward(retain_graph=True)
                self.policy.optimizer.step()
                self.value.optimizer.step()

            self.writter.add_scalar("main/Loss",totalLoss)
            self.writter.add_scalar("main/reward",r)
        
        self.save()
        

In [None]:
z = Agent()
z.learn()