In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from vonenet.utils import gabor_kernel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GaborFilterBank(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), requires_grad=False
        )

    def initialize(self, sf, theta, sigx, sigy, phase, seed=None):
        if seed is not None:
            torch.manual_seed(seed)
        
        random_channel = torch.randint(0, self.in_channels, (self.out_channels,))
        for i in range(self.out_channels):
            self.weight[i, random_channel[i]] = gabor_kernel(
                frequency=sf[i], sigma_x=sigx[i], sigma_y=sigy[i],
                theta=theta[i], offset=phase[i], ks=self.kernel_size
            )

    def forward(self, x):
        return F.conv2d(x, self.weight, None, stride=self.stride, padding=self.padding)

class NCRFModule(nn.Module):
    def __init__(self, gabor_bank, modulation_strength=0.5, surround_kernel_size=3):
        super(NCRFModule, self).__init__()
        self.gabor_bank = gabor_bank
        self.modulation_strength = nn.Parameter(torch.tensor(modulation_strength, dtype=torch.float32))
        self.surround_kernel_size = surround_kernel_size

    def forward(self, x):
        center_response = self.gabor_bank(x)
        surround_response = F.avg_pool2d(
            center_response, kernel_size=self.surround_kernel_size, stride=1, padding=self.surround_kernel_size // 2
        ) * self.modulation_strength
        modulated_response = center_response / (1 + surround_response)
        return modulated_response

class V1Processing(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=9, noise_mode='gaussian', noise_level=0.01):
        super(V1Processing, self).__init__()
        self.gabor_bank = GaborFilterBank(in_channels, out_channels, kernel_size)
        self.ncrf = NCRFModule(self.gabor_bank)
        
        # Simple and Complex Cells
        self.simple = nn.ReLU()
        self.complex = nn.Identity()
        
        # Noise settings (noise level reduced from 0.1 to 0.01)
        self.noise_mode = noise_mode
        self.noise_level = noise_level
        self.fixed_noise = None

        # ---- FIX: Initialize Gabor filters with random parameters ----
        sf    = torch.rand(out_channels) * 0.5 + 0.1    # frequencies between 0.1 and 0.6
        theta = torch.rand(out_channels) * np.pi          # orientations between 0 and pi
        sigx  = torch.rand(out_channels) * 2 + 1.0          # sigma between 1 and 3
        sigy  = torch.rand(out_channels) * 2 + 1.0
        phase = torch.rand(out_channels) * 2 * np.pi        # phase between 0 and 2pi
        self.gabor_bank.initialize(sf, theta, sigx, sigy, phase)

    def set_noise_mode(self, mode='gaussian', level=0.01):
        self.noise_mode = mode
        self.noise_level = level

    def fix_noise(self, batch_size, shape):
        self.fixed_noise = torch.randn(batch_size, *shape, device=device) * self.noise_level
    
    def noise_f(self, x):
        if self.noise_mode == 'gaussian':
            return x + torch.randn_like(x) * self.noise_level
        elif self.noise_mode == 'neuronal':
            return x / (1 + self.noise_level * torch.randn_like(x).abs())
        else:
            return x

    def forward(self, x):
        simple_response = self.simple(self.ncrf(x))
        complex_response = torch.sqrt(simple_response ** 2 + self.ncrf(x) ** 2)
        complex_response = self.complex(complex_response)
        
        output = self.noise_f(complex_response)
        return output

In [None]:
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import alexnet

# Import custom Gabor filter bank and processing module from above

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

class AlexNetWithGabor(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(AlexNetWithGabor, self).__init__()
        self.gabor = V1Processing(in_channels, out_channels=256, kernel_size=15).to(device)
        self.alexnet = alexnet(num_classes=num_classes).to(device)
        
        # Modify first layer to accept 64 Gabor channels instead of 3
        self.alexnet.features[0] = nn.Conv2d(256, 64, kernel_size=11, stride=4, padding=2)

    def forward(self, x):
        x = self.gabor(x)
        x = self.alexnet(x)
        return x


model = AlexNetWithGabor()
model= nn.DataParallel(model).to(device)
criterion = nn.CrossEntropyLoss()
# ---- FIX: Lowered learning rate for Adam (from 0.1 to 1e-3) ----
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, trainloader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            # ---- FIX: Gradient clipping to prevent exploding gradients ----
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}")

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [None]:
train(model, trainloader, criterion, optimizer, epochs=10)
test(model, testloader)