In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Looking in indexes: https://download.pytorch.org/whl/cu126


In [9]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [19]:
import warnings
warnings.filterwarnings('ignore')

import torch, os
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models
from spikingjelly.clock_driven import neuron, functional, surrogate

In [12]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'test': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}


In [4]:
data_dir = '/kaggle/input/fundus-pytorch'
image_datasets = {x: datasets.ImageFolder(
    root=f"{data_dir}/{x}",
    transform=data_transforms[x]
) for x in ['train', 'test', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True) for x in ['train', 'test', 'val']}

In [20]:
class GlaucomaSNN(nn.Module):
    def __init__(self, T=4):
        super(GlaucomaSNN, self).__init__()
        self.T = T  # Number of time steps

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),  # Spiking activation
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            nn.AdaptiveAvgPool2d((8, 8)),  # Adaptive pooling
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, x):
        mem = 0
        for t in range(self.T):  # Iterate over time steps
            out = self.conv_layers(x)
            out = self.fc_layers(out)
            mem += out  # Accumulate membrane potential
        
        return mem / self.T  # Average across time steps

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GlaucomaSNN()

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model = model.to(device)
print(device)

Using 2 GPUs!
cuda


In [22]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [23]:
# Training loop
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 30)
        
        for batch_idx, (inputs, labels) in enumerate(dataloaders['train']):
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx+1}/{len(dataloaders['train'])} | Loss: {loss.item():.4f}")
        
        epoch_loss = running_loss / len(dataloaders['train'])
        epoch_acc = 100. * correct / total
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.2f}%\n")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss /= len(dataloaders['val'])
        val_acc = 100. * val_correct / val_total
        print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_acc:.2f}%\n")

train_model(model, dataloaders, criterion, optimizer, num_epochs=50)

Epoch 1/50
------------------------------
Batch 10/270 | Loss: 0.7704
Batch 20/270 | Loss: 0.6718
Batch 30/270 | Loss: 0.5533
Batch 40/270 | Loss: 0.6580
Batch 50/270 | Loss: 0.6096
Batch 60/270 | Loss: 0.7000
Batch 70/270 | Loss: 0.6421
Batch 80/270 | Loss: 0.6982
Batch 90/270 | Loss: 0.6372
Batch 100/270 | Loss: 0.6140
Batch 110/270 | Loss: 0.6289
Batch 120/270 | Loss: 0.5838
Batch 130/270 | Loss: 0.5187
Batch 140/270 | Loss: 0.6735
Batch 150/270 | Loss: 0.7201
Batch 160/270 | Loss: 0.5441
Batch 170/270 | Loss: 0.5451
Batch 180/270 | Loss: 0.6894
Batch 190/270 | Loss: 0.6296
Batch 200/270 | Loss: 0.7258
Batch 210/270 | Loss: 0.6961
Batch 220/270 | Loss: 0.5332
Batch 230/270 | Loss: 0.5997
Batch 240/270 | Loss: 0.4393
Batch 250/270 | Loss: 0.5173
Batch 260/270 | Loss: 0.5135
Batch 270/270 | Loss: 0.4991
Epoch 1 Loss: 0.5934 | Accuracy: 69.59%

Validation Loss: 0.5457 | Validation Accuracy: 73.99%

Epoch 2/50
------------------------------
Batch 10/270 | Loss: 0.5141
Batch 20/270 | Los

In [25]:
# Training loop
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 30)
        
        for batch_idx, (inputs, labels) in enumerate(dataloaders['train']):
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx+1}/{len(dataloaders['train'])} | Loss: {loss.item():.4f}")
        
        epoch_loss = running_loss / len(dataloaders['train'])
        epoch_acc = 100. * correct / total
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.2f}%\n")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss /= len(dataloaders['val'])
        val_acc = 100. * val_correct / val_total
        print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_acc:.2f}%\n")

train_model(model, dataloaders, criterion, optimizer, num_epochs=5)

Epoch 1/5
------------------------------
Batch 10/270 | Loss: 0.3293
Batch 20/270 | Loss: 0.2029
Batch 30/270 | Loss: 0.2180
Batch 40/270 | Loss: 0.1726
Batch 50/270 | Loss: 0.3503
Batch 60/270 | Loss: 0.4824
Batch 70/270 | Loss: 0.2497
Batch 80/270 | Loss: 0.2939
Batch 90/270 | Loss: 0.4331
Batch 100/270 | Loss: 0.1668
Batch 110/270 | Loss: 0.2951
Batch 120/270 | Loss: 0.2473
Batch 130/270 | Loss: 0.1971
Batch 140/270 | Loss: 0.3398
Batch 150/270 | Loss: 0.4505
Batch 160/270 | Loss: 0.2370
Batch 170/270 | Loss: 0.2318
Batch 180/270 | Loss: 0.3375
Batch 190/270 | Loss: 0.3058
Batch 200/270 | Loss: 0.1428
Batch 210/270 | Loss: 0.2063
Batch 220/270 | Loss: 0.1731
Batch 230/270 | Loss: 0.3268
Batch 240/270 | Loss: 0.1944
Batch 250/270 | Loss: 0.3154
Batch 260/270 | Loss: 0.4840
Batch 270/270 | Loss: 0.7212
Epoch 1 Loss: 0.2868 | Accuracy: 87.50%

Validation Loss: 0.2953 | Validation Accuracy: 86.36%

Epoch 2/5
------------------------------
Batch 10/270 | Loss: 0.2503
Batch 20/270 | Loss:

In [26]:
def evaluate_model(model, dataloader, criterion):
    model.eval()
    total = 0
    correct = 0
    running_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    test_loss = running_loss / len(dataloader)
    test_acc = 100. * correct / total
    print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

evaluate_model(model, dataloaders['test'], criterion)

Test Loss: 0.3144 | Test Accuracy: 86.08%


In [27]:
torch.save(model.state_dict(), "glaucoma_model_SNN_86TA.pth")
print("Model saved successfully!")

Model saved successfully!
