In [9]:
import torch
from collections import Counter


In [19]:
def unit_step_with_rand(x):
    return torch.where(x >= torch.rand(1), torch.tensor(1.0), torch.tensor(0.0))

def transform(tensor: torch.Tensor):
    return unit_step_with_rand(torch.sigmoid(tensor))

def number_from_tensor(tensor: torch.Tensor):
    powers_of_two = torch.vander(torch.Tensor([2]), N=tensor.size(dim=0)+1, increasing=True)
    scalar_mul = tensor * torch.reshape(powers_of_two, (-1,))[1:]
    return 1 + scalar_mul.sum()

In [71]:
class MyIntegerFactorizationModel:

    def __init__(self, F: int, fitting_parameter: float):
        self.F = F
        length = F.bit_length()
        self.P = (length - 1) // 2
        self.Q = length - 2 - self.P
        self.eternal_tensor = torch.ones(self.P + self.Q, requires_grad=True)
        self.tensor_collector = Counter()
        self.fitting_parameter = fitting_parameter

    def energy_function(self, tensor: torch.Tensor) -> torch.Tensor:
        return self.fitting_parameter * (number_from_tensor(tensor[:self.P]) * number_from_tensor(tensor[self.P:]) - self.F) ** 2

    def calculate_gradient(self):
        energy = self.energy_function(self.eternal_tensor)
        energy.backward(torch.ones(energy.shape))
        self.eternal_tensor.retain_grad()
        return self.eternal_tensor.grad


    def evaluate(self):
        self.manage_counting()
        gradient = self.calculate_gradient()
        index = torch.randint(self.P + self.Q, (1, ))
        trans = transform(-gradient[index])
        with torch.no_grad():
            self.eternal_tensor.data[index] = trans

    def manage_counting(self):
        first_number = number_from_tensor(self.eternal_tensor[:self.P])
        second_number = number_from_tensor(self.eternal_tensor[self.P:])
        self.tensor_collector[(first_number.item(), second_number.item())] += 1


In [84]:
integer_factorization = MyIntegerFactorizationModel(35, 5)

for i in range(4000):
    integer_factorization.evaluate()
print(integer_factorization.tensor_collector)

Counter({(7.0, 7.0): 1311, (7.0, 5.0): 548, (5.0, 7.0): 534, (7.0, 3.0): 492, (3.0, 7.0): 446, (1.0, 7.0): 199, (7.0, 1.0): 179, (3.0, 5.0): 73, (5.0, 5.0): 72, (5.0, 3.0): 44, (3.0, 3.0): 41, (3.0, 1.0): 29, (1.0, 5.0): 11, (5.0, 1.0): 11, (1.0, 3.0): 7, (1.0, 1.0): 3})
