<a href="https://colab.research.google.com/github/adeotti/sudoku-env/blob/main/M9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython.display import clear_output
from puzzle import easyBoard,solution
import torch
import math

In [None]:
modifiableCells = []
for i,x in enumerate(easyBoard):
    for y in range(9): 
        if x[y] == 0: 
            modifiableCells.append((i,y))

In [None]:
def actionCheck(action:tuple):
    x,y,value = action
    assert x in range(0,10), f"{x} is too big or too small idk..."
    assert y in range(0,10), f"{y} is too big or too small idk..."
    assert value in range(1,10), f"{value} is too big or too small idk..."

def region(index:tuple|list,board: torch.Tensor): # return the region (row,column,block) of a cell
    x,y = index
    xlist = board[x].tolist()
    xlist.pop(y)

    ylist = [element[y].tolist() for element in board]
    ylist.pop(x)

    #block
    n = int(math.sqrt(9))
    ix,iy = (x//n)* n , (y//n)* n
    block = torch.flatten(board[ix:ix+n , iy:iy+n]).tolist()
    local_row = x - ix
    local_col = y - iy
    action_index = local_row * n + local_col
    block_ = [num for idx, num in enumerate(block) if idx != action_index]

    #output
    Region = [xlist,ylist,block_]
    Region = [item for sublist in Region for item in sublist]
    return Region


class solver: 
    def __init__(self,state:torch.Tensor,modCells:list):
        self.board = state.clone()
        self.solution = solution.clone()
        self.modCells = modCells
        self.maxStep = len(modCells)*3
         
    def domain(self,idx:tuple|list) -> list :
        Region = region(idx,self.board)
        Region = set([item for item in Region if item != 0]) 

        domain_ = set(range(1,10)) 
        TrueDomain = list(domain_ - Region)
        return TrueDomain
    
    def collector(self):
        queu = []
        for element in self.modCells:
            queu.append({element : self.domain(element)})
        return queu
    
    def isSolvable(self) -> bool: 
        count = 0
        while True:
            self.__init__(self.board,self.modCells)
            data = self.collector()
            for dictt in data:
                for k,v in dictt.items():
                    if len(v) == 1:
                        self.board[k] = v[0]
            count+=1
            if len(data) == 0:
                break
            else:
                if count > self.maxStep:
                    break

        diff = (self.board == solution)
        diff = (diff == True).sum().item()
    
        if diff == solution.numel(): # if all True cells = 81 :
            return True
        else:
            return False
        

In [None]:
class Env:
    def __init__(self):
        self.modifiableCells = modifiableCells.copy()

    def step(self,action : tuple|list,state:torch.Tensor):
        actionCheck(action)
        self.action = action
        x,y,value = self.action
        reward = self.rewardFunction(action,state)
        
        if reward > 0:
            state.squeeze(0)[x][y] = value
        done = torch.equal(solution,state)  
        return [
                state, \
                torch.tensor([reward]),\
                torch.tensor([done]),  \
                torch.tensor([action])
        ]
           
    def rewardFunction(self,action:tuple|list,board:torch.Tensor):
        """ 
        This will call the solver method to check if the board is solvable after a cell is filled.
        This fill a copy of the given board so the result here does not affect the original state
        if the board is solvable then the index of the value (x,y) is removed from the list of modifiables cells
        """
        reward = 0
        x,y,value = action
        board = board.squeeze(0).clone()
        copyList = self.modifiableCells.copy()

        board[x][y] = value
    
        if (x,y) not in copyList:
            return 0
        
        copyList.remove((x,y))
        Solver = solver(board.clone(),copyList)

        diff = (board == solution) # the difference between the state and the solution
        conflicts = (diff == False).sum().item() * 0.1

        if Solver.isSolvable():
            reward = round((conflicts + 0.5),2)
            self.modifiableCells.remove((x,y))
        else:
            reward = -conflicts
           
        return reward 
    

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

batchSize = 1
lr = 0.001

class Mask: 
  # This will alter the softmax distribution so value in [x,y,value] != 0 
  def __init__(self):
    self.newValue = -float("inf")

  def apply(self,tensor : torch.FloatTensor):
    self.mask = torch.zeros_like(tensor,dtype=torch.bool)
    self.mask[-1,-1,0] = True
    tensor = tensor.masked_fill(mask=self.mask,value=self.newValue)
    return tensor

class Actor(nn.Module):
  def __init__(self):
    super().__init__()
    self.batchsize = batchSize
    self.action_dist = 27
    self.action_spec = (3,9)
    self.mask = Mask()

    self.inputt = nn.LazyLinear(9)
    self.conv1 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv2 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv3 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv4 = nn.LazyConv2d(self.batchsize,(3,3))
    self.output = nn.LazyLinear(self.action_dist)

    self.optimizer = Adam(self.parameters(),lr = lr)
    
  def forward(self,x):
    x = F.relu(self.inputt(x))
    x = self.conv1(x)
    x = F.relu(self.conv2(x))
    x = self.conv3(x)
    x = F.relu(self.conv4(x))
    x = torch.flatten(x,1,2)
    x = F.relu(self.output(x))
    x = torch.unflatten(x,-1,(self.action_spec))
    x = self.mask.apply(x)
    return F.softmax(x,-1)

class Critic(nn.Module):
  def __init__(self):
    super().__init__()
    self.batchsize = batchSize
    self.input = nn.LazyLinear(9)
    self.conv1 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv2 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv3 = nn.LazyConv2d(self.batchsize,(3,3))
    self.conv4 = nn.LazyConv2d(self.batchsize,(3,3))

    self.optimizer = Adam(self.parameters(),lr = lr)
  
  def forward(self,x):
    x = self.input(x)
    x = F.relu(self.conv1(x))
    x = self.conv2(x)
    x = F.relu(self.conv3(x))
    x = self.conv4(x) # -> shape [1,1,1]
    return torch.flatten(x)


In [None]:
a = Actor()
a(torch.rand((1,9,9),dtype=torch.float))
b = Critic()
b(torch.rand((1,9,9),dtype=torch.float))
clear_output()

In [None]:
import random
from torch.distributions import Categorical

class collector:
    def __init__(self,totalFrame,batchSize):
        assert totalFrame % batchSize == 0 , f"TotalFrame / batchSize should yield 0"
        assert totalFrame < len(modifiableCells)*3 ,f"The memory lenght should be less than an episodes"
        
        self.state = easyBoard.clone().unsqueeze(0)
        self.totalFrame = totalFrame
        self.batchSize = batchSize
        self.env = Env()
        self.actor = Actor()
        self.critic = Critic()
        self.pointer = 0
        self.reward = []
        self.valueDAta = []
        self.envData = []
        self.networkData = []
        
    def rollout(self):
        self.clear()
        for _ in range(self.totalFrame):
            value = self.critic.forward(self.state)
            self.valueDAta.append(value)

            dist = Categorical(self.actor.forward(self.state))
            sample = dist.sample()
            logProb = dist.log_prob(sample)
            action = sample.tolist()[0]
            self.networkData.append(logProb)

            dataPoint = self.env.step(action,self.state)
            self.envData.append(dataPoint)

        for sublist,logs in zip(self.envData,self.networkData): # puting the logProb into the sublist of self.envData
            sublist.append(logs)

        for sublist,value in zip(self.envData,self.valueDAta): # putting the return in the sublists of self.envData
            sublist.append(value)

        random.shuffle(self.envData) # important here !

        states = [item[0] for item in self.envData] 
        rewards = [item[1] for item in self.envData]
        dones =  [item[2] for item in self.envData]
        actions = [item[3] for item in self.envData]
        oldProbs = [item[4] for item in self.envData]
        values = [item[5] for item in self.envData]
        
        return states,rewards,dones,actions,oldProbs,values

    def sample(self):
        output = self.envData[self.pointer : self.pointer + self.batchSize]
        self.pointer += self.batchSize
        return output
    
    def clear(self):
        self.networkData = []
        self.envData = []
        self.valueDAta = []


In [None]:
env = collector(4,2)
data = env.rollout()
data

$$
 
{\large

\begin{align}

&\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \\
& \hat{A_t} = \delta_t + (\gamma \lambda)\delta_{t+1}\\
&GAE : \hat{A_t} = \delta_t + (\gamma \lambda)\hat{A}_{t+1} \\[3em]

&L_\text{critic} = \frac{1}{N} \sum_t \left( V_t - V_t^\text{target} \right)^2 \\
& V_t^\text{target} = \hat{A_t} + V(s_t) \\[3em]

&L^{CPI} = \mathbb{\hat{E_{t}}}
\begin{bmatrix}
\frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta old}(a_t | s_t)}\hat{A_t}  \\
\end{bmatrix} \\

