In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import mmh3 
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist
import random

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [3]:
class SimpleEncoder(nn.Module):
    def __init__(self):
        super(SimpleEncoder, self).__init__()
        # A simple CNN encoder: input (batch_size,channels,height,width) -> output vector (batch_size,128)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(32 * 7 * 7, 128)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x  # output shape: (batch_size, 128)
    
class NeuralBloomFilter(nn.Module):
    def __init__(self, memory_slots=10, word_size=32, class_num=1):
        super().__init__()
        self.encoder = SimpleEncoder()
        self.fc_q    = nn.Linear(128, word_size)
        self.fc_w    = nn.Linear(128, word_size)
        self.A       = nn.Parameter(torch.randn(word_size, memory_slots))
        self.register_buffer('M', torch.zeros(memory_slots, word_size))
        inp_dim = memory_slots*word_size + word_size + 128
        self.mlp = nn.Sequential(
            nn.Linear(inp_dim, 128),
            nn.ReLU(),
            nn.Linear(128, class_num)
        )

    def controller(self, x):
        """
        Runs the encoder, computes query q, write word w, and normalized address a.
        Returns (a, w, z).
        """
        z = self.encoder(x)                   # (B,128)
        q = self.fc_q(z)                      # (B,word_size)
        w = self.fc_w(z)                      # (B,word_size)
        a_logits = q @ self.A                 # (B,slots)
        a = F.softmax(a_logits, dim=1)        # (B,slots)
        return a, w, z

    def write(self, x):
        """
        Write all x in one shot:
          M ← M + ∑_i w_i a_i^T
        """
        a, w, _ = self.controller(x)          # a:(B,slots), w:(B,word_size)
        # Outer product per sample: update_i[k,p] = a[i,k] * w[i,p]
        # Stack them and sum over batch:
        # shape: (B, slots, word_size)
        update = torch.einsum('bk,bp->bkp', a, w)
        # Add to memory and detach so writes don't backprop through time:
        self.M = self.M + update.sum(dim=0).detach()

    def read(self, x):
        """
        Read operation:
          r_i = flatten( M ⊙ a_i )  (componentwise)
          logits = f_out([r_i, w_i, z_i])
        """
        a, w, z = self.controller(x)          # (B,slots), (B,word_size), (B,128)
        # M ⊙ a_i: scale each row of M by a_i:
        # shape before flatten: (B, slots, word_size)
        r = (a.unsqueeze(2) * self.M.unsqueeze(0)).reshape(x.size(0), -1)  # (B, slots*word_size)
        # Concatenate r, w, z:
        concat = torch.cat([r, w, z], dim=1)                              # (B, inp_dim)
        logits = self.mlp(concat)                                          # (B, class_num)
        return logits


In [4]:
def sample_task(labels, storage_set_size, num_queries, class_num):
    class_to_indices = {}
    for idx, label in enumerate(labels):
        if label not in class_to_indices:
            class_to_indices[label] = []
        class_to_indices[label].append(idx)
    
    storage_indices = random.sample(class_to_indices[class_num], storage_set_size)

    num_in = num_queries // 2
    num_out = num_queries - num_in
    query_in_indices = random.sample(class_to_indices[class_num], num_in)
    other_classes = [c for c in class_to_indices if c != class_num]
    query_out_indices = []
    for _ in range(num_out):
        other_class = random.choice(other_classes)
        query_out_indices.append(random.choice(class_to_indices[other_class]))
    
    query_indices = query_in_indices + query_out_indices
    # define target as [1,0] for in-class and [0,1] for out-of-class
    targets = []
    for i in range(num_in):
        targets.append(1)
    for i in range(num_out):
        targets.append(0)
    
    targets = torch.tensor(targets, dtype=torch.float32).unsqueeze(1) 
    return storage_indices, query_indices, targets

