In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertModel
import pandas as pd

# Load data
splits = {
    'train': 'data/train-00000-of-00001.parquet', 
    'validation': 'data/validation-00000-of-00001.parquet', 
    'test': 'data/test-00000-of-00001.parquet'
}
df = pd.read_parquet("hf://datasets/MichaelR207/prism_personalized_1023/" + splits["train"])


# Model for shared mapping function f
class SharedMapping(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SharedMapping, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, x):
        return self.fc(x)
    

# Transformer-based embedding
class TextEmbedder:
    def __init__(self, model_name="distilbert-base-uncased"):
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertModel.from_pretrained(model_name)

    def encode(self, texts):
        tokens = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            embeddings = self.model(**tokens).last_hidden_state.mean(dim=1)  # Mean pooling
        return embeddings

# Step 1: Create a mapping from user_id strings to unique integer indices
unique_user_ids = df['user_id'].unique()
user_id_mapping = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
num_users = len(unique_user_ids)  # Update number of users
print(user_id_mapping)

# Step 2: Update the PreferenceDataset to use the user_id mapping
class PreferenceDataset(Dataset):
    def __init__(self, dataframe, text_embedder, user_id_mapping):
        self.dataframe = dataframe
        self.text_embedder = text_embedder
        self.user_id_mapping = user_id_mapping

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract text content
        context_text = row['context'][0]['content']
        chosen_text = row['chosen']['content']
        rejected_text = row['rejected']['content']
        
        # Get embeddings for context, chosen, rejected
        context_embed = self.text_embedder.encode([context_text])
        chosen_embed = self.text_embedder.encode([chosen_text])
        rejected_embed = self.text_embedder.encode([rejected_text])
        
        # Map user_id to index
        user_id = self.user_id_mapping[row['user_id']]
        
        # Set preference label
        label = torch.tensor(1.0 if row['chosen_score'] > row['rejected_score'] else 0.0)
        
        return context_embed.squeeze(0), chosen_embed.squeeze(0), rejected_embed.squeeze(0), user_id, label



class PAL_A(nn.Module):
    def __init__(self, input_dim, output_dim, num_prototypes, num_users):
        super(PAL_A, self).__init__()
        self.shared_mapping = SharedMapping(input_dim, output_dim)
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, output_dim))
        self.user_weights = nn.Parameter(torch.rand(num_users, num_prototypes))
        self.num_users = num_users

    def forward(self, context, choice_embed, user_id):
        f_context = self.shared_mapping(context)
        f_choice = self.shared_mapping(choice_embed)  # Differentiate chosen/rejected
        if user_id is None:
            a_i = torch.matmul(self.user_weights.mean(dim=0), self.prototypes)
        else:
            a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        #print(a_i.shape)
        #print(user_id)
        #a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        # Calculate similarity for choice and context separately
        reward = torch.sum(f_choice * a_i, dim=1)  # Dot product for similarity
        #print(reward.shape)
        return reward

"""Using Context"""

# class PAL_A(nn.Module):
#     def __init__(self, input_dim, output_dim, num_prototypes, num_users):
#         super(PAL_A, self).__init__()
#         self.shared_mapping = SharedMapping(input_dim, output_dim)
#         self.prototypes = nn.Parameter(torch.randn(num_prototypes, output_dim))
#         self.user_weights = nn.Parameter(torch.rand(num_users, num_prototypes))
#         self.num_users = num_users

#     def forward(self, context, choice_embed, user_id=None):
#         f_context = self.shared_mapping(context)
#         f_choice = self.shared_mapping(choice_embed)
        
#         combined = torch.cat((f_context, f_choice), dim=1)
#         context_aware_choice = nn.Linear(2 * f_choice.size(1), f_choice.size(1)).to(context.device)(combined)
        
#         if user_id is None:
#             a_i = self.prototypes.mean(dim=0)  # For unseen users
#         else:
#             a_i = torch.matmul(self.user_weights[user_id], self.prototypes)
        
#         # Calculate reward based on context-aware choice embedding
#         reward = torch.sum(context_aware_choice * a_i, dim=1)
#         return reward


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_embedder = TextEmbedder()
input_dim = 768  # For DistilBERT embedding dimensions
output_dim = 128
num_prototypes = 5
num_users = len(df['user_id'].unique())

model = PAL_A(input_dim=input_dim, output_dim=output_dim, num_prototypes=num_prototypes, num_users=num_users).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

def preference_loss(reward_chosen, reward_rejected, label):
    #print(reward_chosen)
    #print(reward_rejected)
    diff = reward_chosen - reward_rejected
    #print("Reward Difference:", diff) 
    return torch.nn.functional.binary_cross_entropy_with_logits(diff, label)

