In [2]:
from interpreting_neurons_utils import *

In [4]:
import random
import numpy as np
import math 
from random import choice
import statistics 

In [None]:
@dataclass
class FuzzyTrainingArgs():
    # Which layer, and which positions in a game sequence to probe
    def __init__(self, run_number : int, layer : int, single_neuron=None):
        self.undersampling = True
        self.undersampling_factor = 1
        self.debug = False
        self.run_number = run_number
        self.layer: int = layer
        self.pos_start: int = 0
        self.pos_end: int = 59
        self.length: int = self.pos_end - self.pos_start
        self.max_epochs = 1
        self.single_neuron = single_neuron
        self.rules_count = 10
        self.neurons_count = 2048 if single_neuron is None else 1
        self.variables_count = 64 * (3 + 2 + 2)
        self.manual = False
        self.fuzzy_and = "mul"
        self.fuzzy_or = "max"
        self.spacity_factor = 10
        self.only_0_1_factor = 0
        self.initialization = "uniform"
        self.use_rule_weights = True
        self.manual_rules = None
        self.manual_weights = 10
        self.weight_decay_loss_func = "l1"

        self.num_games_train: int = 250
        self.batch_size: int = 2

        self.lr: float = 1e-2
        self.betas: Tuple[float, float] = (0.9, 0.99)
        self.wd: float = 0.1

        self.run_name: str = f"Fuzzy_{self.run_number}"
        self.full_run_name: str = f"Fuzzy_L{self.layer}_{self.run_number}"

        # wnadb
        self.wandb_project: str = f"Train Fuzzy Logic Gates Test"

