In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from nltk.tokenize import PunktSentenceTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM

In [6]:
CROSS_ENTROPY = torch.nn.CrossEntropyLoss(reduction='none')
sent_split    = PunktSentenceTokenizer().tokenize
DEVICE        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extract_gpt2_features(text, tokenizer, model, sent_cut):
    """
    Compute various perplexity (ppl) metrics and token ranking statistics for a given text using a GPT-2 model.
    """
    input_max_length = tokenizer.model_max_length - 2
    token_ids, offsets = [], []
    sentences = sent_cut(text)
    
    for sentence in sentences:
        tokens = tokenizer.tokenize(sentence)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        truncation_limit = len(token_ids) + len(ids) - input_max_length
        
        if truncation_limit > 0:
            ids = ids[:-truncation_limit]
        
        offsets.append((len(token_ids), len(token_ids) + len(ids)))
        token_ids.extend(ids)
        
        if truncation_limit >= 0:
            break
    
    input_ids = torch.tensor([tokenizer.bos_token_id] + token_ids + [tokenizer.eos_token_id]).to(DEVICE)
    logits = model(input_ids).logits
    
    # Shift logits to align with targets
    shift_logits = logits[:-1].contiguous()
    shift_target = input_ids[1:].contiguous()
    loss = CROSS_ENTROPY(shift_logits, shift_target)
    
    all_probs = torch.softmax(shift_logits, dim=-1)
    sorted_ids = torch.argsort(all_probs, dim=-1, descending=True)
    expanded_tokens = shift_target.unsqueeze(-1).expand_as(sorted_ids)
    indices = torch.where(sorted_ids == expanded_tokens)
    rank = indices[-1]
    
    # Rank distribution counters
    rank_counters = [
        (rank < 10).long().sum().item(),
        ((rank >= 10) & (rank < 100)).long().sum().item(),
        ((rank >= 100) & (rank < 1000)).long().sum().item(),
        (rank >= 1000).long().sum().item()
    ]
    
    # Compute different levels of perplexity (ppl)
    text_ppl = loss.mean().exp().item()
    sent_ppl = [(loss[start:end].sum() / (end - start)).exp().item() for start, end in offsets]
    
    max_sent_ppl = max(sent_ppl)
    sent_ppl_avg = sum(sent_ppl) / len(sent_ppl)
    sent_ppl_std = torch.std(torch.tensor(sent_ppl)).item() if len(sent_ppl) > 1 else 0
    
    mask = torch.ones(loss.size(0), device=DEVICE)
    step_ppl = loss.cumsum(dim=-1).div(mask.cumsum(dim=-1)).exp()
    max_step_ppl = step_ppl.max().item()
    step_ppl_avg = step_ppl.mean().item()
    step_ppl_std = step_ppl.std().item() if step_ppl.size(0) > 1 else 0
    
    ppl_metrics = [
        text_ppl, max_sent_ppl, sent_ppl_avg, sent_ppl_std,
        max_step_ppl, step_ppl_avg, step_ppl_std
    ]
    
    return rank_counters, ppl_metrics


In [None]:
cols = [
    'text_ppl', 'max_sent_ppl', 'sent_ppl_avg', 'sent_ppl_std', 'max_step_ppl', 
    'step_ppl_avg', 'step_ppl_std', 'rank_0', 'rank_10', 'rank_100', 'rank_1000'
]

for i in range(100,101):
    curr_num = i
    train          = pd.read_csv('../../chunk_test_98.csv')
    train['label'] = np.where(train['source'] == 'Human', 0, 1)
    models_train_feats = []

    TOKENIZER_EN = AutoTokenizer.from_pretrained("../../gpt2-large/tokenizer")
    MODEL_EN = AutoModelForCausalLM.from_pretrained("../../gpt2-large/model").to(DEVICE)

    train_ppl_feats  = []
    train_gltr_feats = []
    print(f"Chunk {curr_num} In Process")
    with torch.no_grad():
        for text in tqdm(train.text.values):
            gltr, ppl = extract_gpt2_features(text, TOKENIZER_EN, MODEL_EN, sent_split)
            train_ppl_feats.append(ppl)
            train_gltr_feats.append(gltr)

    X_train = pd.DataFrame(
        np.concatenate((train_ppl_feats, train_gltr_feats), axis=1), 
        columns=[f'gpt2-large-{col}' for col in cols]
    )
    models_train_feats.append(X_train)

    train_feats = pd.concat(models_train_feats, axis=1)
    train_feats = pd.concat([train,train_feats], axis = 1)
    train_feats.to_csv(f"Data/gpt2-large-feature/Split3/chunk_{curr_num}.csv",index=False)
    print(f"Chunk {curr_num} extracted")