In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist
import random
import mmh3
from bitarray import bitarray

In [98]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [99]:
class SimpleEncoder(nn.Module):
    def __init__(self):
        super(SimpleEncoder, self).__init__()
        # A simple CNN encoder: input (batch_size,channels,height,width) -> output vector (batch_size,128)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(32 * 7 * 7, 128)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x  # output shape: (batch_size, 128)
    
class NeuralBloomFilter(nn.Module):
    def __init__(self, memory_slots=10, word_size=32, class_num=1):
        super().__init__()
        self.encoder = SimpleEncoder()
        self.fc_q    = nn.Linear(128, word_size)
        self.fc_w    = nn.Linear(128, word_size)
        self.A       = nn.Parameter(torch.randn(word_size, memory_slots))
        self.register_buffer('M', torch.zeros(memory_slots, word_size))
        inp_dim = memory_slots*word_size + word_size + 128
        self.mlp = nn.Sequential(
            nn.Linear(inp_dim, 128),
            nn.ReLU(),
            nn.Linear(128, class_num)
        )

    def controller(self, x):
        """
        Runs the encoder, computes query q, write word w, and normalized address a.
        Returns (a, w, z).
        """
        z = self.encoder(x)                   # (B,128)
        q = self.fc_q(z)                      # (B,word_size)
        w = self.fc_w(z)                      # (B,word_size)
        a_logits = q @ self.A                 # (B,slots)
        a = F.softmax(a_logits, dim=1)        # (B,slots)
        return a, w, z

    def write(self, x):
        """
        Write all x in one shot:
          M ← M + ∑_i w_i a_i^T
        """
        a, w, _ = self.controller(x)          # a:(B,slots), w:(B,word_size)
        # Outer product per sample: update_i[k,p] = a[i,k] * w[i,p]
        # Stack them and sum over batch:
        # shape: (B, slots, word_size)
        update = torch.einsum('bk,bp->bkp', a, w)
        # Add to memory and detach so writes don't backprop through time:
        self.M = self.M + update.sum(dim=0).detach()

    def read(self, x):
        """
        Read operation:
          r_i = flatten( M ⊙ a_i )  (componentwise)
          logits = f_out([r_i, w_i, z_i])
        """
        a, w, z = self.controller(x)          # (B,slots), (B,word_size), (B,128)
        # M ⊙ a_i: scale each row of M by a_i:
        # shape before flatten: (B, slots, word_size)
        r = (a.unsqueeze(2) * self.M.unsqueeze(0)).reshape(x.size(0), -1)  # (B, slots*word_size)
        # Concatenate r, w, z:
        concat = torch.cat([r, w, z], dim=1)                              # (B, inp_dim)
        logits = self.mlp(concat)                                          # (B, class_num)
        return logits


In [100]:
class BloomFilter:
    """
    Standard Bloom Filter with customizable size and hash count.
    """
    def __init__(self, m: int, k: int):
        """
        Args:
            m: Number of bits in the filter bit array.
            k: Number of hash functions.
        """
        self.m = m
        self.k = k
        self.bits = bitarray(m)
        self.bits.setall(0)

    def add(self, key_bytes: bytes):
        """Insert a key (bytes) into the Bloom filter."""
        for i in range(self.k):
            idx = mmh3.hash(key_bytes, i) % self.m
            self.bits[idx] = 1

    def __contains__(self, key_bytes: bytes) -> bool:
        """Check membership of a key (bytes)."""
        return all(self.bits[mmh3.hash(key_bytes, i) % self.m]
                   for i in range(self.k))

class SimHashLSH:
    """
    Locality-Sensitive Hashing via random hyperplanes (SimHash).
    """
    def __init__(self, dim: int = 128, bits: int = 64):
        """
        Args:
            dim: Dimensionality of input embeddings.
            bits: Number of bits in the SimHash fingerprint.
        """
        # Random hyperplanes drawn from Gaussian distribution
        self.hyperplanes = np.random.randn(bits, dim).astype(np.float32)

    def compute(self, emb: np.ndarray) -> bytes:
        """Compute a `bits`-bit SimHash fingerprint for the embedding."""
        projections = np.dot(self.hyperplanes, emb)  # shape (bits,)
        bits_arr = (projections >= 0).astype(np.uint8)  # 1 if >=0, else 0
        packed = np.packbits(bits_arr)  # pack each 8 bits into a byte
        return packed.tobytes()         # result length = bits//8

