In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from vonenet.utils import gabor_kernel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GaborFilterBank(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, device, stride=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size)
        self.stride = (stride, stride)
        self.padding = (kernel_size // 2, kernel_size//2)

        self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device)
        

    def initialize(self, sf, theta, sigx, sigy, phase, seed=None):
        if seed is not None:
            torch.manual_seed(seed)
        
        random_channel = torch.randint(0, self.in_channels, (self.out_channels,))
        for i in range(self.out_channels):
            self.weight[i, random_channel[i]] = gabor_kernel(
                frequency=sf[i], sigma_x=sigx[i], sigma_y=sigy[i],
                theta=theta[i], offset=phase[i], ks=self.kernel_size[0]
            )
        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, x):
        # [256, 3, 15, 15] 
        #print("x",x.shape, x.device)
        #print("weight", self.weight.shape, self.weight.device)
        return F.conv2d(x, self.weight, None, stride=self.stride, padding=self.padding)

class NCRFModule(nn.Module):
    def __init__(self, gabor_bank, modulation_strength=0.5, surround_kernel_size=3):
        super(NCRFModule, self).__init__()
        self.gabor_bank = gabor_bank
        self.modulation_strength = modulation_strength
        self.surround_kernel_size = surround_kernel_size

    def forward(self, x):
        #print("here")
        #center_response = self.gabor_bank(x)
        #print("center_response")
        surround_response = F.avg_pool2d(
            x, kernel_size=self.surround_kernel_size, stride=1, padding=self.surround_kernel_size // 2
        ) * self.modulation_strength
        modulated_response = x / (1 + surround_response)
        return modulated_response

class V1Processing(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9, stride=4, noise_mode='neuronal', noise_level=0.07, noise_scale=0.35, device='cpu'):
        super(V1Processing, self).__init__()
        # self.gabor_bank = GaborFilterBank(in_channels, out_channels, kernel_size)
        self.simple_conv_q0 = GaborFilterBank(in_channels, out_channels, kernel_size, device, stride)
        self.simple_conv_q1 = GaborFilterBank(in_channels, out_channels, kernel_size, device, stride)
        self.ncrf_q0 = NCRFModule(self.simple_conv_q0)
        self.ncrf_q1 = NCRFModule(self.simple_conv_q1)
        
        # Simple and Complex Cells

        # ---- FIX: Initialize Gabor filters with random parameters ----
        sf    = torch.rand(out_channels) * 0.5 + 0.1    # frequencies between 0.1 and 0.6
        theta = torch.rand(out_channels) * np.pi          # orientations between 0 and pi
        sigx  = torch.rand(out_channels) * 2 + 1.0          # sigma between 1 and 3
        sigy  = torch.rand(out_channels) * 2 + 1.0
        phase = torch.rand(out_channels) * 2 * np.pi        # phase between 0 and 2pi
        self.simple_conv_q0.initialize(sf, theta, sigx, sigy, phase)
        self.simple_conv_q0.initialize(sf, theta, sigx, sigy, phase + np.pi/2)

        self.simple = nn.ReLU()
        self.complex = nn.Identity()
        self.gabors = nn.Identity()
        self.noise = nn.ReLU()
        self.output = nn.Identity()
        
        # Noise settings (noise level reduced from 0.1 to 0.01)
        self.noise_mode = noise_mode
        self.noise_level = noise_level
        self.noise_scale = noise_scale
        self.fixed_noise = None
        self.simple_channels = 128 
        self.complex_channels = 128 
        self.k_exc = 25


    def set_noise_mode(self, mode='gaussian', level=0.01):
        self.noise_mode = mode
        self.noise_level = level

    def fix_noise(self, batch_size, shape):
        self.fixed_noise = torch.randn(batch_size, *shape, device=device) * self.noise_level
    
    def noise_f(self, x):
        if self.noise_mode == 'gaussian':
            return x + torch.randn_like(x) * self.noise_level
        elif self.noise_mode == 'neuronal':
            eps = 10e-5
            x *= self.noise_scale
            x += self.noise_level
            if self.fixed_noise is not None:
                x += self.fixed_noise * torch.sqrt(F.relu(x.clone()) + eps)
            else:
                x += torch.distributions.normal.Normal(torch.zeros_like(x), scale=1).rsample() * \
                     torch.sqrt(F.relu(x.clone()) + eps)
            x -= self.noise_level
            x /= self.noise_scale
            return self.noise(x)
        else:
            return x

    def forward(self, x):
        '''
        s_q0 = self.ncrf_q0(self.simple_conv_q0(x))
        s_q1 = self.ncrf_q1(self.simple_conv_q1(x))
        c = self.complex(torch.sqrt(s_q0[:, self.simple_channels:, :, :] ** 2 +
                                    s_q1[:, self.simple_channels:, :, :] ** 2) / np.sqrt(2))
        s = self.simple(s_q0[:, 0:self.simple_channels, :, :])
        response = self.gabors(self.k_exc * torch.cat((s, c), 1))
        
        simple_response = self.noise_f(response)
        output = self.output(simple_response)
        print(output.shape) # torch.Size([32, 3, 14, 14]) 
        return output
        '''
        #print(f"Input shape: {x.shape}")  # Print initial input shape
        #print(self.simple_conv_q0)
        #print(self.ncrf_q0)
        s_q0 = self.simple_conv_q0(x)
        #print(f"Shape after simple_conv_q0 + NCRF: {s_q0.shape}")
    
        s_q1 = self.simple_conv_q1(x)
        #print(f"Shape after simple_conv_q1 + NCRF: {s_q1.shape}")
    
        c = self.complex(torch.sqrt(s_q0[:, self.simple_channels:, :, :] ** 2 +
                                    s_q1[:, self.simple_channels:, :, :] ** 2) / np.sqrt(2))
        #print(f"Shape of complex cell response: {c.shape}")
    
        s = self.simple(s_q0[:, 0:self.simple_channels, :, :])
        #print(f"Shape of simple cell response: {s.shape}")
    
        response = self.gabors(self.k_exc * torch.cat((s, c), 1))
        #print(f"Shape after concatenating simple & complex responses: {response.shape}")
    
        simple_response = self.noise_f(response)
        #print(f"Shape after noise function: {simple_response.shape}")
    
        output = self.output(simple_response)
        #print(f"Final output shape: {output.shape}")

        output = self.ncrf_q0(simple_response)
        #print(f"Final1 output shape: {output.shape}")
    
        return output

