In [9]:
import torch
from collections import Counter


In [19]:
def unit_step_with_rand(x):
    return torch.where(x >= torch.rand(1), torch.tensor(1.0), torch.tensor(0.0))

def transform(tensor: torch.Tensor):
    return unit_step_with_rand(torch.sigmoid(tensor))

def number_from_tensor(tensor: torch.Tensor):
    powers_of_two = torch.vander(torch.Tensor([2]), N=tensor.size(dim=0)+1, increasing=True)
    scalar_mul = tensor * torch.reshape(powers_of_two, (-1,))[1:]
    return 1 + scalar_mul.sum()

In [46]:
class MyIntegerFactorizationModel:

    def __init__(self, F: int):
        self.F = F
        length = F.bit_length()
        self.P = (length - 1) // 2
        self.Q = length - 2 - self.P
        self.eternal_tensor = torch.ones(self.P + self.Q, requires_grad=True)
        self.tensor_collector = Counter()
        self.other_tensor_collector = Counter()

    def energy_function(self, tensor: torch.Tensor, i_0) -> torch.Tensor:
        return i_0 * (number_from_tensor(tensor[:self.P]) * number_from_tensor(tensor[self.P:]) - self.F) ** 2

    def calculate_gradient(self):
        energy = self.energy_function(self.eternal_tensor, i_0=0.5)
        energy.backward(torch.ones(energy.shape))
        self.eternal_tensor.retain_grad()
        return self.eternal_tensor.grad


    def evaluate(self, times: int, fitting_parameter=0.5):
        for i in range(times):
            self.manage_counting()
            gradient = self.calculate_gradient()
            index = torch.randint(self.P + self.Q, (1, ))
            trans = transform(-gradient[index])
            with torch.no_grad():
                self.eternal_tensor.data[index] = trans
    def manage_counting(self):
        first_number = number_from_tensor(self.eternal_tensor[:self.P])
        second_number = number_from_tensor(self.eternal_tensor[self.P:])
        self.tensor_collector[first_number.item()] += 1
        self.other_tensor_collector[second_number.item()] += 1
    # def bit_fluctuating(self):


In [50]:
integer_factorization = MyIntegerFactorizationModel(35)

integer_factorization.evaluate(4000)
print(integer_factorization.tensor_collector)
print(integer_factorization.other_tensor_collector)

Counter({7.0: 2391, 5.0: 779, 3.0: 653, 1.0: 177})
Counter({7.0: 2646, 3.0: 594, 5.0: 564, 1.0: 196})
