### Code to run evaluation on 2 tasks main tasks, clustering and similarity search

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

%cd /content/drive/MyDrive/ReverseEmbedding/GEIA

Mounted at /content/drive/
/content/drive/MyDrive/ReverseEmbedding/GEIA


In [2]:
import transformers

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, T5EncoderModel, GPTNeoXForCausalLM

# ==========================================
# 1. SETUP: The "Victim" Model (Sentence-T5)
# ==========================================
from sentence_transformers import SentenceTransformer
# name = 'intfloat/e5-base-v2'
# name = 'all-distilroberta-v1'
# name = 'sentence-transformers/sentence-t5-base'
victim_model = SentenceTransformer('intfloat/e5-base-v2')
victim_model.eval() # Good practice, though .encode() handles this internally

def get_clean_embedding(text_list):
    """
    Uses the easy library to get the official embeddings.
    """
    embeddings = victim_model.encode(text_list, convert_to_tensor=True)
    return embeddings


# ==========================================
# 2. THE ATTACKER: Pythia-160m Inverter
# ==========================================
class PythiaAttacker(nn.Module):
    def __init__(self, input_dim=768, device='cpu'):
        super().__init__()
        self.device = device
        print("Loading Pythia-160m...")
        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
        # Ensure padding token exists for Pythia (GPT-NeoX)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m").to(device)

        # Projector: Maps T5 embedding dim (768) to Pythia embedding dim (768)
        # Even if dims are same, this layer learns to align the latent spaces
        self.pythia_dim = self.model.config.hidden_size
        self.projector = nn.Linear(input_dim, self.pythia_dim).to(device)

    def forward(self, sentence_embedding, target_text_list=None):
        """
        Args:
            sentence_embedding: (batch, t5_dim)
            target_text_list: List of strings (the original text)
        """
        # 1. Project the T5 embedding to Pythia's space
        # Shape: (batch, 1, pythia_dim) -> It becomes the first "virtual token"
        prompt_embeds = self.projector(sentence_embedding).unsqueeze(1)

        # 2. Prepare Target Labels (if training/evaluating)
        if target_text_list:
            # Tokenize targets with Pythia's tokenizer
            target_inputs = self.tokenizer(
                target_text_list,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=64
            ).to(self.device)

            input_ids = target_inputs['input_ids']
            attention_mask = target_inputs['attention_mask']

            # Get embeddings of the real text
            # inputs_embeds: (batch, seq_len, pythia_dim)
            text_embeds = self.model.gpt_neox.embed_in(input_ids)

            # 3. Concatenate: [Projected_Emb + Text_Embeds]
            # We prepend the sentence embedding to the sequence
            inputs_embeds = torch.cat([prompt_embeds, text_embeds], dim=1)

            # Extend attention mask (add 1 for the prepended token)
            batch_size = input_ids.shape[0]
            prefix_mask = torch.ones((batch_size, 1), device=self.device)
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

            # 4. Forward Pass through Pythia
            outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
            return outputs.logits, input_ids

        return None, None

# ==========================================
# 3. THE DEFENSE: FGSM
# ==========================================
def fgsm_protect(sentence_embedding, text_list, attacker_model, epsilon=0.1):
    """
    Performs Gradient Ascent on the Reconstruction Loss.
    """
    # 1. Setup gradient tracking on the embedding
    protected_emb = sentence_embedding.clone().detach()
    protected_emb.requires_grad = True

    # 2. Forward pass through Attacker
    logits, target_ids = attacker_model(protected_emb, text_list)

    # 3. Align Logits and Labels
    # Pythia outputs logits for [VirtualToken, Token1, Token2, ...]
    # We want to predict [Token1, Token2, ...] given previous history.
    # Shift logits: Remove last logit, Start from index 0 (prediction for first real token)
    # The logits at index 0 (corresponding to VirtualToken) predict Token1 (target_ids[:, 0])

    # logits shape: (batch, seq_len + 1, vocab)
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = target_ids.contiguous()

    # 4. Calculate Loss
    loss_fct = nn.CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    # 5. Backward & Perturb (in the direction that maximizes noise)
    attacker_model.zero_grad()
    loss.backward()

    # Add noise to maximize loss
    noise = epsilon * protected_emb.grad.sign()
    final_emb = sentence_embedding + noise

    return final_emb.detach()

