# Membership Inference using Robust Membership Inference Attack (RMIA)

## Setup

In [40]:
import torch
import torch.nn.functional as F

from privacy_and_grokking.path_keeper import get_path_keeper
from privacy_and_grokking.utils import get_device
from privacy_and_grokking.datasets import MNIST
from privacy_and_grokking.models import MLP

In [41]:
pk = get_path_keeper()
device = get_device()

## Attack

In [56]:
def compute_log_likelihood(model, inputs, labels, device='cpu'):
    model.eval()
    inputs, labels = inputs.to(device), labels.to(device)
    
    with torch.no_grad():
        logits = model(inputs)
        log_probs = F.log_softmax(logits, dim=1)
        true_class_log_probs = log_probs.gather(1, labels.view(-1, 1)).squeeze()
        
    return true_class_log_probs

def rmia_score(target_model, ref_model, target_x, target_y, population_data, population_labels, device='cpu'):
    """
    Implements the RMIA scoring function.
    
    Args:
        target_model: The model under attack.
        ref_model: The reference model (independent).
        target_x: The specific data point to test (single sample).
        target_y: The label of the target data point.
        population_data: A batch of non-member data samples (z).
        population_labels: Labels for the population data.
    
    Returns:
        float: The RMIA score (0.0 to 1.0). Higher score = higher probability of membership.
    """
    # 1. Compute Log-Likelihoods for the Target Sample (x)
    # P(x | theta_target)
    ll_target_x = compute_log_likelihood(target_model, target_x.unsqueeze(0), target_y.unsqueeze(0), device)
    # P(x | theta_ref)
    ll_ref_x = compute_log_likelihood(ref_model, target_x.unsqueeze(0), target_y.unsqueeze(0), device)
    
    # Ratio for x: P(x|Target) / P(x|Ref) -> Log difference
    ratio_x = ll_target_x - ll_ref_x
    
    # 2. Compute Log-Likelihoods for Population Data (z)
    # P(z | theta_target)
    ll_target_z = compute_log_likelihood(target_model, population_data, population_labels, device)
    # P(z | theta_ref)
    ll_ref_z = compute_log_likelihood(ref_model, population_data, population_labels, device)
    
    # Ratio for z: P(z|Target) / P(z|Ref) -> Log difference
    ratio_z = ll_target_z - ll_ref_z
    
    # 3. Pairwise Comparison (The RMIA Test)
    # We want to count how many times Ratio(x) > Ratio(z)
    # Because we are in log space: (ll_target_x - ll_ref_x) > (ll_target_z - ll_ref_z)
    
    # This effectively computes: P( LR(x, z) > 1 ) over all z
    pairwise_comparisons = (ratio_x > ratio_z).float()
    
    score = pairwise_comparisons.mean().item()
    
    return score

In [None]:
train, test = MNIST(size="small", canary="gaussian_noise")()
print(len(train), len(test))

population_x, population_y = [], []
for x, y in test:
    population_x.append(x)
    population_y.append(y)
    if len(population_x) == 1000:
        break
population_x, population_y = torch.stack(population_x), torch.tensor(population_y)

1010 10000


In [97]:
pk.set_params({"run_id": "v1.3.0", "model": "MLP_GROK_CAN_NOISE_V1", "step": 100_000})

target_model = MLP()
target_model.load_state_dict(torch.load(pk.MODEL_TORCH, weights_only=True, map_location=device))
target_model.to(device)

pk.set_params({"run_id": "v1.3.0", "model": "MLP_V1", "step": 100_000})
reference_model = MLP()
reference_model.load_state_dict(torch.load(pk.MODEL_TORCH, weights_only=True, map_location=device))
reference_model.to(device)

MLP(
  (fc1): Linear(in_features=784, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=10, bias=True)
)

In [109]:
scores = []
for target_x, target_y in test:
    target_y = torch.tensor(target_y)
    score = rmia_score(target_model, reference_model, target_x, target_y, population_x, population_y, device)
    scores.append(score)
scores = torch.tensor(scores)
(scores > 0.5).sum() / len(scores)

tensor(0.5465)

In [110]:
scores = []
for target_x, target_y in train:
    score = rmia_score(target_model, reference_model, target_x, target_y, population_x, population_y, device)
    scores.append(score)
scores = torch.tensor(scores)
(scores > 0.5).sum() / len(scores)

tensor(0.9990)