In [None]:
class Trainer:
    def get_resid():
        with t.inference_mode():
            _, cache = model.run_with_cache(
                games_int[:, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name == f"blocks.{self.args.layer}.ln2.hook_normalized" or name == f"blocks.{self.args.layer}.mlp.hook_post"
            )
            # TODO: Make undersampling work for generell case
            resid_layer_norm = cache[f"blocks.{self.args.layer}.ln2.hook_normalized"][:, self.args.pos_start: self.args.pos_end]
            neuron_activations : Float[Tensor, "batch pos neurons"] = cache["mlp_post", self.args.layer][:, self.args.pos_start: self.args.pos_end]
            resid_layer_norm = einops.rearrange(resid_layer_norm, "batch pos d_model -> (batch pos) d_model")
            neuron_activations = einops.rearrange(neuron_activations, "batch pos neurons -> (batch pos) neurons")
            if self.args.single_neuron is not None:
                neuron_activations = neuron_activations[:, self.args.single_neuron]
                # neuron_activations = einops.rearrange(neuron_activations, "batch pos neurons -> (batch pos neurons)")
                if self.args.undersampling:
                    count_positive = (neuron_activations > 0).sum()
                    count_all = len(neuron_activations)
                    neuron_activations_ideces = neuron_activations.argsort(descending=True)
                    positive_indeces = neuron_activations_ideces[:count_positive]
                    negative_indeces = neuron_activations_ideces[t.randperm(count_all - count_positive).to(device)[:int(count_positive * self.args.undersampling_factor)] + count_positive]
                    indeces = t.cat([positive_indeces, negative_indeces])
                    neuron_activations = neuron_activations[indeces]
                    resid_layer_norm = resid_layer_norm[indeces]
            # resid = cache["resid_mid", layer][:, self.args.pos_start: self.args.pos_end]
        resid_layer_norm = resid_layer_norm.clone().detach().to(device)
        variables : Float[Tensor, "batch variables"] = get_variables(resid_layer_norm, self.args.layer)

In [3]:


# Cost Function    
def CalculateNumberOfErrors(sudoku):
    numberOfErrors = 0 
    for i in range (0,9):
        numberOfErrors += CalculateNumberOfErrorsRowColumn(i ,i ,sudoku)
    return(numberOfErrors)

def CalculateNumberOfErrorsRowColumn(row, column, sudoku):
    numberOfErrors = (9 - len(np.unique(sudoku[:,column]))) + (9 - len(np.unique(sudoku[row,:])))
    return(numberOfErrors)


def CreateList3x3Blocks ():
    finalListOfBlocks = []
    for r in range (0,9):
        tmpList = []
        block1 = [i + 3*((r)%3) for i in range(0,3)]
        block2 = [i + 3*math.trunc((r)/3) for i in range(0,3)]
        for x in block1:
            for y in block2:
                tmpList.append([x,y])
        finalListOfBlocks.append(tmpList)
    return(finalListOfBlocks)

def RandomlyFill3x3Blocks(sudoku, listOfBlocks):
    for block in listOfBlocks:
        for box in block:
            if sudoku[box[0],box[1]] == 0:
                currentBlock = sudoku[block[0][0]:(block[-1][0]+1),block[0][1]:(block[-1][1]+1)]
                sudoku[box[0],box[1]] = choice([i for i in range(1,10) if i not in currentBlock])
    return sudoku

def SumOfOneBlock (sudoku, oneBlock):
    finalSum = 0
    for box in oneBlock:
        finalSum += sudoku[box[0], box[1]]
    return(finalSum)

def TwoRandomBoxesWithinBlock(fixedSudoku, block):
    while (1):
        firstBox = random.choice(block)
        secondBox = choice([box for box in block if box is not firstBox ])

        if fixedSudoku[firstBox[0], firstBox[1]] != 1 and fixedSudoku[secondBox[0], secondBox[1]] != 1:
            return([firstBox, secondBox])

def FlipBoxes(sudoku, boxesToFlip):
    proposedSudoku = np.copy(sudoku)
    placeHolder = proposedSudoku[boxesToFlip[0][0], boxesToFlip[0][1]]
    proposedSudoku[boxesToFlip[0][0], boxesToFlip[0][1]] = proposedSudoku[boxesToFlip[1][0], boxesToFlip[1][1]]
    proposedSudoku[boxesToFlip[1][0], boxesToFlip[1][1]] = placeHolder
    return (proposedSudoku)

def ProposedState (sudoku, fixedSudoku, listOfBlocks):
    randomBlock = random.choice(listOfBlocks)

    if SumOfOneBlock(fixedSudoku, randomBlock) > 6:  
        return(sudoku, 1, 1)
    boxesToFlip = TwoRandomBoxesWithinBlock(fixedSudoku, randomBlock)
    proposedSudoku = FlipBoxes(sudoku,  boxesToFlip)
    return([proposedSudoku, boxesToFlip])

def ChooseNewState (currentSudoku, fixedSudoku, listOfBlocks, sigma):
    proposal = ProposedState(currentSudoku, fixedSudoku, listOfBlocks)
    newSudoku = proposal[0]
    boxesToCheck = proposal[1]
    currentCost = CalculateNumberOfErrorsRowColumn(boxesToCheck[0][0], boxesToCheck[0][1], currentSudoku) + CalculateNumberOfErrorsRowColumn(boxesToCheck[1][0], boxesToCheck[1][1], currentSudoku)
    newCost = CalculateNumberOfErrorsRowColumn(boxesToCheck[0][0], boxesToCheck[0][1], newSudoku) + CalculateNumberOfErrorsRowColumn(boxesToCheck[1][0], boxesToCheck[1][1], newSudoku)
    # currentCost = CalculateNumberOfErrors(currentSudoku)
    # newCost = CalculateNumberOfErrors(newSudoku)
    costDifference = newCost - currentCost
    rho = math.exp(-costDifference/sigma)
    if(np.random.uniform(1,0,1) < rho):
        return([newSudoku, costDifference])
    return([currentSudoku, 0])


def ChooseNumberOfItterations(fixed_sudoku):
    numberOfItterations = 0
    for i in range (0,9):
        for j in range (0,9):
            if fixed_sudoku[i,j] != 0:
                numberOfItterations += 1
    return numberOfItterations

def CalculateInitialSigma (sudoku, fixedSudoku, listOfBlocks):
    listOfDifferences = []
    tmpSudoku = sudoku
    for i in range(1,10):
        tmpSudoku = ProposedState(tmpSudoku, fixedSudoku, listOfBlocks)[0]
        listOfDifferences.append(CalculateNumberOfErrors(tmpSudoku))
    return (statistics.pstdev(listOfDifferences))


def solveSudoku (sudoku):
    f = open("demofile2.txt", "a")
    solutionFound = 0
    while (solutionFound == 0):
        decreaseFactor = 0.99
        stuckCount = 0
        fixedSudoku = np.copy(sudoku)
        PrintSudoku(sudoku)
        FixSudokuValues(fixedSudoku)
        listOfBlocks = CreateList3x3Blocks()
        tmpSudoku = RandomlyFill3x3Blocks(sudoku, listOfBlocks)
        sigma = CalculateInitialSigma(sudoku, fixedSudoku, listOfBlocks)
        score = CalculateNumberOfErrors(tmpSudoku)
        itterations = ChooseNumberOfItterations(fixedSudoku)
        if score <= 0:
            solutionFound = 1

        while solutionFound == 0:
            previousScore = score
            for i in range (0, itterations):
                newState = ChooseNewState(tmpSudoku, fixedSudoku, listOfBlocks, sigma)
                tmpSudoku = newState[0]
                scoreDiff = newState[1]
                score += scoreDiff
                print(score)
                f.write(str(score) + '\n')
                if score <= 0:
                    solutionFound = 1
                    break

            sigma *= decreaseFactor
            if score <= 0:
                solutionFound = 1
                break
            if score >= previousScore:
                stuckCount += 1
            else:
                stuckCount = 0
            if (stuckCount > 80):
                sigma += 2
            if(CalculateNumberOfErrors(tmpSudoku)==0):
                PrintSudoku(tmpSudoku)
                break
    f.close()
    return(tmpSudoku)

solution = solveSudoku(sudoku)
print(CalculateNumberOfErrors(solution))
PrintSudoku(solution)



0 2 4 | 0 0 7 | 0 0 0 
6 0 0 | 0 0 0 | 0 0 0 
0 0 3 | 6 8 0 | 4 1 5 
---------------------
4 3 1 | 0 0 5 | 0 0 0 
5 0 0 | 0 0 0 | 0 3 2 
7 9 0 | 0 0 0 | 0 6 0 
---------------------
2 0 9 | 7 1 0 | 8 0 0 
0 4 0 | 0 9 3 | 0 0 0 
3 1 0 | 0 0 4 | 7 5 0 
40
40
38
38
37
37
35
34
34
33
33
35
36
37
37
35
35
34
34
34
35
35
35
33
35
33
33
33
33
35
35
35
35
33
35
35
37
37
38
37
37
35
35
35
35
34
34
34
34
34
31
29
29
25
25
28
28
28
27
31
31
31
31
32
30
29
30
30
30
30
27
27
27
27
27
27
25
25
25
26
27
27
27
27
28
27
28
28
27
27
27
27
28
28
27
26
26
26
26
26
26
26
26
26
26
26
26
26
26
26
26
26
27
27
27
27
30
29
27
28
28
29
27
29
29
29
28
28
31
31
31
31
32
32
32
32
33
33
33
36
36
40
42
43
41
41
41
41
41
41
42
42
41
43
42
42
43
43
43
41
40
41
41
41
41
42
42
41
40
38
37
37
39
39
38
38
36
36
36
35
35
33
32
33
33
33
33
33
32
32
32
31
30
30
29
29
31
31
32
32
33
32
31
31
30
30
30
30
30
31
31
31
31
32
32
32
32
31
34
34
35
35
35
35
35
33
32
32
32
32
32
32
33
34
33
34
34
34
36
38
37
35
35
35
34
35
34
34
35


In [None]:
# objective function
def objective(x):
	return 0

# define range for input
bounds = asarray([[-5.0, 5.0]])