In [None]:
import numpy as np
import torch
import itertools
from torch import nn
import copy
from tqdm import tqdm

In [None]:
import math
import numpy as np

def lz_complexity(s):
    i, k, l = 0, 1, 1
    k_max = 1
    n = len(s) - 1
    c = 1
    while True:
        if s[i + k - 1] == s[l + k - 1]:
            k = k + 1
            if l + k >= n - 1:
                c = c + 1
                break
        else:
            if k > k_max:
                k_max = k
            i = i + 1
            if i == l:
                c = c + 1
                l = l + k_max
                if l + 1 > n:
                    break
                else:
                    i = 0
                    k = 1
                    k_max = 1
            else:
                k = 1
    return c

def decimal(x):
    n = len(x)
    output = 0
    for i in range(len(x)):
        output += x[i]*2**(n-1-i)
    return output

def K_lemp_ziv(sequence):
    if (np.sum(sequence == 0) == len(sequence)) or (np.sum(sequence == 1) == len(sequence)) :

        out = math.log2(len(sequence))
    else:
        forward = sequence
        backward = sequence[::-1]

        out = math.log2(len(sequence))*(lz_complexity(forward) + lz_complexity(backward))/2

    return out

In [None]:
dim = 7
inputs = [[0, 1] for _ in range(dim)]
inputs = itertools.product(*inputs)
inputs = [i for i in inputs]
data = torch.Tensor(np.array(inputs))

In [None]:
class SimpleNN(nn.Module):
    def __init__(self, nl=2):
        super(SimpleNN, self).__init__()
        # self.flatten = nn.Flatten()
        self.nl = nl
        self.fc1 = nn.Linear(dim, 128, bias=False)
        self.relu = nn.ReLU()
        self.fcs = [nn.Linear(128, 128, bias=False) for _ in range(self.nl-1)]

        # self.relu = nn.Tanh()
        self.fc2 = nn.Linear(128, 1, bias=False)

    def forward(self, x):
        # x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        if self.nl>1:
            for fc in self.fcs:
                x = fc(x)
                x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
net = SimpleNN(nl=4)

resolution = 20
total = 100
data_out = {}
layer = net.fc1
center = False

@torch.no_grad()
def main(data_out):
    x_ax = torch.randn_like(layer.weight)
    y_ax = torch.randn_like(layer.weight)

    subyx = float(torch.sum(x_ax*y_ax) / (torch.sqrt(torch.sum(x_ax**2)) *torch.sqrt(torch.sum(y_ax**2))  ))
    y_ax = y_ax - subyx * x_ax
    y_ax = y_ax / torch.norm(y_ax) * torch.norm(x_ax)

    raw = copy.deepcopy(layer.weight)

    if center:
        subx = float(torch.sum(x_ax*layer.weight) / torch.sqrt(torch.sum(x_ax**2))  / torch.sqrt(torch.sum(layer.weight**2))     )
        suby = float(torch.sum(y_ax*layer.weight) / torch.sqrt(torch.sum(y_ax**2))/ torch.sqrt(torch.sum(layer.weight**2))  )
        layer.weight += (-1* subx * x_ax -1* suby * y_ax)

    for i, j in tqdm(itertools.product(range(-total, total), range(-total, total)), total=4*total**2):
        layer.weight += x_ax * i * (1/resolution/layer.weight.shape[1]) + y_ax * j * (1/resolution/layer.weight.shape[1])
        y = net(data)
        y_str = "".join(["1" if float(k)>0 else "0" for k in y.reshape(-1)])
        data_out[(i* (1/resolution/layer.weight.shape[1]),j* (1/resolution/layer.weight.shape[1]))]=y_str
        layer.weight = copy.deepcopy(raw)
    return data_out

data_out = main(data_out)
print(data_out)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
# from complexities2 import K_lemp_ziv

unique_strings = list(set(data_out.values()))
unique_strings_lz = {i:K_lemp_ziv(i) for i in unique_strings}
unique_strings_lz = {k: v for k, v in sorted(unique_strings_lz.items(), key=lambda item: item[1])}
unique_strings = list(unique_strings_lz.keys())
color_map = plt.cm.get_cmap('viridis', len(unique_strings))
colors = [color_map(i) for i in range(len(unique_strings))]
string_color_dict = dict(zip(unique_strings, colors))
string_color_dict_lz = dict(zip([K_lemp_ziv(i) for i in unique_strings], colors))

# Extract coordinates and corresponding strings
coordinates, strings = zip(*data_out.items())

# Map string to color
point_colors = [string_color_dict[string] for string in strings]

# Unpack coordinates
x, y = zip(*coordinates)

fig, ax = plt.subplots()

# Create scatter plot
scatter = ax.scatter(x, y, c=point_colors, s=100, linewidths=1, cmap='viridis', marker='s')

# Add legend
# legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=string) for string, color in string_color_dict_lz.items()]
# plt.legend(handles=legend_elements, loc='upper right')

cbar = plt.colorbar(scatter, ax=ax)

colours = [color for lz, color in string_color_dict_lz.items()]
lz = [lz for lz, color in string_color_dict_lz.items()]

num=len(colours)
if num > 10:
    skip = num // 10
cbar.set_ticks([i/(num-1) for i in range(num)][::skip])
cbar.set_ticklabels(lz[::skip])

# Show the plot
# fig.set_size_inches(10,10)
ax.set_box_aspect(1)
ax.set_xlabel(r'$\theta_1$')
ax.set_ylabel(r'$\theta_2$')
# ax.set_yticks([-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3])
# ax.set_xticks([-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3])
fig.set_size_inches(5,4)
ax.set_title(f'{len(string_color_dict)} unique functions')
# plt.savefig('wow.png', bbox_inches='tight', dpi=300)