In [None]:
# Setup and data loading
!pip install -q transformers torch scikit-learn matplotlib seaborn nltk accelerate bitsandbytes google-generativeai

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
import os
import json
import random
from nltk.corpus import gutenberg
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, confusion_matrix
from sklearn.linear_model import LogisticRegression
from scipy.spatial.distance import cosine
from transformers import AutoTokenizer, AutoModel, AutoConfig, GPT2LMHeadModel, GPT2TokenizerFast
from huggingface_hub import snapshot_download
import google.generativeai as genai
from google.colab import userdata

# Configuration
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

# Data Prep
nltk.download('gutenberg', quiet=True)
nltk.download('punkt', quiet=True)

def create_dataset(file_id, label, chunk_size=128, max_samples=300):
    """Tokenizes, chunks, and subsamples text from NLTK corpus."""
    raw = gutenberg.raw(file_id)
    words = raw.split()

    chunks = []
    for i in range(0, len(words), chunk_size):
        chunk = " ".join(words[i:i+chunk_size])
        if len(chunk) > 200:
            chunks.append(chunk)

    if len(chunks) > max_samples:
        random.seed(SEED)
        chunks = random.sample(chunks, max_samples)

    return pd.DataFrame({'text': chunks, 'label': label})

# Load Corpus
df = pd.concat([
    create_dataset('austen-emma.txt', 'Austen'),
    create_dataset('melville-moby_dick.txt', 'Melville')
]).reset_index(drop=True)

print(f"Dataset loaded: {len(df)} samples")
print(df['label'].value_counts())

In [None]:
# Load BAAI/bge-small-en-v1.5 for semantic clustering
SEM_MODEL_NAME = "BAAI/bge-small-en-v1.5"

print(f"Loading {SEM_MODEL_NAME}...")
sem_tokenizer = AutoTokenizer.from_pretrained(SEM_MODEL_NAME)
sem_model = AutoModel.from_pretrained(SEM_MODEL_NAME).to(device)

def get_semantic_embeddings(text_list, batch_size=32):
    sem_model.eval()
    embeddings = []

    for i in range(0, len(text_list), batch_size):
        batch = text_list[i:i+batch_size]
        inputs = sem_tokenizer(batch, padding=True, truncation=True,
                               max_length=512, return_tensors="pt").to(device)

        with torch.no_grad():
            out = sem_model(**inputs)

        # BGE uses CLS token normalized
        cls_emb = out.last_hidden_state[:, 0]
        cls_emb = torch.nn.functional.normalize(cls_emb, p=2, dim=1)
        embeddings.append(cls_emb.cpu().numpy())

    return np.vstack(embeddings)

print("Generating semantic embeddings...")
vectors_sem = get_semantic_embeddings(df['text'].tolist())
print(f"Output shape: {vectors_sem.shape}")

In [None]:
# Dimensionality Reduction
pca = PCA(n_components=2, random_state=SEED)
pca_res = pca.fit_transform(vectors_sem)

tsne = TSNE(n_components=2, perplexity=30, random_state=SEED, n_iter=1000)
tsne_res = tsne.fit_transform(vectors_sem)

# Clustering Metrics
kmeans = KMeans(n_clusters=2, random_state=SEED, n_init=10)
clusters = kmeans.fit_predict(vectors_sem)
sil_score = silhouette_score(vectors_sem, df['label'])

# Plots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# PCA
sns.scatterplot(ax=axes[0], x=pca_res[:,0], y=pca_res[:,1], hue=df['label'], alpha=0.7)
axes[0].set_title('PCA: Austen vs Melville')

# t-SNE
sns.scatterplot(ax=axes[1], x=tsne_res[:,0], y=tsne_res[:,1], hue=df['label'], alpha=0.7)
axes[1].set_title('t-SNE: Austen vs Melville')

plt.show()

print(f"Silhouette Score: {sil_score:.4f}")
print("\nConfusion Matrix (Cluster vs Label):")
print(pd.crosstab(df['label'], clusters))

