In [1]:
!pip3 install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset
import re
from collections import Counter
import math

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 16  # Reduced for better learning
num_epochs = 10
learning_rate = 1e-4  # More conservative learning rate
num_steps = 50  # Reduced time steps for efficiency
vocab_size = 30522  # BERT vocab size
embedding_dim = 256
hidden_dim = 512
num_classes = 2
beta_stm = 0.8  # STM decay (fast)
beta_ltm = 0.95  # LTM decay (slow)
threshold = 1.0
max_length = 128  # Reduced for memory efficiency
dropout_rate = 0.4  # Increased dropout for better regularization

class IMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize using pretrained tokenizer
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        return encoding['input_ids'].squeeze(), torch.tensor(label, dtype=torch.long)

class AdaptiveSTDPLayer(nn.Module):
    def __init__(self, input_dim, output_dim, beta, threshold=1.0):
        super(AdaptiveSTDPLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.beta = beta
        self.threshold = threshold
        
        # Learnable parameters with proper initialization
        self.linear = nn.Linear(input_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)
        
        # Initialize weights properly
        nn.init.kaiming_uniform_(self.linear.weight, a=math.sqrt(5))
        nn.init.constant_(self.linear.bias, 0)
        
        # Use ATan surrogate for better gradient flow
        self.spike_grad = surrogate.atan(alpha=2.0)
        self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=self.spike_grad)
        
        # Adaptive dropout
        self.dropout = nn.Dropout(dropout_rate)
        
        # Improved STDP parameters
        self.register_buffer('tau_plus', torch.tensor(15.0))
        self.register_buffer('tau_minus', torch.tensor(30.0))
        self.A_plus = 0.001
        self.A_minus = 0.001
        
        # Trace variables for STDP
        self.register_buffer('pre_trace', torch.zeros(1, input_dim))
        self.register_buffer('post_trace', torch.zeros(1, output_dim))
        
    def forward(self, x, mem=None):
        batch_size = x.size(0)
        
        # Initialize membrane potential if needed
        if mem is None:
            mem = self.lif.init_leaky()
        
        # Apply layer normalization before linear transformation
        x_norm = self.layer_norm(x) if x.dim() > 1 else x
        
        # Apply dropout during training
        if self.training:
            x_norm = self.dropout(x_norm)
        
        # Linear transformation
        cur = self.linear(x_norm)
        
        # Add noise during training for regularization
        if self.training:
            noise = torch.randn_like(cur) * 0.01
            cur = cur + noise
        
        # LIF neuron processing
        spk, mem = self.lif(cur, mem)
        
        # Apply STDP learning during training
        if self.training:
            self.apply_stdp(x_norm, spk)
        
        return spk, mem
    
    def apply_stdp(self, pre_spikes, post_spikes):
        """Apply STDP learning rule"""
        if not self.training:
            return
        
        batch_size = pre_spikes.size(0)
        
        # Create new tensors for traces by cloning after expansion
        pre_trace = self.pre_trace.expand(batch_size, -1).clone()
        post_trace = self.post_trace.expand(batch_size, -1).clone()
        
        # Decay traces
        pre_trace = pre_trace * torch.exp(-1.0 / self.tau_plus)
        post_trace = post_trace * torch.exp(-1.0 / self.tau_minus)
        
        # Update traces with current spikes
        pre_trace = pre_trace + pre_spikes
        post_trace = post_trace + post_spikes
        
        # Update the buffers with the new values
        self.pre_trace = pre_trace.mean(dim=0, keepdim=True)
        self.post_trace = post_trace.mean(dim=0, keepdim=True)
        
        # Compute STDP weight updates
        with torch.no_grad():
            # LTP: post-synaptic spike increases weights
            ltp_update = torch.outer(post_spikes.mean(0), pre_trace.mean(0))
            
            # LTD: pre-synaptic spike decreases weights
            ltd_update = torch.outer(post_trace.mean(0), pre_spikes.mean(0))
            
            # Apply updates with small learning rate
            weight_update = self.A_plus * ltp_update - self.A_minus * ltd_update
            
            # Apply weight update with bounds
            self.linear.weight.data += weight_update * 0.0001
            self.linear.weight.data = torch.clamp(self.linear.weight.data, -2.0, 2.0)

class MemoryGatingLayer(nn.Module):
    """Gating mechanism to control STM/LTM contribution"""
    def __init__(self, hidden_dim):
        super(MemoryGatingLayer, self).__init__()
        self.gate_linear = nn.Linear(hidden_dim * 2, hidden_dim)
        self.gate_activation = nn.Sigmoid()
        
    def forward(self, stm_output, ltm_output):
        # Concatenate STM and LTM outputs
        combined = torch.cat([stm_output, ltm_output], dim=-1)
        
        # Compute gate values
        gate = self.gate_activation(self.gate_linear(combined))
        
        # Apply gating
        gated_output = gate * stm_output + (1 - gate) * ltm_output
        
        return gated_output

class NeuromorphicSentimentNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, 
                 beta_stm, beta_ltm, threshold):
        super(NeuromorphicSentimentNet, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        
        # Embedding layer with proper initialization
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        nn.init.normal_(self.embedding.weight, 0, 0.1)
        
        # Rate encoding with batch normalization
        self.rate_encoder = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh()
        )
        
        # STM Layer (fast dynamics) - processes immediate context
        self.stm_layer = AdaptiveSTDPLayer(hidden_dim, hidden_dim, beta_stm, threshold)
        
        # LTM Layer (slow dynamics) - processes long-term patterns
        self.ltm_layer = AdaptiveSTDPLayer(hidden_dim, hidden_dim, beta_ltm, threshold)
        
        # Memory gating mechanism
        self.memory_gate = MemoryGatingLayer(hidden_dim)
        
        # Attention mechanism for sequence processing
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )
        
        # Final integration layers
        self.integration_norm = nn.LayerNorm(hidden_dim)
        self.integration_dropout = nn.Dropout(dropout_rate)
        
        # Classification head with regularization
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Initialize classifier weights
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0)
        
    def forward(self, x):
        batch_size, seq_len = x.shape
        
        # Initialize memory states
        stm_mem = self.stm_layer.lif.init_leaky()
        ltm_mem = self.ltm_layer.lif.init_leaky()
        
        # Store outputs over time
        stm_outputs = []
        ltm_outputs = []
        
        # Process sequence with time steps
        effective_seq_len = min(seq_len, self.num_steps)
        
        for t in range(effective_seq_len):
            # Get embeddings for current time step
            emb = self.embedding(x[:, t])
            
            # Rate encoding - convert to neural activity
            rate_input = self.rate_encoder(emb)
            
            # Generate spikes with proper normalization
            spike_input = torch.sigmoid(rate_input) * 2.0 - 1.0
            
            # STM processing (fast adaptation)
            stm_spk, stm_mem = self.stm_layer(spike_input, stm_mem)
            stm_outputs.append(stm_spk)
            
            # LTM processing (slow adaptation)
            ltm_spk, ltm_mem = self.ltm_layer(spike_input, ltm_mem)
            ltm_outputs.append(ltm_spk)
        
        # Aggregate outputs over time
        if stm_outputs and ltm_outputs:
            # Stack and process temporal information
            stm_sequence = torch.stack(stm_outputs, dim=1)  # [batch, time, hidden]
            ltm_sequence = torch.stack(ltm_outputs, dim=1)  # [batch, time, hidden]
            
            # Apply attention to capture important temporal patterns
            stm_attended, _ = self.attention(stm_sequence, stm_sequence, stm_sequence)
            ltm_attended, _ = self.attention(ltm_sequence, ltm_sequence, ltm_sequence)
            
            # Temporal pooling with weighted average
            time_weights = torch.softmax(torch.arange(effective_seq_len, dtype=torch.float32, device=x.device), dim=0)
            stm_pooled = (stm_attended * time_weights.view(1, -1, 1)).sum(dim=1)
            ltm_pooled = (ltm_attended * time_weights.view(1, -1, 1)).sum(dim=1)
            
        else:
            stm_pooled = torch.zeros(batch_size, self.hidden_dim, device=x.device)
            ltm_pooled = torch.zeros(batch_size, self.hidden_dim, device=x.device)
        
        # Apply memory gating
        gated_output = self.memory_gate(stm_pooled, ltm_pooled)
        
        # Final normalization and dropout
        integrated = self.integration_norm(gated_output)
        integrated = self.integration_dropout(integrated)
        
        # Classification
        output = self.classifier(integrated)
        
        return output

def load_imdb_data():
    """Load real IMDB dataset from Hugging Face"""
    print("Loading IMDB dataset...")
    
    # Load dataset
    dataset = load_dataset('imdb')
    
    # Extract train and test data
    train_texts = dataset['train']['text']
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']
    
    # Use a reasonable subset for development
    # train_texts = train_texts[:20000]  # Smaller subset for better learning
    # train_labels = train_labels[:20000]
    # test_texts = test_texts[:10000]
    # test_labels = test_labels[:10000]
    
    print(f"Training samples: {len(train_texts)}")
    print(f"Test samples: {len(test_texts)}")
    
    return train_texts, train_labels, test_texts, test_labels

def validate_model(model, val_loader, criterion):
    """Validate the model"""
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += targets.size(0)
            val_correct += predicted.eq(targets).sum().item()
    
    val_accuracy = 100. * val_correct / val_total
    avg_val_loss = val_loss / len(val_loader)
    
    return avg_val_loss, val_accuracy

