In [None]:
import torch.nn as nn
import torch
from torchmetrics import HingeLoss, MeanSquaredError
from torchmetrics.classification import BinaryHingeLoss
import random
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from matplotlib import pyplot as plt
from datetime import datetime
from torch.autograd.functional import hessian

In [None]:
def data_generator(
        k: int = 3,
        n: int = 20,
        ):
    parity_bits = random.sample(range(n), k)
    num = 2 ** n
    x = torch.zeros((num, n), dtype=torch.float32)
    for i in range(num):
        x[i] = torch.tensor(
            list(map(int, bin(i)[2:].zfill(n))), dtype=torch.float
            )
    y = x[:, parity_bits].sum(dim=1) % 2
    # y = 2 * y - 1
    y = y.reshape(-1, 1)

    return x, y, parity_bits

In [None]:
class MyHingeLoss(nn.Module):

    def __init__(self):
        super(MyHingeLoss, self).__init__()

    def forward(self, output, target):
        y_hat = output
        y_true = target * 2 - 1
        hinge_loss = 1 - torch.mul(y_hat, y_true)
        hinge_loss = torch.clamp(hinge_loss, min=0)
        return (hinge_loss ** 2).mean()

In [None]:
class FFNN(nn.Module):
    def __init__(self, n, k: int = 3):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(out_features=k, in_features=n),
            nn.ReLU(),
            nn.Linear(out_features=1, in_features=k),
        )
        self.initialize_params(k, n)
        self.freeze_params()

    def initialize_params(self, k, n):
        with torch.no_grad():
            first_layer = self.network[0]
            first_layer.bias.data = -torch.arange(k).float() - 0.5
            # first_layer.weight.data = torch.ones(k, n) * k / n
            weight_matrix = torch.Tensor([[0.9, 0, 1.2],[0.7, -0.1, 1.1]]).float()
            first_layer.weight.data = weight_matrix
            second_layer = self.network[2]
            weights = torch.tensor(
                [((-1) ** i) * (2 + 4 * i) for i in range(k)],
                dtype=torch.float32
                )
            second_layer.weight.data = weights.view(1, -1)  # Shape: (k, 1)
            second_layer.bias.data = torch.Tensor([0])

    def freeze_params(self):
        self.network[0].bias.requires_grad = False
        self.network[0].weight.requires_grad = True
        self.network[2].weight.requires_grad = False
        self.network[2].bias.requires_grad = False

    def forward(self, x):
        return self.network(x)

In [None]:
length = 3
k = 2

In [None]:
def test(model, x, y):
    pred = model(x)
    predicted_classes = (pred >= 0.5).float()
    correct_predictions = (predicted_classes == y).sum()
    accuracy = correct_predictions / y.size(0)
    return accuracy.item()

In [None]:

epochs = 30000
# loss_fn = BinaryHingeLoss(squared=True)
loss_fn = HingeLoss(task="binary")
# loss_fn = MyHingeLoss()
# loss_fn = MeanSquaredError()
x, y, bits = data_generator(k, length)

In [None]:
bits

In [None]:
model = FFNN(length, k)

print(f"the weight in the first layer is {model.network[0].weight.data}")
print(f"the bias in the first layer is {model.network[0].bias.data}")

print(f"the weight in the second layer is {model.network[2].weight.data}")
print(f"the bias in the second layer is {model.network[2].bias.data}")

dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, shuffle=True)

