In [None]:
import torch
from torch.utils.data import Dataset
import random
import torch.nn as nn
import torchvision.models as models
from torch.nn import TripletMarginLoss
from torch.utils.data import DataLoader
import torch.optim as optim

In [None]:
non_hallucination_train_activations = torch.load("./train_normal_activations.pt", map_location=torch.device("cpu"), weights_only=True)

hallucination_train_activations = pt2 = torch.load("./train_attack_activations.pt", map_location=torch.device("cpu"), weights_only=True)

In [None]:
import torch
from torch.utils.data import Dataset

class TripletActivationDataset(Dataset):
    def __init__(self, real_data, halluc_data, target_shape=4096):
        self.triplets = []
        self.target_shape = target_shape

        for i in range(len(real_data)):  # 35 entries
            for j in range(len(real_data[i])):  # 50 tokens
                for k in range(len(real_data[i][j])):  # 33 layers
                    anchor = real_data[i][j][k][0]
                    positive = real_data[i][j][(k+1) % 33][0]
                    negative = halluc_data[i][j][k][0]

                    # Preprocess each tensor
                    anchor = self._standardize(anchor)
                    positive = self._standardize(positive)
                    negative = self._standardize(negative)

                    self.triplets.append((anchor, positive, negative))

    def _standardize(self, tensor):
        tensor = tensor.squeeze().float()  # <--- add .float() here
        if tensor.ndim > 1:
            tensor = tensor.view(-1)
        if tensor.numel() > self.target_shape:
            return tensor[:self.target_shape]
        elif tensor.numel() < self.target_shape:
            pad_size = self.target_shape - tensor.numel()
            return torch.nn.functional.pad(tensor, (0, pad_size))
        return tensor


    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        return self.triplets[idx]



In [None]:
class ResNetEmbedding(nn.Module):
    def __init__(self, input_dim=4096, embedding_dim=512):
        super(ResNetEmbedding, self).__init__()
        self.fc_in = nn.Linear(input_dim, 3 * 16 * 16)
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_dim)

    def forward(self, x):
        x = self.fc_in(x)
        x = x.view(x.size(0), 3, 16, 16)
        return self.resnet(x)


In [None]:
dataset = TripletActivationDataset(non_hallucination_train_activations, hallucination_train_activations)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

model = ResNetEmbedding().cuda()
criterion = TripletMarginLoss(margin=1.0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    model.train()
    for anchor, positive, negative in loader:
        anchor, positive, negative = anchor.cuda(), positive.cuda(), negative.cuda()

        anchor_emb = model(anchor)
        positive_emb = model(positive)
        negative_emb = model(negative)

        loss = criterion(anchor_emb, positive_emb, negative_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


In [None]:
test_normal = torch.load("./test_normal_activations.pt", map_location=torch.device("cpu"), weights_only=True)

test_hall = pt2 = torch.load("./test_attack_activations.pt", map_location=torch.device("cpu"), weights_only=True)

In [None]:
import torch
import numpy as np

# ---------------------------
# Step 1: Helper to extract tensors from nested triplet structure
# ---------------------------
def extract_all_embeddings(data, index=0):
    """
    Flattens nested triplet structure and ensures uniform shape [4096].
    Automatically handles batched entries.
    """
    activations = []
    for group in data:
        for triplet_list in group:
            for triplet in triplet_list:
                act = triplet[index]  # Get anchor/positive/negative

                if len(act.shape) == 2 and act.shape[0] > 1:
                    # Batched entries like [16, 4096], split them
                    for sub in act:
                        activations.append(sub)
                else:
                    # Single entry like [1, 4096] or [4096]
                    activations.append(act.squeeze(0) if act.shape[0] == 1 else act)

    return torch.stack(activations).float().cuda()


# ---------------------------
# Step 2: Convert test data
# ---------------------------

test_normal_tensor = extract_all_embeddings(test_normal, index=0)  # anchor (normal)
test_hall_tensor = extract_all_embeddings(test_hall, index=0)      # anchor (hallucination)

# ---------------------------
# Step 3: Compute embeddings for test data
# ---------------------------

model.eval()
with torch.no_grad():
    emb_normal = model(test_normal_tensor)
    emb_hall = model(test_hall_tensor)



In [None]:
# ---------------------------
# Step 4: Compute training class centroids
# ---------------------------

with torch.no_grad():
    normal_train_tensor = test_normal_tensor.float().cuda()
    hall_train_tensor = test_hall_tensor.float().cuda()

    normal_centroid = model(normal_train_tensor).mean(dim=0)
    hall_centroid = model(hall_train_tensor).mean(dim=0)

# ---------------------------
# Step 5: Predict class based on nearest centroid
# ---------------------------

def predict_class(embedding, centroid_normal, centroid_hall):
    dist_normal = torch.norm(embedding - centroid_normal, p=2, dim=1)
    dist_hall = torch.norm(embedding - centroid_hall, p=2, dim=1)
    return (dist_normal < dist_hall).long()  # 1 = normal, 0 = hallucination

# Predict for both sets
pred_normal = predict_class(emb_normal, normal_centroid, hall_centroid)  # Expected: 1
pred_hall = predict_class(emb_hall, normal_centroid, hall_centroid)      # Expected: 0

# Ground truth labels
true_normal = torch.ones_like(pred_normal)
true_hall = torch.zeros_like(pred_hall)

# Combine predictions and labels
y_pred = torch.cat([pred_normal, pred_hall])
y_true = torch.cat([true_normal, true_hall])

# ---------------------------
# Step 6: Compute accuracy
# ---------------------------

accuracy = (y_pred == y_true).float().mean().item()
print(f"Test Accuracy: {accuracy * 100:.2f}%")


In [None]:
# Assuming y_true and y_pred are 1D tensors of 0/1 labels

# True positives, false positives, false negatives, true negatives
TP = ((y_pred == 1) & (y_true == 1)).sum().item()
FP = ((y_pred == 1) & (y_true == 0)).sum().item()
FN = ((y_pred == 0) & (y_true == 1)).sum().item()
TN = ((y_pred == 0) & (y_true == 0)).sum().item()

# Precision = TP / (TP + FP)
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0

# Recall = TP / (TP + FN)
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0

# F1 score = 2 * (precision * recall) / (precision + recall)
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

# Accuracy = (TP + TN) / Total
accuracy = (TP + TN) / (TP + TN + FP + FN)

print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision * 100:.2f}%")
print(f"Recall: {recall * 100:.2f}%")
print(f"F1 Score: {f1 * 100:.2f}%")