&\hspace{2em} = \mathbb{\hat{E_t}}
\begin{bmatrix}
r_t(\theta)\hat{A_t}  \\
\end{bmatrix} \\[1em]

&L^{CLIP} = \mathbb{\hat{E_t}}
\begin{bmatrix}
min(r_t(\theta) \hat{A_t},clip(r_t(\theta),1-\epsilon,1+\epsilon) \hat{A_t})
\end{bmatrix}\\[3em]

&Total Loss : L_t(\theta) = \mathbb{\hat{E_t}}
\begin{bmatrix}
L_t^{CLIP}{\theta} - c_1L_t^{critic}(\theta)
\end{bmatrix}\\

\end{align}
}
$$


In [None]:
def GAE(rewards,values):
    gamma = 0.9
    llambda = 0.9
    TDList = []
    AtList = []
    pointer = 1

    for d in rewards:
        try:
            TD = d + gamma*values[pointer] - values[pointer-1]
            TDList.append(TD)
        except(IndexError):
            TD = d + gamma - values[pointer-1]
            TDList.append(TD)
        pointer+=1

    a_t = 0  
    for td in reversed(TDList):
        a_t = td + (gamma * llambda) * a_t
        AtList.append(a_t)
    AtList.reverse()
    return AtList


def criticLoss(advantages = None,values = None):
    if not advantages is None and not values is None:
        L = []
        for item in range(len(advantages)):
            vTarget = 0
            vTarget = advantages[item] + values[item]
            loss = (values[item] - vTarget)**2
            L.append(loss)
        L = sum(list(map(int,L)))/len(L)
        return torch.tensor([L],dtype=torch.float)
    return None