# optimizer = torch.optim.SGD(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=4e-3
# )
optimizer = torch.optim.SGD(model.parameters(), lr = 3e-2)
model.train()
weight_history = []
loss_history = []
gradient_norm_history = []
for i in range(epochs):
    pred = model(x)
    loss = loss_fn(pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    current_weight = model.network[0].weight.detach().cpu().clone()
    weight_history.append(current_weight)
    if i % 200 == 0:
        grad_norm = model.network[0].weight.grad.norm().item()
        gradient_norm_history.append(grad_norm)
        loss_history.append(loss.item())
        print(f"epoch {i}: loss = {loss.item():.6f}")

model.eval()
print(test(model, x, y))
    
W = torch.stack(weight_history).numpy()

fig, axes = plt.subplots(k, length, figsize=(10, 5))
for i in range(k):
    for j in range(length):
        axes[i, j].plot(W[:, i, j])
        axes[i, j].set_title(f"W[{i},{j}]")
        axes[i, j].set_xlabel("Epoch")
        axes[i, j].set_ylabel("Value")

# for j in range(length):
#     axes[j].plot(W[:, j])
#     axes[j].set_title(f"W[{j}]")
#     axes[j].set_xlabel("Epoch")
#     axes[j].set_ylabel("Value")

plt.tight_layout()
plt.show()

In [None]:
bits

In [None]:

epochs_range = range(len(loss_history))

fig, ax1 = plt.subplots()

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color='tab:blue')
ax1.plot(epochs_range, loss_history, label='Loss', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')

ax2 = ax1.twinx()
ax2.set_ylabel('Gradient Norm', color='tab:red')
ax2.plot(epochs_range, gradient_norm_history, label='Grad Norm', color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')

fig.tight_layout()
plt.title("Loss and Gradient Norm over Epochs")
plt.show()

In [None]:
params = list(model.network[0].weight.view(-1)) 

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_loss_landscape(ffnn_model, inputs, targets, weight_idx1, weight_idx2, 
                        weight_range=0.5, steps=30):

    with torch.no_grad():
        best_params = list(ffnn_model.network[0].weight.view(-1)).copy()
        base_tensor = ffnn_model.network[0].weight.view(-1).detach().clone()

    x_vals = np.linspace(-weight_range, weight_range, steps)
    y_vals = np.linspace(-weight_range, weight_range, steps)
    X, Y = np.meshgrid(x_vals, y_vals)
    Z = np.zeros_like(X)

    for i in range(steps):
        for j in range(steps):
            modified_weights = base_tensor.clone()

            modified_weights[weight_idx1] += x_vals[i]
            modified_weights[weight_idx2] += y_vals[j]

            with torch.no_grad():
                ffnn_model.network[0].weight.view(-1).copy_(modified_weights)

            pred = ffnn_model(inputs)
            loss = loss_fn(pred, targets)
            Z[j, i] = loss.item()

    with torch.no_grad():
        ffnn_model.network[0].weight.view(-1).copy_(base_tensor)

    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, Z, cmap='viridis')
    ax.set_xlabel(f'Weight {weight_idx1} offset')
    ax.set_ylabel(f'Weight {weight_idx2} offset')
    ax.set_zlabel('Loss')
    ax.set_title(f'Loss landscape for weights {weight_idx1} and {weight_idx2}')
    plt.show()


In [None]:
plot_loss_landscape(model, x, y, 0, 1)
plot_loss_landscape(model, x, y, 2, 3)
plot_loss_landscape(model, x, y, 4, 5)

In [None]:
import torch
from torch.autograd.functional import hessian
import numpy as np

def classify_critical_point(model, inputs, targets):
    model.eval()

    weight_tensor = model.network[0].weight
    weight_flat = weight_tensor.view(-1).detach().clone().requires_grad_(True)

    def loss_fn_for_hessian(w_flat):
        w_tensor = w_flat.view_as(weight_tensor)
        with torch.no_grad():
            model.network[0].weight.copy_(w_tensor)

        preds = model(inputs)
        loss = loss_fn(preds, targets)
        return loss

    H = hessian(loss_fn_for_hessian, weight_flat)
    H_np = H.detach().cpu().numpy()

    eigenvalues = np.linalg.eigvalsh(H_np)


    print(f"Top Hessian eigenvalue: {eigenvalues.max():.4e}")
    print(f"Smallest Hessian eigenvalue: {eigenvalues.min():.4e}")
    return eigenvalues


In [None]:
classify_critical_point(model, x, y)

In [None]:
pred[:10]

In [None]:
y[:10]

In [None]:
loss_fn = MyHingeLoss()
a = loss_fn(pred, y)
print(a.item())