# ==========================================
# 4. MAIN EXECUTION
# ==========================================


modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/650 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

In [17]:
import json

def load_personachat():
    path = 'data/personachat/processed_persona/train_merged_shuffle.txt'
    with open(path, 'r') as f:
        data = json.load(f)
        train_sentence_list = []
        for i,dict_i in enumerate(data):
            conv = dict_i['conv']
            train_sentence_list.extend(conv)
    return train_sentence_list


sent_list = load_personachat()
sent_list_unique = set(sent_list)
sent_list = list(sent_list_unique)

def load_validation_data():
    path = 'data/personachat/processed_persona/dev_merged_shuffle.txt'
    with open(path, 'r') as f:
        data = json.load(f)
        val_sentence_list = []
        for i,dict_i in enumerate(data):
            conv = dict_i['conv']
            val_sentence_list.extend(conv)
    return val_sentence_list

val_sent = load_validation_data()
val_sent_unique = set(val_sent)
val_sent = list(val_sent_unique)

def get_validation_loss(attacker, dataloader):
    attacker.eval()
    batch_size = 128
    with torch.no_grad():
        avg_loss = 0
        total = 0
        for data in dataloader:
            clean_embs = get_clean_embedding(data).clone()
            logits, target_ids = attacker(clean_embs, data)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = target_ids.contiguous()
            loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            avg_loss += loss.item()
            total += 1
    return avg_loss / total

def save_model(model, path):
    torch.save(model.state_dict(), path)


from torch.utils.data import DataLoader, Dataset

# Define a custom Dataset class
class SentenceDataset(Dataset):
    def __init__(self, sentence_list):
        self.sentence_list = sentence_list

    def __len__(self):
        return len(self.sentence_list)

    def __getitem__(self, idx):
        return self.sentence_list[idx]

batch_size = 128

train_dataset = SentenceDataset(sent_list)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = SentenceDataset(val_sent)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [26]:
from tqdm import tqdm

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on: {device}\n")


    # Attacker (Pythia 160m)
    input_dimension = victim_model.get_sentence_embedding_dimension()
    attacker = PythiaAttacker(input_dim=input_dimension, device=device).to(device)

    # Optimization only on Attacker
    optimizer = optim.AdamW(attacker.parameters(), lr=5e-5)

    print("\n--- Phase 1: Training the Attacker (Pythia) ---")
    # # Pre-calculate victim embeddings (simulating a static database)
    # clean_embs, _ = get_clean_embedding(data)

    batch_size = 128

    attacker.train()
    best_loss = 100000000
    streak= 0
    streak_cap = 3
    for epoch in range(10):
        epoch_loss = 0
        num_batches = 0
        for data in tqdm(train_dataloader, desc=f"Processing epoch: {epoch}"):
            # data = sent_list[batch_id:batch_id+batch_size]
            optimizer.zero_grad()

            clean_embs = get_clean_embedding(data).clone()
            logits, target_ids = attacker(clean_embs, data)
            """Processing the whole sequence leads to the same final hidden state for each token"""

            # Shift for Causal LM Loss
            # Logits: [Prom, T1, T2...] -> predict [T1, T2, EOS...]
            shift_logits = logits[:, :-1, :].contiguous()  # remove the last token because we can't calculate the loss on that
            shift_labels = target_ids.contiguous()

            loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        if (epoch+1) % 1 == 0:
            validation_loss = get_validation_loss(attacker, val_dataloader)
            print(f"Epoch {epoch}\n loss: {epoch_loss / num_batches}\n validation loss: ", validation_loss)

            if validation_loss < best_loss:
                best_loss = validation_loss
                print("New best at epoch ", epoch, " Saving current model ....\n")
                save_model(attacker, f'FGSM/attacker_e5v2_best.pth')
            else:
                streak += 1
                if streak > streak_cap:
                    print("Early stopping triggered.")
                    break



    print("Attacker trained.\n")

    # --- Phase 2: Apply FGSM Defense ---
    print("--- Phase 2: Testing FGSM Defense ---")

    test_text = ["Privacy is a fundamental human right."]
    test_emb = get_clean_embedding(test_text).clone().detach()

    # 1. Check Attacker success on CLEAN embedding
    attacker.eval()
    with torch.no_grad():
        logits, ids = attacker(test_emb, test_text)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = ids.contiguous()
        clean_loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    # 2. Generate PROTECTED embedding
    # We use a high epsilon here to visually prove the concept
    epsilon = 5.0
    protected_emb = fgsm_protect(test_emb, test_text, attacker, epsilon=epsilon)

    # 3. Check Attacker failure on PROTECTED embedding
    with torch.no_grad():
        logits, ids = attacker(protected_emb, test_text)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = ids.contiguous()
        adv_loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    print(f"Original Text: {test_text[0]}")
    print(f"Attacker Loss (Clean):     {clean_loss.item():.4f}")
    print(f"Attacker Loss (Protected): {adv_loss.item():.4f}")
    print(f"Difference:                {adv_loss.item() - clean_loss.item():.4f}")