class BackupBloomFilter:
    """
    Combines a learned classifier or embedding-based filter with a backup Bloom filter.
    """
    def __init__(self, encoder, bbf_size: int, bbf_hashes: int,
                 lsh_bits: int = 64):
        """
        Args:
            encoder: Function or model that maps an input image -> 1D embedding (np.ndarray, shape (dim,)).
            classifier: Function or model that maps embedding -> bool (returns True if "in set").
            bbf_size: Number of bits in the backup Bloom filter.
            bbf_hashes: Number of hash functions for the BBF.
            lsh_bits: Number of bits in SimHash fingerprint.
        """
        self.encoder = encoder
        self.lsh = SimHashLSH(dim=encoder.fc.out_features, bits=lsh_bits)
        self.bbf = BloomFilter(m=bbf_size, k=bbf_hashes)

    def insert(self, image):
        """Insert an image into the backup filter (and optionally train classifier)."""
        emb_t = self.encoder(image)          # expects np.ndarray shape (B, dim)
        emb = emb_t.squeeze(0).detach().cpu().numpy()
        key = self.lsh.compute(emb)
        self.bbf.add(key)
        # Optionally: classifier.fit or online update here

    def query(self, image) -> bool:
        """Query the learned filter, fallback to BBF for zero false negatives."""
        emb_t = self.encoder(image) # expects np.ndarray shape (B, dim)
        emb = emb_t.squeeze(0).detach().cpu().numpy()
        key = self.lsh.compute(emb)
        return (key in self.bbf)


In [101]:
def sample_task(labels, storage_set_size, num_queries, class_num):
    class_to_indices = {}
    for idx, label in enumerate(labels):
        if label not in class_to_indices:
            class_to_indices[label] = []
        class_to_indices[label].append(idx)
    
    storage_indices = random.sample(class_to_indices[class_num], storage_set_size)

    num_in = num_queries // 2
    num_out = num_queries - num_in
    query_in_indices = random.sample(class_to_indices[class_num], num_in)
    other_classes = [c for c in class_to_indices if c != class_num]
    query_out_indices = []
    for _ in range(num_out):
        other_class = random.choice(other_classes)
        query_out_indices.append(random.choice(class_to_indices[other_class]))
    
    query_indices = query_in_indices + query_out_indices
    # define target as [1,0] for in-class and [0,1] for out-of-class
    targets = []
    for i in range(num_in):
        targets.append(1)
    for i in range(num_out):
        targets.append(0)
    
    targets = torch.tensor(targets, dtype=torch.float32).unsqueeze(1) 
    return storage_indices, query_indices, targets

