In [None]:
from torchsummary import summary

# Define the MLP model
class MLP_HSIC_BP(nn.Module):
    def __init__(self):
        super(MLP_HSIC_BP, self).__init__()
        self.flatten = nn.Flatten()
        self.layer1 = nn.Linear(28*28, 512)
        self.layer2 = nn.Linear(512, 256)
        self.layer3 = nn.Linear(256, 128)
        self.layer4 = nn.Linear(128, 10)

    def forward(self, x):
        #x = self.flatten(x)
        x = x.view(x.size(0), -1)  # Flatten the input
        out1 = torch.relu(self.layer1(x))
        out2 = torch.relu(self.layer2(out1))
        out3 = torch.relu(self.layer3(out2))
        out4 = self.layer4(out3)
        return out1, out2, out3, out4

# Define the HSIC bottleneck class
class HSICBottleneck:
    def __init__(self, model, batch_size, lambda_0, sigma, lr=0.001):
        self.model = model
        self.batch_size = batch_size
        self.lambda_0 = lambda_0
        self.sigma = sigma
        self.lr = lr
        self.opt = optim.SGD(model.parameters(), lr)
        self.remember = []
        self.count = 0

    def step(self, input_data, labels):
        Kx = kernel_matrix(input_data, self.sigma)
        Ky = kernel_matrix(labels, self.sigma)
        total_loss = 0.
        self.opt.zero_grad()

        out1, _, out3, _ = self.model(input_data)

        Kz1 = kernel_matrix(out1, self.sigma)
        Kz3 = kernel_matrix(out3, self.sigma)

        loss1 = HSIC(Kz1, Kx, self.batch_size)
        loss3 = HSIC(Kz3, Kx, self.batch_size)
        loss = loss1 - self.lambda_0 * loss3

        total_loss += loss
        total_loss.backward()
        self.opt.step()

        self.remember.append(total_loss.item())
        return total_loss.item()

# Initialize the model, HSIC bottleneck, and other components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_hsic_bp = MLP_HSIC_BP().to(device)
summary(model_hsic_bp, input_size=(1,784 ))
hsic_bottleneck = HSICBottleneck(model_hsic_bp, batch_size=128, lambda_0=1.0, sigma=1.0, lr=0.001)
criterion = nn.CrossEntropyLoss()

num_epochs = 6
train_accuracy_history_hsic_bp = []
test_accuracy_history_hsic_bp = []
train_loss_history_hsic_bp = []

# Training loop
for epoch in range(num_epochs):
    model_hsic_bp.train()
    correct = 0
    total = 0
    total_loss = 0

    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        # Update HSIC layers
        hsic_bottleneck.step(data, targets)

        # Forward pass and standard backpropagation for BP layers
        outputs = model_hsic_bp(data)
        loss = criterion(outputs[-1], targets)

        optimizer_bp = optim.Adam(list(model_hsic_bp.layer2.parameters()) + list(model_hsic_bp.layer4.parameters()), lr=0.001)
        optimizer_bp.zero_grad()
        loss.backward()
        optimizer_bp.step()

        total_loss += loss.item()

        # Track training accuracy
        _, predicted = torch.max(outputs[-1].data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    train_accuracy = 100 * correct / total
    train_accuracy_history_hsic_bp.append(train_accuracy)
    train_loss_history_hsic_bp.append(total_loss / len(train_loader))

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

    # Testing loop
    model_hsic_bp.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model_hsic_bp(data)
            _, predicted = torch.max(outputs[-1].data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    test_accuracy = 100 * correct / total
    test_accuracy_history_hsic_bp.append(test_accuracy)

    print(f'Test Accuracy: {test_accuracy:.2f}%')
