In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertModel
import pandas as pd

# Load data
splits = {
    'train': 'data/train-00000-of-00001.parquet', 
    'validation': 'data/validation-00000-of-00001.parquet', 
    'test': 'data/test-00000-of-00001.parquet'
}
df = pd.read_parquet("hf://datasets/MichaelR207/prism_personalized_1023/" + splits["train"]).sample(3000)


# Model for shared mapping function f
class SharedMapping(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SharedMapping, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, x):
        return self.fc(x)
    

# Transformer-based embedding
class TextEmbedder:
    def __init__(self, model_name="distilbert-base-uncased"):
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertModel.from_pretrained(model_name)

    def encode(self, texts):
        tokens = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            embeddings = self.model(**tokens).last_hidden_state.mean(dim=1)  # Mean pooling
        return embeddings

# Step 1: Create a mapping from user_id strings to unique integer indices
unique_user_ids = df['user_id'].unique()
user_id_mapping = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
num_users = len(unique_user_ids)  # Update number of users
print(user_id_mapping)

# Step 2: Update the PreferenceDataset to use the user_id mapping
class PreferenceDataset(Dataset):
    def __init__(self, dataframe, text_embedder, user_id_mapping):
        self.dataframe = dataframe
        self.text_embedder = text_embedder
        self.user_id_mapping = user_id_mapping

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract text content
        context_text = row['context'][0]['content']
        chosen_text = row['chosen']['content']
        rejected_text = row['rejected']['content']
        
        # Get embeddings for context, chosen, rejected
        context_embed = self.text_embedder.encode([context_text])
        chosen_embed = self.text_embedder.encode([chosen_text])
        rejected_embed = self.text_embedder.encode([rejected_text])
        
        # Map user_id to index
        user_id = self.user_id_mapping[row['user_id']]
        
        # Set preference label
        label = torch.tensor(1.0 if row['chosen_score'] > row['rejected_score'] else 0.0)
        
        return context_embed.squeeze(0), chosen_embed.squeeze(0), rejected_embed.squeeze(0), user_id, label



class PAL_A(nn.Module):
    def __init__(self, input_dim, output_dim, num_prototypes, num_users):
        super(PAL_A, self).__init__()
        self.shared_mapping = SharedMapping(input_dim, output_dim)
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, output_dim))
        self.user_weights = nn.Parameter(torch.rand(num_users, num_prototypes))
        self.num_users = num_users

    def forward(self, context, choice_embed, user_id):
        f_context = self.shared_mapping(context)
        f_choice = self.shared_mapping(choice_embed)  # Differentiate chosen/rejected
        if user_id is None:
            a_i = torch.matmul(self.user_weights.mean(dim=0), self.prototypes)
        else:
            a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        #print(a_i.shape)
        #print(user_id)
        #a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        # Calculate similarity for choice and context separately
        reward = torch.sum(f_choice * a_i, dim=1)  # Dot product for similarity
        #print(reward.shape)
        return reward

"""Using Context"""

# class PAL_A(nn.Module):
#     def __init__(self, input_dim, output_dim, num_prototypes, num_users):
#         super(PAL_A, self).__init__()
#         self.shared_mapping = SharedMapping(input_dim, output_dim)
#         self.prototypes = nn.Parameter(torch.randn(num_prototypes, output_dim))
#         self.user_weights = nn.Parameter(torch.rand(num_users, num_prototypes))
#         self.num_users = num_users

#     def forward(self, context, choice_embed, user_id=None):
#         f_context = self.shared_mapping(context)
#         f_choice = self.shared_mapping(choice_embed)
        
#         combined = torch.cat((f_context, f_choice), dim=1)
#         context_aware_choice = nn.Linear(2 * f_choice.size(1), f_choice.size(1)).to(context.device)(combined)
        
#         if user_id is None:
#             a_i = self.prototypes.mean(dim=0)  # For unseen users
#         else:
#             a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        
#         # Calculate reward based on context-aware choice embedding
#         reward = torch.sum(context_aware_choice * a_i, dim=1)
#         return reward


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_embedder = TextEmbedder()
input_dim = 768  # For DistilBERT embedding dimensions
output_dim = 128
num_prototypes = 5
num_users = len(df['user_id'].unique())

model = PAL_A(input_dim=input_dim, output_dim=output_dim, num_prototypes=num_prototypes, num_users=num_users).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

def preference_loss(reward_chosen, reward_rejected, label):
    #print(reward_chosen)
    #print(reward_rejected)
    diff = reward_chosen - reward_rejected
    #print("Reward Difference:", diff) 
    return torch.nn.functional.binary_cross_entropy_with_logits(diff, label)