# Training Loop
def train_model(model, dataloader, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for context, chosen, rejected, user_id, label in dataloader:
            context, chosen, rejected, label = context.to(device), chosen.to(device), rejected.to(device), label.to(device)
            user_id = user_id.to(device)
            

            reward_chosen = model(context, chosen, user_id)
            reward_rejected = model(context, rejected, user_id)
            loss = preference_loss(reward_chosen, reward_rejected, label)
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}')


train_dataset = PreferenceDataset(df[df['split'] == 'train'], text_embedder, user_id_mapping)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

train_model(model, train_dataloader, optimizer, 2)


{'user469': 0, 'user47': 1, 'user470': 2, 'user471': 3, 'user473': 4, 'user474': 5, 'user475': 6, 'user476': 7, 'user477': 8, 'user478': 9, 'user479': 10, 'user480': 11, 'user481': 12, 'user482': 13, 'user483': 14, 'user486': 15, 'user487': 16, 'user488': 17, 'user489': 18, 'user49': 19, 'user490': 20, 'user491': 21, 'user492': 22, 'user494': 23, 'user495': 24, 'user496': 25, 'user497': 26, 'user499': 27, 'user5': 28, 'user50': 29, 'user500': 30, 'user501': 31, 'user502': 32, 'user504': 33, 'user505': 34, 'user506': 35, 'user507': 36, 'user508': 37, 'user509': 38, 'user51': 39, 'user510': 40, 'user512': 41, 'user514': 42, 'user515': 43, 'user516': 44, 'user517': 45, 'user518': 46, 'user519': 47, 'user52': 48, 'user520': 49, 'user521': 50, 'user522': 51, 'user523': 52, 'user524': 53, 'user525': 54, 'user526': 55, 'user527': 56, 'user528': 57, 'user53': 58, 'user530': 59, 'user531': 60, 'user532': 61, 'user533': 62, 'user534': 63, 'user535': 64, 'user536': 65, 'user537': 66, 'user538': 6

In [2]:

import torch

df_test = pd.read_parquet("hf://datasets/MichaelR207/prism_personalized_1023/" + splits["test"])
print(df_test.shape)

def make_prediction(model, context_text, chosen_text, rejected_text, user_id):
    model.eval()  

    context_embed = text_embedder.encode([context_text]).to(device)
    chosen_embed = text_embedder.encode([chosen_text]).to(device)
    print(chosen_embed.shape)
    rejected_embed = text_embedder.encode([rejected_text]).to(device)

    if user_id in user_id_mapping:
        user_index = user_id_mapping[user_id]
        user_index = torch.tensor([user_index], dtype=torch.long).to(device)
    else:
        # Use average user weights for unseen users, pass None to indicate unseen user for model forward
        user_index = None
    print(user_index)

    reward_chosen = model(context_embed, chosen_embed, user_index)
    print(reward_chosen.shape)
    reward_rejected = model(context_embed, rejected_embed, user_index)
    diff = reward_chosen - reward_rejected

    # Using diff mean for reward tensors.
    prediction = "chosen" if diff.mean() > 0 else "rejected"  
    return prediction == "chosen"  

# Calculate accuracy
correct_predictions = 0
for idx, row in df_test.iterrows():
    print(idx)
    context_text = row['context'][0]['content'] 
    chosen_text = row['chosen']['content']
    rejected_text = row['rejected']['content']
    user_id = row['user_id']
    
    is_correct = make_prediction(model, context_text, chosen_text, rejected_text, user_id)
    correct_predictions += int(is_correct)  

accuracy = correct_predictions / len(df_test)
print(f"\nAccuracy on first 100 test points: {accuracy * 100:.2f}%")



(26702, 12)
0
torch.Size([1, 768])
None
torch.Size([1])
1
torch.Size([1, 768])
None
torch.Size([1])
2
torch.Size([1, 768])
None
torch.Size([1])
3
torch.Size([1, 768])
None
torch.Size([1])
4
torch.Size([1, 768])
None
torch.Size([1])
5
torch.Size([1, 768])
None
torch.Size([1])
6
torch.Size([1, 768])
None
torch.Size([1])
7
torch.Size([1, 768])
None
torch.Size([1])
8
torch.Size([1, 768])
None
torch.Size([1])
9
torch.Size([1, 768])
None
torch.Size([1])
10
torch.Size([1, 768])
None
torch.Size([1])
11
torch.Size([1, 768])
None
torch.Size([1])
12
torch.Size([1, 768])
None
torch.Size([1])
13
torch.Size([1, 768])
None
torch.Size([1])
14
torch.Size([1, 768])
None
torch.Size([1])
15
torch.Size([1, 768])
None
torch.Size([1])
16
torch.Size([1, 768])
None
torch.Size([1])
17
torch.Size([1, 768])
None
torch.Size([1])
18
torch.Size([1, 768])
None
torch.Size([1])
19
torch.Size([1, 768])
None
torch.Size([1])
20
torch.Size([1, 768])
None
torch.Size([1])
21
torch.Size([1, 768])
None
torch.Size([1])
22
torch