In [5]:
def meta_train(model, dataset, labels, optimizer, criterion, device, meta_epochs=10, storage_set_size=60, num_queries=10):
    
    model.train()
    total_loss = 0.0
    classes = [0,1,2,3,4,5,6]

    for epoch in range(meta_epochs):
        
        class_num = random.choice(classes)
        storage_indices, query_indices, targets= sample_task(labels, storage_set_size, num_queries, class_num)
        storage_images = dataset[storage_indices]
        storage_images = torch.tensor(storage_images, dtype=torch.float32).unsqueeze(1)
        query_images = dataset[query_indices]
        query_images = torch.tensor(query_images, dtype=torch.float32).unsqueeze(1)
        
        model.M.zero_()
        
        _ = model.write(storage_images)  # Expected shape: (storage_set_size, word_size)
        
        logits = model.read(query_images)  # Expected shape: (num_queries, class_num)
        
        
        probs = torch.sigmoid(logits) 
        
        # Convert probabilities to binary predictions
        predictions = (probs > 0.5).float()
        
        fnr = 0
        fpr = 0
        for i in range(len(predictions)):
            if predictions[i] == 0 and targets[i] == 1:
                fnr += 1
            elif predictions[i] == 1 and targets[i] == 0:
                fpr += 1
        
        loss = criterion(logits, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        total_loss += loss.item()
        if (epoch + 1) % 100 == 0:
            avg_loss = total_loss / (epoch + 1)
            print(f"Epoch [{epoch+1}/{meta_epochs}], Loss: {loss.item():.4f}, False Positive Rate: {fpr}, False Negative Rate: {fnr}")
    
    return total_loss / meta_epochs



In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

if __name__ == '__main__':
    # Hyperparameters
    meta_epochs = 2000
    storage_set_size = 1000
    num_queries = 200
    learning_rate = 1e-5
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = NeuralBloomFilter(memory_slots=64, word_size=64, class_num=1).to(device)
    
    # Use CrossEntropyLoss for binary classification (2 classes)
    criterion = nn.BCEWithLogitsLoss()
    
    # Use Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Start meta-training
    avg_loss = meta_train(model, x_train, y_train,optimizer, criterion, device,
                          meta_epochs=meta_epochs,
                          storage_set_size=storage_set_size,
                          num_queries=num_queries)
    print(f"Meta-training completed. Average Loss: {avg_loss:.4f}")

Epoch [100/2000], Loss: 0.9922, False Positive Rate: 17, False Negative Rate: 65
Epoch [200/2000], Loss: 0.6182, False Positive Rate: 21, False Negative Rate: 44
Epoch [300/2000], Loss: 0.6598, False Positive Rate: 35, False Negative Rate: 33
Epoch [400/2000], Loss: 0.5731, False Positive Rate: 25, False Negative Rate: 37
Epoch [500/2000], Loss: 0.6020, False Positive Rate: 23, False Negative Rate: 29
Epoch [600/2000], Loss: 0.4749, False Positive Rate: 25, False Negative Rate: 18
Epoch [700/2000], Loss: 0.7805, False Positive Rate: 11, False Negative Rate: 25
Epoch [800/2000], Loss: 0.6146, False Positive Rate: 26, False Negative Rate: 34
Epoch [900/2000], Loss: 0.5845, False Positive Rate: 22, False Negative Rate: 29
Epoch [1000/2000], Loss: 0.5398, False Positive Rate: 21, False Negative Rate: 29
Epoch [1100/2000], Loss: 0.5020, False Positive Rate: 18, False Negative Rate: 20
Epoch [1200/2000], Loss: 0.4031, False Positive Rate: 17, False Negative Rate: 16
Epoch [1300/2000], Loss: 

In [7]:
x_test = torch.tensor(x_test, dtype=torch.float32).unsqueeze(1)
y_labels = [0]*len(x_test)


for i in range(len(x_test)):
    if y_test[i] in [0]:
        y_labels[i] = 1
    else:
        y_labels[i] = 0
        
y_labels = torch.tensor(y_labels, dtype=torch.float32).unsqueeze(1)

logits = model.read(x_test)  # Expected shape: (num_queries, class_num)
probs = F.softmax(logits, dim=1)  # Convert logits to probabilities of shape (num_queries, 2)
predicted_class = torch.argmax(probs, dim=1)
accuracy = (predicted_class == y_labels[:, 0].long()).float().mean().item()
print(f"Test Accuracy: {accuracy:.4f}")

Test Accuracy: 0.9020