# Training Loop
def train_model(model, dataloader, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for context, chosen, rejected, user_id, label in dataloader:
            context, chosen, rejected, label = context.to(device), chosen.to(device), rejected.to(device), label.to(device)
            user_id = user_id.to(device)
            

            reward_chosen = model(context, chosen, user_id)
            reward_rejected = model(context, rejected, user_id)
            loss = preference_loss(reward_chosen, reward_rejected, label)
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}')


train_dataset = PreferenceDataset(df[df['split'] == 'train'], text_embedder, user_id_mapping)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

train_model(model, train_dataloader, optimizer)


{'user859': 0, 'user710': 1, 'user549': 2, 'user727': 3, 'user844': 4, 'user702': 5, 'user931': 6, 'user939': 7, 'user778': 8, 'user571': 9, 'user509': 10, 'user771': 11, 'user95': 12, 'user992': 13, 'user510': 14, 'user976': 15, 'user546': 16, 'user648': 17, 'user811': 18, 'user622': 19, 'user898': 20, 'user764': 21, 'user674': 22, 'user981': 23, 'user812': 24, 'user66': 25, 'user925': 26, 'user576': 27, 'user742': 28, 'user889': 29, 'user826': 30, 'user679': 31, 'user886': 32, 'user748': 33, 'user92': 34, 'user699': 35, 'user682': 36, 'user517': 37, 'user802': 38, 'user822': 39, 'user84': 40, 'user5': 41, 'user776': 42, 'user749': 43, 'user750': 44, 'user921': 45, 'user542': 46, 'user961': 47, 'user754': 48, 'user665': 49, 'user600': 50, 'user539': 51, 'user598': 52, 'user880': 53, 'user860': 54, 'user486': 55, 'user706': 56, 'user566': 57, 'user738': 58, 'user728': 59, 'user794': 60, 'user777': 61, 'user878': 62, 'user77': 63, 'user955': 64, 'user715': 65, 'user739': 66, 'user480': 

In [49]:
example_row = df.iloc[78]
context_text = example_row['context'][0]['content']  # Assuming context is a list of dicts with a 'content' key
chosen_text = example_row['chosen']['content']       # Assuming chosen and rejected are dicts with a 'content' key
rejected_text = example_row['rejected']['content']

print("\nExample Data Point:")
print("Context:", context_text)
print("Chosen:", chosen_text)
print("Rejected:", rejected_text)


def make_prediction(model, context_text, chosen_text, rejected_text):
    model.eval()  # Set model to evaluation mode
    
    # Convert text data to embeddings
    context_embed = text_embedder.encode(context_text).to(device)
    chosen_embed = text_embedder.encode(chosen_text).to(device)
    rejected_embed = text_embedder.encode(rejected_text).to(device)

    # Choose an example user_id (in practice, use the specific user_id mapping)
    user_id = torch.tensor([78], dtype=torch.long).to(device)  # Example user_id 0

    # Get rewards for chosen and rejected
    reward_chosen = model(context_embed, chosen_embed, user_id)
    reward_rejected = model(context_embed, rejected_embed, user_id)

    # Prediction based on reward values
    prediction = "chosen" if reward_chosen > reward_rejected else "rejected"
    print(f"\nPrediction: {prediction}")
    print(f"Reward for Chosen: {reward_chosen.item()}, Reward for Rejected: {reward_rejected.item()}")

# Test prediction with an example data point
make_prediction(model, context_text, chosen_text, rejected_text)


Example Data Point:
Context: red or blue
Chosen: Sure, I'd be happy to help! Would you like to know the difference between red and blue, or perhaps something else?
Rejected: As Coral, an AI chatbot designed to assist humans in a conversational manner, I do not have personal preferences or opinions. I can, however, provide information or generate responses based on my training to help answer your queries to the best of my knowledge and abilities. If you have any specific questions or requests, feel free to ask me, and I'll do my best to assist you!

Prediction: chosen
Reward for Chosen: 1.3204188346862793, Reward for Rejected: 0.15617117285728455


In [19]:
# Step 1: Check for Consistent Mapping
unique_user_ids = df['user_id'].unique()
mapped_indices = set(user_id_mapping.values())

if len(user_id_mapping) != len(unique_user_ids):
    print("Error: Inconsistent mapping. The number of unique user_ids does not match the size of user_id_mapping.")
else:
    print("User ID mapping is consistent with the number of unique user IDs.")

# Step 2: Verify Mapping Coverage
missing_user_ids = [user_id for user_id in unique_user_ids if user_id not in user_id_mapping]
if missing_user_ids:
    print(f"Error: Missing user IDs in user_id_mapping: {missing_user_ids}")
else:
    print("All user IDs in the dataset are covered in user_id_mapping.")

# Step 3: Ensure No Out-of-Bounds Indices
if max(user_id_mapping.values()) >= num_users or min(user_id_mapping.values()) < 0:
    print("Error: Out-of-bounds indices in user_id_mapping.")
else:
    print("All mapped indices are within bounds.")

# Step 4: Spot-check Random Samples to Confirm Accuracy
import random
sample_user_ids = random.sample(list(unique_user_ids), min(5, len(unique_user_ids)))
print("Sample check of user_id_mapping:")
for user_id in sample_user_ids:
    print(f"User ID: {user_id} -> Mapped Index: {user_id_mapping[user_id]}")

# Additional Check: Ensure each user_id maps to a unique integer index
inverse_mapping = {v: k for k, v in user_id_mapping.items()}
if len(inverse_mapping) != len(user_id_mapping):
    print("Error: Duplicate indices in user_id_mapping.")
else:
    print("Each user ID maps to a unique index.")


User ID mapping is consistent with the number of unique user IDs.
All user IDs in the dataset are covered in user_id_mapping.
All mapped indices are within bounds.
Sample check of user_id_mapping:
User ID: user925 -> Mapped Index: 0
User ID: user509 -> Mapped Index: 1
User ID: user84 -> Mapped Index: 17
User ID: user941 -> Mapped Index: 31
User ID: user870 -> Mapped Index: 34
Each user ID maps to a unique index.


In [15]:

import torch

df_test = pd.read_parquet("hf://datasets/MichaelR207/prism_personalized_1023/" + splits["test"]).sample(1000)
print(df_test.shape)

def make_prediction(model, context_text, chosen_text, rejected_text, user_id):
    model.eval()  

    context_embed = text_embedder.encode([context_text]).to(device)
    chosen_embed = text_embedder.encode([chosen_text]).to(device)
    print(chosen_embed.shape)
    rejected_embed = text_embedder.encode([rejected_text]).to(device)

    if user_id in user_id_mapping:
        user_index = user_id_mapping[user_id]
        user_index = torch.tensor([user_index], dtype=torch.long).to(device)
    else:
        # Use average user weights for unseen users, pass None to indicate unseen user for model forward
        user_index = None
    print(user_index)

    reward_chosen = model(context_embed, chosen_embed, user_index)
    print(reward_chosen.shape)
    reward_rejected = model(context_embed, rejected_embed, user_index)
    diff = reward_chosen - reward_rejected

    # Using diff mean for reward tensors.
    prediction = "chosen" if diff.mean() > 0 else "rejected"  
    return prediction == "chosen"  

# Calculate accuracy
correct_predictions = 0
for idx, row in df_test.iterrows():
    print(idx)
    context_text = row['context'][0]['content'] 
    chosen_text = row['chosen']['content']
    rejected_text = row['rejected']['content']
    user_id = row['user_id']
    
    is_correct = make_prediction(model, context_text, chosen_text, rejected_text, user_id)
    correct_predictions += int(is_correct)  

accuracy = correct_predictions / len(df_test)
print(f"\nAccuracy on first 100 test points: {accuracy * 100:.2f}%")



(1000, 12)
14881
torch.Size([1, 768])
None
torch.Size([1, 128])
12596
torch.Size([1, 768])
None
torch.Size([1, 128])
26419
torch.Size([1, 768])
None
torch.Size([1, 128])
25419
torch.Size([1, 768])
None
torch.Size([1, 128])
5838
torch.Size([1, 768])
None
torch.Size([1, 128])
8829
torch.Size([1, 768])
None
torch.Size([1, 128])
327
torch.Size([1, 768])
None
torch.Size([1, 128])
17754
torch.Size([1, 768])
None
torch.Size([1, 128])
11467
torch.Size([1, 768])
None
torch.Size([1, 128])
1254
torch.Size([1, 768])
None
torch.Size([1, 128])
14952
torch.Size([1, 768])
None
torch.Size([1, 128])
21521
torch.Size([1, 768])
None
torch.Size([1, 128])
15893
torch.Size([1, 768])
None
torch.Size([1, 128])
25631
torch.Size([1, 768])
None
torch.Size([1, 128])
24339
torch.Size([1, 768])
None
torch.Size([1, 128])
13713
torch.Size([1, 768])
None
torch.Size([1, 128])
11801
torch.Size([1, 768])
None
torch.Size([1, 128])
2178
torch.Size([1, 768])
None
torch.Size([1, 128])
2165
torch.Size([1, 768])
None
torch.Size