In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
# Define a simple neural network model
class SimpleNN(nn.Module):
    def _init_(self):
        super(SimpleNN, self)._init_()
        self.fc1 = nn.Linear(28 * 28, 1024)  # Input layer to hidden layer
        self.relu = nn.ReLU()  # Activation function
        self.fc2 = nn.Linear(1024, 10)  # Hidden layer to output layer

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input images
        x1 = self.fc1(x)
        x2 = self.relu(x1)
        x = self.fc2(x2)
        return x, x1, x2

# Set the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset and create data loaders
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=20, shuffle=False)

# Initialize the model, loss function, and optimizer
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

# Training loop
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output, _, _ = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 1000 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}')

    # Test the model after each epoch
    model.eval()
    correct, total = 0, 0
    l1 = torch.tensor([]).to(device)
    u2 = torch.tensor([]).to(device)
    l2 = torch.tensor([]).to(device)
    l3 = torch.tensor([]).to(device)
    w1 = torch.tensor([]).to(device)
    w2 = torch.tensor([]).to(device)
    b1 = torch.tensor([]).to(device)
    b2 = torch.tensor([]).to(device)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, x1, x2 = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()


            l1 = data.view(data.size(0), -1)
            u2 = x1
            l2 = x2
            l3 = output
            w1=model.fc1.weight.data
            w2=model.fc2.weight.data
            b1=model.fc1.bias#.repeat(20,1)
            b2=model.fc2.bias#.repeat(20,1)

    t1 = ((torch.sum(l1 * l1, dim = 1)).mean() + (torch.sum(u2 * u2, dim = 1)).mean() + (torch.sum(l3 * l3, dim = 1)).mean()) * 0.5
    t2 = (torch.mm(torch.mm(l2,w1),l1.t()) * (torch.eye(20)).double().to(device)).mean() * 20 + (torch.mm(torch.mm(l3,w2),l2.t()) * (torch.eye(20)).double().to(device)).mean() * 20
    t3 = torch.sum(b1 * l2, dim=1).mean() + torch.sum(b2 * l3, dim=1).mean()
    print("Total Energy: ", t1 - t2 - t3)
    accuracy = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Test Accuracy: {100 * accuracy:.2f}%')

    # # Print the shapes of the activation values after each epoch
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of l1: {l1.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of u2: {u2.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of l2: {l2.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of l3: {l3.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of w1: {w1.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of w2: {w2.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of b1: {b1.shape}")
    # print(f"Epoch {epoch+1}/{num_epochs}, Shape of b2: {b2.shape}")

# Final test accuracy
print(f'Final Test Accuracy: {100 * accuracy:.2f}%')