def train_model():
    # Load real IMDB data
    train_texts, train_labels, test_texts, test_labels = load_imdb_data()
    
    # Initialize pretrained tokenizer
    print("Loading pretrained tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Add padding token if not present
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Create train/validation split
    train_texts_split, val_texts, train_labels_split, val_labels = train_test_split(
        train_texts, train_labels, test_size=0.2, random_state=42, stratify=train_labels
    )
    
    # Create datasets
    train_dataset = IMDBDataset(train_texts_split, train_labels_split, tokenizer, max_length=max_length)
    val_dataset = IMDBDataset(val_texts, val_labels, tokenizer, max_length=max_length)
    test_dataset = IMDBDataset(test_texts, test_labels, tokenizer, max_length=max_length)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Initialize model
    model = NeuromorphicSentimentNet(
        vocab_size=tokenizer.vocab_size,
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes,
        beta_stm=beta_stm,
        beta_ltm=beta_ltm,
        threshold=threshold
    ).to(device)
    
    # Loss and optimizer with proper regularization
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing helps generalization
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-3)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Early stopping
    best_val_accuracy = 0
    patience = 3
    patience_counter = 0
    
    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 50 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # Update learning rate
        scheduler.step()
        
        # Training metrics
        train_accuracy = 100. * correct / total
        avg_train_loss = total_loss / len(train_loader)
        
        # Validation
        val_loss, val_accuracy = validate_model(model, val_loader, criterion)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
        
        # Early stopping check
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_neuromorphic_sentiment_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
    
    # Load best model for final evaluation
    model.load_state_dict(torch.load('best_neuromorphic_sentiment_model.pth'))
    
    # Final test evaluation
    test_loss, test_accuracy = validate_model(model, test_loader, criterion)
    print(f'Final Test Accuracy: {test_accuracy:.2f}%')
    
    return model, tokenizer

def predict_sentiment(model, tokenizer, text):
    """Predict sentiment of a single text"""
    model.eval()
    with torch.no_grad():
        encoding = tokenizer(
            text,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        tokens = encoding['input_ids'].to(device)
        output = model(tokens)
        prediction = torch.softmax(output, dim=1)
        sentiment = "Positive" if prediction[0][1] > prediction[0][0] else "Negative"
        confidence = prediction[0].max().item()
        return sentiment, confidence

if __name__ == "__main__":
    print("Training Enhanced Neuromorphic Sentiment Analysis Model...")
    
    model, tokenizer = train_model()
    
    # Test with sample texts
    test_texts = [
        "This movie is absolutely amazing! The cinematography was breathtaking and the acting was superb.",
        "I hate this film, it's terrible. The plot made no sense and the dialogue was awful.",
        "Not bad, but could be better. The story was okay but felt rushed.",
        "Outstanding performance by all actors! This is definitely one of the best films I've seen this year.",
        "The movie was boring and predictable. I fell asleep halfway through.",
        "Incredible storytelling and beautiful visuals. A masterpiece of modern cinema!"
    ]
    
    print("\nTesting predictions:")
    for text in test_texts:
        sentiment, confidence = predict_sentiment(model, tokenizer, text)
        print(f"Text: '{text[:50]}...' -> {sentiment} (confidence: {confidence:.3f})")
    
    print("\nModel training completed successfully!")

  from snntorch import backprop


Using device: cuda
Training Enhanced Neuromorphic Sentiment Analysis Model...
Loading IMDB dataset...


README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Training samples: 25000
Test samples: 25000
Loading pretrained tokenizer...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Starting training...
Epoch: 0, Batch: 0, Loss: 1.2335
Epoch: 0, Batch: 50, Loss: 0.8168
Epoch: 0, Batch: 100, Loss: 0.7877
Epoch: 0, Batch: 150, Loss: 0.7026
Epoch: 0, Batch: 200, Loss: 0.7388
Epoch: 0, Batch: 250, Loss: 0.6252
Epoch: 0, Batch: 300, Loss: 0.6522
Epoch: 0, Batch: 350, Loss: 0.8687
Epoch: 0, Batch: 400, Loss: 0.6657
Epoch: 0, Batch: 450, Loss: 0.7649
Epoch: 0, Batch: 500, Loss: 0.6524
Epoch: 0, Batch: 550, Loss: 0.5995
Epoch: 0, Batch: 600, Loss: 0.6311
Epoch: 0, Batch: 650, Loss: 0.7365
Epoch: 0, Batch: 700, Loss: 0.6433
Epoch: 0, Batch: 750, Loss: 0.6755
Epoch: 0, Batch: 800, Loss: 0.7112
Epoch: 0, Batch: 850, Loss: 0.6084
Epoch: 0, Batch: 900, Loss: 0.5392
Epoch: 0, Batch: 950, Loss: 0.6591
Epoch: 0, Batch: 1000, Loss: 0.4868
Epoch: 0, Batch: 1050, Loss: 0.7371
Epoch: 0, Batch: 1100, Loss: 0.6663
Epoch: 0, Batch: 1150, Loss: 0.7411
Epoch: 0, Batch: 1200, Loss: 0.5828
Epoch 0: Train Loss: 0.7017, Train Acc: 59.44%, Val Loss: 0.5926, Val Acc: 71.12%
Epoch: 1, Batch: 0, 