Running on: cuda

Loading Pythia-160m...

--- Phase 1: Training the Attacker (Pythia) ---


Processing epoch: 0: 100%|██████████| 985/985 [03:32<00:00,  4.64it/s]


Epoch 0
 loss: 1.4693829783933416
 validation loss:  1.055190668042217
New best at epoch  0  Saving current model ....



Processing epoch: 1: 100%|██████████| 985/985 [03:32<00:00,  4.64it/s]


Epoch 1
 loss: 0.868660680775715
 validation loss:  0.8983886736844268
New best at epoch  1  Saving current model ....



Processing epoch: 2: 100%|██████████| 985/985 [03:32<00:00,  4.63it/s]


Epoch 2
 loss: 0.6485107631852784
 validation loss:  0.8618634247354099
New best at epoch  2  Saving current model ....



Processing epoch: 3: 100%|██████████| 985/985 [03:32<00:00,  4.64it/s]


Epoch 3
 loss: 0.5007818928527348
 validation loss:  0.8892493077686855


Processing epoch: 4: 100%|██████████| 985/985 [03:31<00:00,  4.65it/s]


Epoch 4
 loss: 0.3889142135375647
 validation loss:  0.9399081782570907


Processing epoch: 5: 100%|██████████| 985/985 [03:31<00:00,  4.65it/s]


Epoch 5
 loss: 0.2883714701167218
 validation loss:  1.0332311556807585


Processing epoch: 6: 100%|██████████| 985/985 [03:31<00:00,  4.65it/s]


Epoch 6
 loss: 0.21993919626105254
 validation loss:  1.0845743848809175
Early stopping triggered.
Attacker trained.

--- Phase 2: Testing FGSM Defense ---
Original Text: Privacy is a fundamental human right.
Attacker Loss (Clean):     6.7112
Attacker Loss (Protected): 16.3863
Difference:                9.6751


In [7]:
import torch
import os

# Define the device (same as during training)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the attacker model (ensure PythiaAttacker class is defined)
loaded_attacker = PythiaAttacker(input_dim=768, device=device).to(device)

# Load the saved state dictionary
save_directory = 'FGSM'
model_path = os.path.join(save_directory, 'attacker_distilroberta_best.pth') # Assuming it was saved as 'attacker_model_best.pth'
loaded_attacker.load_state_dict(torch.load(model_path, map_location=device))

# Set the model to evaluation mode for inference
loaded_attacker.eval()

print(f"Pythia Attacker model loaded from {model_path} and set to evaluation mode.")

Loading Pythia-160m...
Pythia Attacker model loaded from FGSM/attacker_distilroberta_best.pth and set to evaluation mode.


In [19]:
!ls FGSM/

attacker_model_1.pth


In [15]:
test_text = ["I have an exam on next week tuesday"]
test_emb = get_clean_embedding(test_text).clone().detach()

loaded_attacker.eval()
logits, ids = loaded_attacker(test_emb, test_text)
predicted_ids = torch.argmax(logits, dim=-1)
loaded_attacker.tokenizer.batch_decode(predicted_ids, skip_special_tokens=False)

['see have a exam on satur week souesday<|endoftext|>']

In [28]:
loaded_attacker.tokenizer.eos_token

'<|endoftext|>'

In [14]:
protected_embd = fgsm_protect(test_emb, test_text, loaded_attacker, epsilon=0.5)
protected_logits, protected_ids = loaded_attacker(protected_embd, test_text)
predicted_ids = torch.argmax(protected_logits, dim=-1)
loaded_attacker.tokenizer.batch_decode(predicted_ids, skip_special_tokens=False)

[' . . . cra . . . .hs .']