In [2]:
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import alexnet
from collections import OrderedDict

# Import custom Gabor filter bank and processing module from above

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
class AlexNetBackEnd(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(64, 192, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class AlexNetWithGabor(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(AlexNetWithGabor, self).__init__()
        gabor = V1Processing(in_channels, out_channels=256, kernel_size=15, stride=4, device='cuda:0').to(device)
        bottleneck = nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False)
        print(bottleneck)
        nn.init.kaiming_normal_(bottleneck.weight, mode='fan_out', nonlinearity='relu')
        #self.alexnet = alexnet(num_classes=num_classes).to(device)
        
        # Modify first layer to accept 64 Gabor channels instead of 3
        #self.alexnet.features[0] = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2)
        model_back_end = AlexNetBackEnd() 
        print(model_back_end)
        self.model = nn.Sequential(OrderedDict([
            ('ncrf', gabor),
            ('bottleneck',bottleneck),
            ('model', model_back_end)
        ]))

    def forward(self, x):
        x = self.model(x)
        return x

model = AlexNetWithGabor().to(device)
print(model)
#model= nn.DataParallel(model).to(device)
criterion = nn.CrossEntropyLoss()
# ---- FIX: Lowered learning rate for Adam (from 0.1 to 1e-3) ----
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, trainloader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            # ---- FIX: Gradient clipping to prevent exploding gradients ----
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}")

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

Files already downloaded and verified
Files already downloaded and verified


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
AlexNetBackEnd(
  (features): Sequential(
    (0): Conv2d(64, 192, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=12544, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, b

In [None]:
train(model, trainloader, criterion, optimizer, epochs=30)
test(model, testloader)