In [None]:
#torch.autograd.set_detect_anomaly(True)
from tqdm import tqdm
import sys

class Agent:
    def __init__(self,epochs,epsilon):
        self.policy = Actor()
        self.value = Critic()
        self.memory = collector(totalFrame=4,batchSize=2)
        self.valueLossfunction = criticLoss
        self.epochs = epochs
        self.epsilon = epsilon
        self.c1 = 0.5 # the weight of the critic loss in the Total loss formula

    def learn(self):
        for i in tqdm(range(self.epochs),total=self.epochs):
            states,rewards,dones,actions,oldProbs,values = self.memory.rollout()

            #print(len(oldProbs))
            advantages = GAE(rewards,values)
        
            criticLoss = self.valueLossfunction(advantages,values)
            
            newProbs = [] # computing the new log prob 
            state_action = list(zip(states,actions))
            
            for s,v in state_action:
                probs = self.policy.forward(s)
                dist = Categorical(probs)
                np = dist.log_prob(v)
                newProbs.append(np)

            ratio = torch.stack(newProbs).exp()/torch.stack(oldProbs).exp()

            actorLosslist = []
            for i in range(len(advantages)):
                ratioAdvantage = ratio*advantages[i]
                clippedRatio = (torch.clamp(ratio,(1-self.epsilon),(1+self.epsilon))*advantages[i])
                policyLoss = torch.min(ratioAdvantage,clippedRatio)
                actorLosslist.append(policyLoss)
            actorLoss = torch.mean(torch.stack(actorLosslist))
            
            totalLoss = actorLoss + self.c1*criticLoss # actorLoss + (weight critic loss * critic loss)
        
            self.policy.optimizer.zero_grad()
            self.value.optimizer.zero_grad()
            totalLoss.backward(retain_graph=True)
            self.policy.optimizer.step()
            self.value.optimizer.step()

        self.memory.clear() # clear after each epochs
            

In [None]:
z = Agent(epochs=4,epsilon=0.2)
z.learn()

In [None]:
data[0]

In [None]:
rewards = [item[1] for item in data] # tensor
values = [item[-1] for item in data] # tensor
advantages = GAE(rewards,values) # tensor
CLoss = criticLoss(advantages,values) # tensor

In [None]:
actions = actions = [item[3] for item in data]
state = [item[0] for item in data] 
state_action = list(zip(state,actions))
newProb = []
a = Actor()

for el in state_action:
    probs = a.forward(el[0])
    dist = Categorical(probs)
 
    nP = dist.log_prob(el[1])
    newProb.append(nP)

In [None]:
oldProbs = [item[4] for item in data]

In [None]:
ratio = torch.stack(newProb).exp()/torch.stack(oldProbs)
epsilon = 0.2
actorLoss = []
for i in range(len(advantages)):
    ratioAdvantage = ratio*advantages[i]
    clippedRatio = (torch.clamp(ratio,(1-epsilon),(1+epsilon))*advantages[i])
    policyLoss = torch.min(ratioAdvantage,clippedRatio)
    actorLoss.append(policyLoss)

In [None]:
actorLoss = torch.mean(torch.stack(actorLoss))