In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer
import torch
import re
import random
import json
from functools import lru_cache

In [None]:
t5_model = T5ForConditionalGeneration.from_pretrained("t5-large", torch_dtype=torch.float16, 
                                             device_map="auto")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-large")
device = "cuda" if torch.cuda.is_available() else "cpu"
t5_model.to(device)
print(device)

In [None]:
# Load GPT-2 model and tokenizer
model_name = "openai-community/gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.to(device)

In [None]:
file_path = "./SubtaskB/subtaskB_train.jsonl"
with open(file_path, "r", encoding="utf-8") as file:
    data_human = [json.loads(line) for line in file if json.loads(line).get("model") == "human"]

# Print first 3 records
print(data_human[0])

In [None]:
def batch_mask_text(texts, mask_ratio=0.15, max_words=370):
    """Mask multiple texts at once."""
    masked_texts = []
    mask_indices_list = []
    
    for text in texts:
        words = text.split()
        
        # Truncate text
        if len(words) > max_words:
            words = words[:max_words]
        
        num_masks = int(len(words) * mask_ratio)
        
        # Randomly select spans to mask
        mask_indices = sorted(random.sample(range(len(words) - 1), num_masks))
        mask_indices_list.append(mask_indices)
        
        for i, idx in enumerate(mask_indices):
            words[idx] = f"<extra_id_{i}>"
            if idx + 1 < len(words):  # Ensure a 2-word span
                words[idx + 1] = ""
        
        masked_texts.append(" ".join(words))
    
    return masked_texts, mask_indices_list

def batch_replace_masks(texts, batch_size=8):
    """Generate T5 model outputs for masked texts in batches."""
    all_outputs = []
    
    # Process in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        n_expected = [text.count("<extra_id_") for text in batch_texts]
        stop_id = t5_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
        
        tokens = t5_tokenizer(batch_texts, return_tensors="pt", padding=True)
        
        # Move input tensors to model's device
        with torch.no_grad():
            outputs = t5_model.generate(
                input_ids=tokens["input_ids"].to(t5_model.device),
                attention_mask=tokens["attention_mask"].to(t5_model.device),
                max_length=150,
                do_sample=True,
                top_p=0.9,
                num_return_sequences=1,
                eos_token_id=stop_id
            )
            
        # Move outputs back to CPU to save GPU memory
        outputs = outputs.detach().cpu()
        batch_decoded = t5_tokenizer.batch_decode(outputs, skip_special_tokens=False)
        all_outputs.extend(batch_decoded)
    
    return all_outputs

def batch_extract_fills(texts):
    """Extract the generated fills from T5's output for multiple texts."""
    extracted_fills = []
    for text in texts:
        text = text.replace("<pad>", "").replace("</s>", "").strip()
        
        # Use regex to extract text inside <extra_id_X> tokens
        fills = re.findall(r"<extra_id_\d+>\s*(.*?)\s*(?=<extra_id_\d+>|$)", text)
        
        # Clean extracted tokens
        extracted_fills.append([fill.strip() for fill in fills])
    
    return extracted_fills

def batch_apply_extracted_fills(masked_texts, extracted_fills):
    """Replace mask tokens in the masked texts with generated fills."""
    filled_texts = []
    
    for masked_text, fills in zip(masked_texts, extracted_fills):
        if not fills:
            filled_texts.append(masked_text)
            continue
        
        filled_text = masked_text
        # Iterate through expected mask positions and replace them
        for i, fill in enumerate(fills):
            filled_text = filled_text.replace(f"<extra_id_{i}>", fill, 1)
        
        filled_texts.append(filled_text)
    
    return filled_texts

def batch_average_log_prob(texts, batch_size=8):
    """Calculate average log probability for multiple texts in batches."""
    all_log_probs = []
    
    # Process in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize input
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids, attention_mask=attention_mask)
        
        # For batch processing, we need to compute loss per sample
        if hasattr(outputs, "loss") and outputs.loss.dim() == 0:
            # If model returns a single loss value for the batch
            avg_log_prob = -outputs.loss.item()
            all_log_probs.extend([avg_log_prob] * len(batch_texts))
        else:
            # If we need to calculate per-sample loss
            # This is a simplification - you might need to adjust based on your model's output
            logits = outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss_per_token = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                       shift_labels.view(-1))
            
            # Reshape back to [batch_size, sequence_length]
            loss_per_token = loss_per_token.view(shift_labels.size())
            
            # Calculate average loss per sample by considering attention mask
            sample_losses = []
            for j in range(loss_per_token.size(0)):
                # Use attention mask to identify real tokens
                mask = attention_mask[j, 1:].bool()  # Shift to align with targets
                if mask.sum() > 0:
                    sample_loss = loss_per_token[j][mask].mean().item()
                    sample_losses.append(-sample_loss)  # Negative loss is log probability
                else:
                    sample_losses.append(0.0)
            
            all_log_probs.extend(sample_losses)
    
    return all_log_probs

# Main optimized processing loop
def optimized_processing(data_human, num_samples=50, iterations=25, batch_size=8):
    log_probs_per_text_base = []
    log_probs_per_text_transformed = []
    
    # Process original texts in batches
    original_texts = [" ".join(data_human[j]["text"].split()[:50]) for j in range(num_samples)]
    base_log_probs = batch_average_log_prob(original_texts, batch_size)
    
    # For each iteration, process all texts together in batches
    for iter_idx in range(iterations):
        # Step 1: Mask all texts at once
        all_masked_texts, _ = batch_mask_text(original_texts)
        
        # Step 2: Generate replacements in batches
        all_raw_fills = batch_replace_masks(all_masked_texts, batch_size)
        
        # Step 3: Extract fills
        all_extracted_fills = batch_extract_fills(all_raw_fills)
        
        # Step 4: Apply fills
        all_perturbed_texts = batch_apply_extracted_fills(all_masked_texts, all_extracted_fills)
        
        # Step 5: Calculate log probs in batches
        all_log_probs = batch_average_log_prob(all_perturbed_texts, batch_size)
        
        # Organize results by original text
        for j in range(num_samples):
            if iter_idx == 0:
                log_probs_per_text_transformed.append([])
            log_probs_per_text_transformed[j].append(all_log_probs[j])
    
    # Store base log probs
    log_probs_per_text_base = base_log_probs
    
    # Print results
    for j in range(num_samples):
        avg_log_prob_not = log_probs_per_text_base[j]
        log_probs = log_probs_per_text_transformed[j]
        
        print(f"Average per-token log probability for base sentence {j + 1}: {avg_log_prob_not:.4f}")
        print(f"Average per-token log probability for transformed sentence {j + 1}: {(sum(log_probs) / len(log_probs)):.4f}, "
              f"the minimum is {min(log_probs)} and the maximum is {max(log_probs)}")
    
    return log_probs_per_text_base, log_probs_per_text_transformed

# Memory management utilities
def clear_cuda_cache():
    """Clear CUDA cache to free up memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# Add caching for tokenization
@lru_cache(maxsize=1024)
def cached_tokenize(text, is_t5=False):
    """Cache tokenization results to avoid repeated work."""
    if is_t5:
        return t5_tokenizer(text, return_tensors="pt", padding=True)
    else:
        return tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Example usage
# log_probs_base, log_probs_transformed = optimized_processing(data_human, num_samples=50, iterations=25, batch_size=8)

In [None]:
log_probs_base, log_probs_transformed = optimized_processing(data_human, num_samples=10, iterations=25, batch_size=2)