# FSL-Net: Feature Shift Localization Demo

This notebook demonstrates how to use a pre-trained Feature Shift Localization Network (FSL-Net) to identify feature-level changes between a reference and a manipulated query dataset. The goal is to localize which features have been modified. We will:

1. Load a pre-trained FSL-Net model
2. Load the reference and manipulated query datasets
3. Run inference
4. Evaluate the model's performance using F1 Score and runtime

In [1]:
import sys
import time
import torch
import numpy as np
from sklearn.metrics import f1_score

sys.path.append('../')
from fslnet.fslnet import FSLNet

#### 1 Setup

Select the computation device (GPU if available) and define dataset paths:

In [2]:
# Automatically select GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

# Paths to data files
ref_path = './data/references/covid_0.1_E1_ref.npy'
que_path = './data/queries/covid_0.1_E1_que.npy'
C_positions_path = './data/C_positions/covid_0.1_E1_C_positions.npy'

Device: cpu


### 2 Load Pre-trained FSL-Net

Load the FSL-Net model:

In [3]:
fslnet = FSLNet.from_pretrained(device=device)

--> Loading FSLNet weights from '/private/home/mbarrabe/FSL-Net/demos/../fslnet/checkpoints/fslnet.pth' onto cpu ...
--> FSLNet loaded and set to eval().


### 3 Load Reference, Query, and Shifted Features

Load the reference, query, and shifted feature indices:

In [4]:
ref_path = '/private/home/mbarrabe/FSL-Net/data/references/covid_0.1_E1_ref.npy'
que_path = '/private/home/mbarrabe/FSL-Net/data/queries/covid_0.1_E1_que.npy'
C_positions_path = '/private/home/mbarrabe/FSL-Net/data/C_positions/covid_0.1_E1_C_positions.npy'

ref = torch.tensor(np.load(ref_path), dtype=torch.float32)
que = torch.tensor(np.load(que_path), dtype=torch.float32)
C_positions = torch.tensor(np.load(C_positions_path))

### 4 Inference & Evaluation

Run FSL-Net on the reference and query sets, then evaluate the F1 Score:

In [5]:
with torch.no_grad():
    start_time = time.time()
    soft_predictions, _ = fslnet(ref, que)        # Corruption probabilities
    hard_predictions = (soft_predictions > 0.5)  # Boolean mask — True = shifted
    end_time = time.time()

runtime = end_time - start_time

# Create ground truth tensor
target = torch.zeros(1, que.shape[1]).to(device)
target[0, C_positions] = 1

# Compute F1 Score
f1 = f1_score(target.squeeze(), hard_predictions.squeeze(), zero_division=1)

print("F1 Score:", f1)
print("Runtime (seconds):", runtime)

F1 Score: 1.0
Runtime (seconds): 0.18433022499084473