In [None]:
# Generation model setup
try:
    api_key = userdata.get('GOOGLE_API_KEY')
    genai.configure(api_key=api_key)
except Exception as e:
    print("Error: Please set GOOGLE_API_KEY in Colab Secrets.")

# Using Flash for efficiency
gen_model = genai.GenerativeModel('gemini-2.5-flash')

def generate_rewrite(text, target_author, style_example=None):
    """
    Rewrites text using In-Context Learning (ICL).
    """
    prompt = f"""
    You are an expert literary editor.
    TASK: Rewrite the text below to strictly mimic the writing style, vocabulary, and syntax of {target_author}.
    CONSTRAINTS: Keep original meaning/narrative. Do not add intro/outro text.
    """

    if style_example:
        prompt += f'\nSTYLE REFERENCE:\n"{style_example}"\n'

    prompt += f'\nORIGINAL TEXT:\n"{text}"\n\nREWRITTEN TEXT:'

    try:
        response = gen_model.generate_content(
            prompt,
            generation_config=genai.types.GenerationConfig(temperature=0.7)
        )
        return response.text.strip()
    except Exception as e:
        return f"Generation Error: {str(e)}"

In [None]:
# Experiment execution

# 1. Select Samples
sample_idx = 5
sample_austen = df[df['label'] == 'Austen'].iloc[sample_idx]['text']

# Select a random demonstration for ICL (Pan et al., 2024)
demo_idx = 10
style_demo = df[df['label'] == 'Melville'].iloc[demo_idx]['text']

print(f"--- ORIGINAL (Austen) ---\n{sample_austen[:200]}...\n")

# 2. Generate Experimental Sample
# Passing the demo helps the model disentangle style from content
rewritten_melville = generate_rewrite(sample_austen, "Herman Melville", style_example=style_demo)
print(f"--- REWRITTEN (Target: Melville) ---\n{rewritten_melville[:200]}...\n")

# 3. Embed Experiment
# Reuse the BGE model from Cell 2 (ensure Cell 2 has run)
exp_vectors = get_semantic_embeddings([sample_austen, rewritten_melville])

# 4. Centroid Analysis
c_austen = vectors_sem[df['label'] == 'Austen'].mean(axis=0)
c_melville = vectors_sem[df['label'] == 'Melville'].mean(axis=0)

d_orig_austen = cosine(exp_vectors[0], c_austen)
d_new_melville = cosine(exp_vectors[1], c_melville)
d_new_austen = cosine(exp_vectors[1], c_austen)

print("-" * 30)
print("SEMANTIC SPACE (BGE) RESULTS:")
print(f"Original -> Austen Centroid:   {d_orig_austen:.4f}")
print(f"Rewritten -> Melville Centroid: {d_new_melville:.4f}")
print(f"Rewritten -> Austen Centroid:   {d_new_austen:.4f}")

In [None]:
# Reload BGE configuration to output hidden states
# I'm using the same model weights, but need to access the level 2 ones
from transformers import AutoConfig

print(f"Configuring {SEM_MODEL_NAME} for layer access...")
config = AutoConfig.from_pretrained(SEM_MODEL_NAME)
config.output_hidden_states = True

layer_model = AutoModel.from_pretrained(SEM_MODEL_NAME, config=config).to(device)
layer_tokenizer = AutoTokenizer.from_pretrained(SEM_MODEL_NAME)

def get_layer_embeddings(text_list, layer_idx=2, batch_size=32):
    layer_model.eval()
    embeddings = []

    for i in range(0, len(text_list), batch_size):
        batch = text_list[i:i+batch_size]
        inputs = layer_tokenizer(batch, padding=True, truncation=True,
                                 max_length=512, return_tensors="pt").to(device)

        with torch.no_grad():
            out = layer_model(**inputs)

        # Extract specific hidden layer
        # Layer 0 = Embeddings, Layer 1-12 = Transformer Blocks
        hidden_state = out.hidden_states[layer_idx]

        # Mean Pooling
        mask = inputs['attention_mask'].unsqueeze(-1).expand(hidden_state.size()).float()
        sum_emb = torch.sum(hidden_state * mask, 1)
        sum_mask = torch.clamp(mask.sum(1), min=1e-9)
        mean_emb = sum_emb / sum_mask

        # L2 Normalize
        mean_emb = torch.nn.functional.normalize(mean_emb, p=2, dim=1)
        embeddings.append(mean_emb.cpu().numpy())

    return np.vstack(embeddings)