In [102]:
def meta_train(model, dataset, labels, optimizer, backupBloomfilter, criterion, device, meta_epochs=10, storage_set_size=60, num_queries=10):
    
    model.train()
    total_loss = 0.0
    classes = [0]

    for epoch in range(meta_epochs):
        
        class_num = random.choice(classes)
        storage_indices, query_indices, targets= sample_task(labels, storage_set_size, num_queries, class_num)
        storage_images = dataset[storage_indices]
        storage_images = torch.tensor(storage_images, dtype=torch.float32).unsqueeze(1)
        query_images = dataset[query_indices]
        query_images = torch.tensor(query_images, dtype=torch.float32).unsqueeze(1)
        
        model.M.zero_()
        backupBloomfilter.bbf.bits.setall(0)
        _ = model.write(storage_images)  # Expected shape: (storage_set_size, word_size)
        
        # Write to the backup Bloom filter
        for i in range(storage_set_size):
            
            backupBloomfilter.insert(storage_images[i].unsqueeze(0))
        
        
        logits = model.read(query_images)  # Expected shape: (num_queries, class_num)
        
        
        probs = torch.sigmoid(logits) 
        
        # Convert probabilities to binary predictions
        predictions = (probs > 0.5).float()
        
        
        for i in range(len(predictions)):
            if predictions[i] == 0 and targets[i] == 1:
                predictions[i] = backupBloomfilter.query(query_images[i].unsqueeze(0))
        
        fnr = 0
        fpr = 0
        for i in range(len(predictions)):
            if predictions[i] == 0 and targets[i] == 1:
                fnr += 1
            elif predictions[i] == 1 and targets[i] == 0:
                fpr += 1
        
        loss = criterion(logits, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        total_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            avg_loss = total_loss / (epoch + 1)
            print(f"Epoch [{epoch+1}/{meta_epochs}], Loss: {loss.item():.4f}, False Positive Rate: {fpr}, False Negative Rate: {fnr}")
    
    return total_loss / meta_epochs



In [103]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

if __name__ == '__main__':
    # Hyperparameters
    meta_epochs = 100
    storage_set_size = 100
    num_queries = 20
    learning_rate = 1e-5
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = NeuralBloomFilter(memory_slots=64, word_size=64, class_num=1).to(device)
    
    BackupBloomFilter = BackupBloomFilter(
        encoder=model.encoder,
        bbf_size=2000,  # Size of the backup Bloom filter
        bbf_hashes=16,   # Number of hash functions for the BBF
        lsh_bits=32    # Number of bits in the SimHash fingerprint
    )
    
    # Use CrossEntropyLoss for binary classification (2 classes)
    criterion = nn.BCEWithLogitsLoss()
    
    # Use Adam optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Start meta-training
    avg_loss = meta_train(model, x_train, y_train,optimizer, BackupBloomFilter,criterion, device,
                          meta_epochs=meta_epochs,
                          storage_set_size=storage_set_size,
                          num_queries=num_queries)
    print(f"Meta-training completed. Average Loss: {avg_loss:.4f}")
    # print the backup bloom filter
    print(BackupBloomFilter.bbf.bits)

Epoch [10/100], Loss: 1.2477, False Positive Rate: 8, False Negative Rate: 1
Epoch [20/100], Loss: 1.0488, False Positive Rate: 4, False Negative Rate: 3
Epoch [30/100], Loss: 0.5281, False Positive Rate: 2, False Negative Rate: 2
Epoch [40/100], Loss: 0.4766, False Positive Rate: 2, False Negative Rate: 0
Epoch [50/100], Loss: 0.7468, False Positive Rate: 5, False Negative Rate: 2
Epoch [60/100], Loss: 0.5182, False Positive Rate: 0, False Negative Rate: 4
Epoch [70/100], Loss: 0.6591, False Positive Rate: 2, False Negative Rate: 2
Epoch [80/100], Loss: 0.2813, False Positive Rate: 0, False Negative Rate: 0
Epoch [90/100], Loss: 0.3495, False Positive Rate: 2, False Negative Rate: 2
Epoch [100/100], Loss: 0.4146, False Positive Rate: 2, False Negative Rate: 2
Meta-training completed. Average Loss: 0.5564
bitarray('01001110101011101111110011010001101011001101111101100111110101011110101101111101100101111101110010101101111000010110100110011110111010100000000010101100100101111011011001011

In [104]:
x_test = torch.tensor(x_test, dtype=torch.float32).unsqueeze(1)
y_labels = [0]*len(x_test)


for i in range(len(x_test)):
    if y_test[i] in [0]:
        y_labels[i] = 1
    else:
        y_labels[i] = 0
        
y_labels = torch.tensor(y_labels, dtype=torch.float32).unsqueeze(1)

logits = model.read(x_test)  # Expected shape: (num_queries, class_num)
probs = F.softmax(logits, dim=1)  # Convert logits to probabilities of shape (num_queries, 2)
predicted_class = torch.argmax(probs, dim=1)
accuracy = (predicted_class == y_labels[:, 0].long()).float().mean().item()
print(f"Test Accuracy: {accuracy:.4f}")

Test Accuracy: 0.9020
