In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class BrainInspiredVisualNetwork(nn.Module):
    def __init__(self, input_channels=1, base_features=32):
        super(BrainInspiredVisualNetwork, self).__init__()

        # V1 (Primary Visual Cortex) layer
        self.v1 = nn.Sequential(
            nn.Conv2d(input_channels, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features),
            nn.Conv2d(base_features, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # Thick Stripe layer
        self.thick_stripe = nn.Sequential(
            nn.Conv2d(base_features, base_features//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features//2)
        )

        # Interstripe layer
        self.interstripe = nn.Sequential(
            nn.Conv2d(base_features, base_features//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features//2)
        )

        # Thin Stripe layer
        self.thin_stripe = nn.Sequential(
            nn.Conv2d(base_features, base_features//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features//2)
        )

        # MT layer (Middle Temporal)
        self.mt = nn.Sequential(
            nn.Conv2d(base_features + base_features//2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # VIP layer (Ventral Intraparietal)
        self.vip = nn.Sequential(
            nn.Conv2d(base_features, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # MST layer (Medial Superior Temporal)
        self.mst = nn.Sequential(
            nn.Conv2d(base_features*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # V4 layer
        self.v4 = nn.Sequential(
            nn.Conv2d(base_features + base_features//2*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # LIP layer (Lateral Intraparietal)
        self.lip = nn.Sequential(
            nn.Conv2d(base_features*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # PIT layer (Posterior Inferotemporal)
        self.pit = nn.Sequential(
            nn.Conv2d(base_features*3, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # CIT layer (Central Inferotemporal)
        self.cit = nn.Sequential(
            nn.Conv2d(base_features*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # 7A layer
        self.layer_7a = nn.Sequential(
            nn.Conv2d(base_features*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features)
        )

        # AIT layer (Anterior Inferotemporal)
        self.ait = nn.Sequential(
            nn.Conv2d(base_features*2, base_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_features),
            nn.AdaptiveAvgPool2d(1)  # Global average pooling
        )

        # Classifier (final layer for MNIST)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(base_features, 10)
        )

    def forward(self, x):
        # Process through V1 (input to V1)
        v1_output = self.v1(x)

        # Process V1 outputs
        thick_stripe_output = self.thick_stripe(v1_output)
        interstripe_output = self.interstripe(v1_output)
        thin_stripe_output = self.thin_stripe(v1_output)

        # MT layer (thick_stripe, v1)
        mt_input = torch.cat([v1_output, thick_stripe_output], dim=1)
        mt_output = self.mt(mt_input)

        # VIP layer (MT)
        vip_output = self.vip(mt_output)

        # MST layer (MT, VIP)
        mst_input = torch.cat([mt_output, vip_output], dim=1)
        mst_output = self.mst(mst_input)

        # V4 layer (MT, interstripe, thin_stripe)
        v4_input = torch.cat([mt_output, interstripe_output, thin_stripe_output], dim=1)
        v4_output = self.v4(v4_input)

        # LIP layer (MST)
        lip_input = torch.cat([mst_output, v1_output], dim=1)  # Added V1 for better feature propagation
        lip_output = self.lip(lip_input)

        # PIT layer (V4, MST, LIP)
        pit_input = torch.cat([v4_output, mst_output, lip_output], dim=1)
        pit_output = self.pit(pit_input)

        # CIT layer (PIT, V4)
        cit_input = torch.cat([pit_output, v4_output], dim=1)
        cit_output = self.cit(cit_input)

        # 7A layer (LIP, MST)
        layer_7a_input = torch.cat([lip_output, mst_output], dim=1)
        layer_7a_output = self.layer_7a(layer_7a_input)

        # AIT layer (CIT, 7A)
        ait_input = torch.cat([cit_output, layer_7a_output], dim=1)
        ait_output = self.ait(ait_input)

        # Final classification
        output = self.classifier(ait_output)

        return output

# Training and evaluation functions
def train(model, train_loader, optimizer, criterion, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')

    return running_loss / len(train_loader), 100. * correct / total

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    test_loss /= len(test_loader)
    accuracy = 100. * correct / total

    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
    return test_loss, accuracy

def main():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Hyperparameters
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 10

    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model = BrainInspiredVisualNetwork(input_channels=1, base_features=32).to(device)
    print(model)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training history
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Train the model
    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)

        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

    # Plot training and validation loss and accuracy
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss over epochs')

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Accuracy over epochs')

    plt.tight_layout()
    plt.savefig('training_results.png')
    plt.show()

    # Save the model
    torch.save(model.state_dict(), 'brain_inspired_model.pth')
    print("Model saved successfully!")

if __name__ == "__main__":
    main()

Using device: cpu
BrainInspiredVisualNetwork(
  (v1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (thick_stripe): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (interstripe): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (thin_stripe): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(16,