# 1. Generate corpus vectors for Layer 2
print("Generating Layer 2 embeddings...")
vectors_layer2 = get_layer_embeddings(df['text'].tolist(), layer_idx=2)

# 2. Calculate Layer 2 Centroids
c_austen_l2 = vectors_layer2[df['label'] == 'Austen'].mean(axis=0)
c_melville_l2 = vectors_layer2[df['label'] == 'Melville'].mean(axis=0)

# 3. Run Experiment
if 'sample_austen' in locals() and 'rewritten_melville' in locals():
    exp_vecs_l2 = get_layer_embeddings([sample_austen, rewritten_melville], layer_idx=2)

    d_orig = cosine(exp_vecs_l2[0], c_austen_l2)
    d_new_mel = cosine(exp_vecs_l2[1], c_melville_l2)
    d_new_aus = cosine(exp_vecs_l2[1], c_austen_l2)

    print("-" * 30)
    print("LAYER 2 RESULTS (Syntax/Structure Focus):")
    print(f"Original -> Austen Centroid:   {d_orig:.4f}")
    print(f"Rewritten -> Melville Centroid: {d_new_mel:.4f}")
    print(f"Rewritten -> Austen Centroid:   {d_new_aus:.4f}")

    if d_new_mel < d_new_aus:
        print("\nResult: Success. Layer 2 vector shifted to Target Cluster.")
    else:
        print("\nResult: Failure. Layer 2 vector remained in Source Cluster.")
else:
    print("Skipping experiment: input variables not found (Run Cell 5 first).")

In [None]:
# Fluency Evaluation
# Implements Perplexity (PPL) to evaluate whether
# the style transfer results are fluent English.

from transformers import GPT2LMHeadModel, GPT2TokenizerFast

print("Loading GPT-2 for Perplexity calculation...")
# Using standard GPT-2 as the 'judge' of fluency
ppl_model_id = "gpt2"
ppl_model = GPT2LMHeadModel.from_pretrained(ppl_model_id).to(device)
ppl_tokenizer = GPT2TokenizerFast.from_pretrained(ppl_model_id)

def calculate_perplexity(text):
    """
    Calculates PPL (Perplexity).
    Lower is better (10-100 is normal human text; >100 is incoherent).
    """
    encodings = ppl_tokenizer(text, return_tensors="pt").to(device)
    max_length = ppl_model.config.n_positions
    stride = 512
    seq_len = encodings.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    for begin_loc in range(0, seq_len, stride):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc
        input_ids = encodings.input_ids[:, begin_loc:end_loc]
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = ppl_model(input_ids, labels=target_ids)
            # Neg Log Likelihood
            nlls.append(outputs.loss * trg_len)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
    return ppl.item()

# Evaluation of generated texts
if 'sample_austen' in locals() and 'rewritten_melville' in locals():
    ppl_orig = calculate_perplexity(sample_austen)
    ppl_new = calculate_perplexity(rewritten_melville)

    print("-" * 30)
    print("FLUENCY METRICS (PPL - Lower is better):")
    print(f"Original Text PPL:  {ppl_orig:.2f}")
    print(f"Rewritten Text PPL: {ppl_new:.2f}")

    if ppl_new < 100:
        print("Result: The rewritten text is fluent (PPL < 100).")
    else:
        print("Result: The rewritten text shows signs of incoherence.")

In [None]:
# Visualizing the attention of Layer 2 vs Layer 12 to see if
# they focus on different parts of the sentence (Syntax vs Content).

