In [None]:
## mean pooling 

'''
>>> Script for getting embeddings with mean pooling for better generalization
IMPORTANT: The script currently loads the data from google drive. 
You can change input path, based on where your data is.

'''



import json
import torch
import numpy as np
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm import tqdm

model_name = "distilbert/distilbert-base-cased-distilled-squad"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertModel.from_pretrained(model_name).half().to("cuda")
model.eval()

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state  # (1, seq_len, hidden_size)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return (sum_embeddings / sum_mask).squeeze().cpu().numpy()


def embed_comment(comment_text):
    inputs = tokenizer(comment_text, return_tensors="pt", truncation=True, padding=True, max_length=512).to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
    return mean_pooling(outputs, inputs['attention_mask'])

def embed_long_comment(text, chunk_size=510, stride=128):
    tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
    chunks = []

    for i in range(0, len(tokens), stride):
        chunk = tokens[i:i+chunk_size]
        if len(chunk) == 0:
            continue
        chunk = tokenizer.build_inputs_with_special_tokens(chunk.tolist())
        chunk_inputs = torch.tensor([chunk]).to("cuda")
        attention_mask = (chunk_inputs != tokenizer.pad_token_id).long()

        with torch.no_grad():
            outputs = model(chunk_inputs)
        chunk_embedding = mean_pooling(outputs, attention_mask)
        chunks.append(chunk_embedding)

        if i + chunk_size >= len(tokens):
            break

    return np.mean(chunks, axis=0)

def get_embedding(text):
    tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
    return embed_comment(text) if len(tokens) <= 512 else embed_long_comment(text)

# currently loads the data from google drive. 

from google.colab import drive
drive.mount('/content/drive')
json_path = "/content/drive/MyDrive/filtered_pandora.json"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)


rows = []

sample_size = int(len(data["authors"]) * 0.10)
sampled_authors = data["authors"][:sample_size]
for author in tqdm(sampled_authors):
    author_id = author["id"]
    labels = author["labels"]
    comments = author.get("comments", [])

    embeddings = []
    for comment in comments:
        try:
            vec = get_embedding(comment)
            embeddings.append(vec)
        except Exception as e:
            print(f"Error embedding comment for {author_id}: {e}")

    if embeddings:
        avg_embedding = np.mean(embeddings, axis=0)
        row = {
            "id": author_id,
            **labels
        }
        for i in range(len(avg_embedding)):
            row[f"embed_{i}"] = avg_embedding[i]
        rows.append(row)

df = pd.DataFrame(rows)
df.to_csv("author_embeddings_meanpool.csv", index=False)
print("Saved mean pooled embeddings to author_embeddings_meanpool.csv")