# Reusing the BGE model from Cell 6
from transformers import AutoConfig

VIZ_MODEL_NAME = "BAAI/bge-small-en-v1.5"
print(f"Loading {VIZ_MODEL_NAME} with attention outputs...")

viz_config = AutoConfig.from_pretrained(VIZ_MODEL_NAME)
viz_config.output_attentions = True
viz_model = AutoModel.from_pretrained(VIZ_MODEL_NAME, config=viz_config).to(device)
viz_tokenizer = AutoTokenizer.from_pretrained(VIZ_MODEL_NAME)

def plot_attention_heatmap(text, layer_num, title):
    inputs = viz_tokenizer(text, return_tensors="pt").to(device)
    tokens = viz_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

    with torch.no_grad():
        outputs = viz_model(**inputs)

    # Get attentions: Tuple of (batch, heads, seq, seq)
    # Selecting the specific layer
    layer_attention = outputs.attentions[layer_num][0]

    # Average across all heads to get a general "focus map"
    avg_attention = torch.mean(layer_attention, dim=0).cpu().numpy()

    # Filter tokens for visualization
    clean_tokens = [t.replace('##', '') for t in tokens]

    plt.figure(figsize=(10, 8))
    sns.heatmap(avg_attention, xticklabels=clean_tokens, yticklabels=clean_tokens, cmap="viridis")
    plt.title(title)
    plt.show()

if 'rewritten_melville' in locals():
    # Taking a short slice of the rewritten text for readability
    short_text = " ".join(rewritten_melville.split()[:20])

    print(f"Visualizing attention for: '{short_text}...'")

    # Plot Layer 2 (Syntax/Style focus - "Success" layer)
    # Pan et al. suggest style is local/structural
    plot_attention_heatmap(short_text, layer_num=1, title="Layer 2 Attention (Structural/Stylistic)")

    # Plot Final Layer (Semantic/Content focus - "Failure" layer)
    plot_attention_heatmap(short_text, layer_num=11, title="Layer 12 Attention (Semantic/Content)")

In [None]:
# Style classifier (Accuracy Metric)
# Implements the "ACC" metric from Pan et al. (2024) using a lightweight
# Logistic Regression on top of the BGE embeddings.

from sklearn.linear_model import LogisticRegression

# 1. Training the Classifier
# Using the vectors already created in Cell 2/6
print("Training Style Classifier...")
clf = LogisticRegression(random_state=SEED)
clf.fit(vectors_sem, df['label'])

# Checking the accuracy on the original data
train_acc = clf.score(vectors_sem, df['label'])
print(f"Classifier Accuracy on Training Data: {train_acc*100:.2f}%")

# 2. Evaluating the Experiment
if 'sample_austen' in locals() and 'rewritten_melville' in locals():
    # Embed the samples (using the same BGE model)
    exp_vecs = get_semantic_embeddings([sample_austen, rewritten_melville])

    # Predicting probabilities
    # Returns [Prob_Austen, Prob_Melville]
    probs_orig = clf.predict_proba(exp_vecs[0].reshape(1, -1))[0]
    probs_new = clf.predict_proba(exp_vecs[1].reshape(1, -1))[0]

    # Get class labels order
    classes = clf.classes_

    print("-" * 30)
    print("STYLE PROBABILITY SCORES:")
    print(f"Original Text:   {probs_orig[0]*100:.1f}% {classes[0]}, {probs_orig[1]*100:.1f}% {classes[1]}")
    print(f"Rewritten Text:  {probs_new[0]*100:.1f}% {classes[0]}, {probs_new[1]*100:.1f}% {classes[1]}")

    # Determine "Accuracy" (Did it flip?)
    target_author = "Melville"
    target_idx = list(classes).index("Melville")

    if probs_new[target_idx] > 0.5:
        print(f"\nRESULT: Success. Classified as {target_author}.")
    else:
        print(f"\nRESULT: Failure. Still classified as {classes[1-target_idx]}.")