In [None]:
# !pip install pandas transformers scikit-learn matplotlib seaborn sentencepiece accelerate -q
# !pip install protobuf

# import pandas as pd
# hate_yes_data = pd.read_csv('/root/ccs_aisf/data/yes_no/hate_vs_antagonist_yes.csv')
# hate_no_data = pd.read_csv('/root/ccs_aisf/data/yes_no/hate_vs_antagonist_no.csv')

# torch.set_default_tensor_type(torch.cuda.HalfTensor)

## **1. Datasets.**

In [None]:
import re
import pickle
import pandas as pd
from sklearn.metrics import accuracy_score
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification
# from transformers import EncoderDecoderModel, BertTokenizer, DistilBertTokenizer
# from transformers import AutoModelForCausalLM

from sklearn.linear_model import LogisticRegression
import numpy as np
from tqdm import tqdm

In [None]:
ROOT = '../'

real_vs_ideal_world_data = pd.read_csv(ROOT+'data/raw/real_vs_ideal_world.csv')
hate_data = pd.read_csv(ROOT+'data/raw/hate_vs_antagonist.csv')

real_vs_ideal_world_yes_data = pd.read_csv(ROOT+'data/yes_no/real_vs_ideal_world_yes.csv')
real_vs_ideal_world_no_data = pd.read_csv(ROOT+'data/yes_no/real_vs_ideal_world_no.csv')

hate_yes_data = pd.read_csv(ROOT+'data/yes_no/hate_vs_antagonist_yes.csv')
hate_no_data = pd.read_csv(ROOT+'data/yes_no/hate_vs_antagonist_no.csv')

## **2. Choose model.**

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ===== GEMMA 3 SERIES (MARCH 2025) - LATEST MULTIMODAL! =====

# Gemma 3 - Latest multimodal models with vision capabilities
print("=== LOADING GEMMA 3 MODELS ===")

# # GEMMA 3 1B - Text-only, smallest
# YOUR_NAME = "gemma-3-1b-it"
# gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
# gemma_model = AutoModelForCausalLM.from_pretrained(
#     f"google/{YOUR_NAME}",
#     torch_dtype=torch.bfloat16,  # CHANGED: Use bfloat16 for Gemma 3
#     device_map="auto"
# )
# gemma_model.eval()

# GEMMA 3 4B - Multimodal (text + images)
YOUR_NAME = "gemma-3-4b-it"
gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
gemma_model = AutoModelForCausalLM.from_pretrained(
    f"google/{YOUR_NAME}",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
gemma_model.eval()

# # GEMMA 3 12B - Multimodal (text + images)
# YOUR_NAME = "gemma-3-12b-it"
# gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
# gemma_model = AutoModelForCausalLM.from_pretrained(
#     f"google/{YOUR_NAME}",
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )
# gemma_model.eval()

# # GEMMA 3 27B - Largest multimodal (text + images)
# YOUR_NAME = "gemma-3-27b-it"
# gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
# gemma_model = AutoModelForCausalLM.from_pretrained(
#     f"google/{YOUR_NAME}",
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )
# gemma_model.eval()

# # ===== GEMMA 3N SERIES (JUNE 2025) - MOBILE-OPTIMIZED =====

# # Gemma 3n - Mobile-first architecture with multimodal capabilities
# # print("=== LOADING GEMMA 3N MODELS ===")  

# # GEMMA 3N E2B (Effective 2B, actual 5B params)
# YOUR_NAME = "gemma-3n-e2b-it"
# gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
# gemma_model = AutoModelForCausalLM.from_pretrained(
#     f"google/{YOUR_NAME}",
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )
# gemma_model.eval()

# # GEMMA 3N E4B (Effective 4B, actual 8B params)
# YOUR_NAME = "gemma-3n-e4b-it"
# gemma_tokenizer = AutoTokenizer.from_pretrained(f"google/{YOUR_NAME}")
# gemma_model = AutoModelForCausalLM.from_pretrained(
#     f"google/{YOUR_NAME}",
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )
# gemma_model.eval()


### **3. Get hidden states**

In [None]:

import sys
import os

# Add the code directory directly to Python path
code_dir = '/Users/elenaericheva/ericheva_git/ccs_aisf/code'
if code_dir not in sys.path:
    sys.path.insert(0, code_dir)

print(f"Added {code_dir} to Python path")



In [None]:
from extract_fixed import vectorize_df, extract_representation

X_pos = vectorize_df(hate_yes_data['statement'],
                     gemma_model,
                     gemma_tokenizer,
                     layer_index=None,
                     strategy="last-token",      # CHANGED: Use last-token for decoder
                     model_type='decoder',       # CHANGED: Gemma is decoder, not encoder
                     use_decoder=False,          # CHANGED: Not needed for decoder-only
                     get_all_hs=True,
                     device=None)

X_neg = vectorize_df(hate_no_data['statement'],
                     gemma_model,
                     gemma_tokenizer,
                     layer_index=None,
                     strategy="last-token",      # CHANGED: Use last-token for decoder
                     model_type='decoder',       # CHANGED: Gemma is decoder, not encoder
                     use_decoder=False,          # CHANGED: Not needed for decoder-only
                     get_all_hs=True,
                     device=None)

In [None]:
import numpy as np

np.savez_compressed(f'{YOUR_NAME}_neg.npz', X_neg)
np.savez_compressed(f'{YOUR_NAME}_pos.npz', X_pos)

# Load the files correctly
X_pos_file = np.load(f'{YOUR_NAME}_pos.npz')
X_neg_file = np.load(f'{YOUR_NAME}_neg.npz')

# Extract the actual arrays
X_pos = X_pos_file['arr_0']
X_neg = X_neg_file['arr_0']

## **4. Plot.**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from itertools import combinations
from sklearn.preprocessing import normalize
from sklearn.preprocessing import StandardScaler

In [None]:
# Check data types and shapes first
print("X_pos dtype:", X_pos.dtype)
print("X_pos shape:", X_pos.shape)
print("X_neg dtype:", X_neg.dtype)
print("X_neg shape:", X_neg.shape)

print("######################################################################################## ")
print("X_pos.max()", X_pos.max())
print("X_pos.min()", X_pos.min())
print("X_pos.mean()", X_pos.mean())
print("X_pos.std()", X_pos.std())
print("X_pos.median()", np.median(X_pos))
print("X_neg.max()", X_neg.max())
print("X_neg.min()", X_neg.min())
print("X_neg.mean()", X_neg.mean())
print("X_neg.std()", X_neg.std())
print("X_neg.median()", np.median(X_neg))
print("######################################################################################## ")

# Check for any issues with the data
print("X_pos sample:", X_pos.flat[:5])
print("X_neg sample:", X_neg.flat[:5])

# Convert to proper numeric type if needed
X_pos = X_pos.astype(np.float32)
X_neg = X_neg.astype(np.float32)

n_samples, n_layers, n_features = X_pos.shape

# Normalize the data
X_pos_normalized = X_pos.reshape(-1, X_pos.shape[-1])         # (512*25, 768)
X_pos_normalized = normalize(X_pos_normalized, norm='l2', axis=1).reshape(X_pos.shape)

X_neg_normalized = X_neg.reshape(-1, X_neg.shape[-1])         # (512*25, 768)
X_neg_normalized = normalize(X_neg_normalized, norm='l2', axis=1).reshape(X_neg.shape)

print("Normalization completed!")
print("X_pos_normalized shape:", X_pos_normalized.shape)
print("X_neg_normalized shape:", X_neg_normalized.shape)

# Check data types and shapes first
print("X_pos_normalized dtype:", X_pos_normalized.dtype)
print("X_pos_normalized shape:", X_pos_normalized.shape)
print("X_neg_normalized dtype:", X_neg_normalized.dtype)
print("X_neg_normalized shape:", X_neg_normalized.shape)

print("######################################################################################## ")
print("X_pos_normalized.max()", X_pos_normalized.max())
print("X_pos_normalized.min()", X_pos_normalized.min())
print("X_pos_normalized.mean()", X_pos_normalized.mean())
print("X_pos_normalized.std()", X_pos_normalized.std())
print("X_pos_normalized.median()", np.median(X_pos_normalized))
print("X_neg_normalized.max()", X_neg_normalized.max())
print("X_neg_normalized.min()", X_neg_normalized.min())
print("X_neg_normalized.mean()", X_neg_normalized.mean())
print("X_neg_normalized.std()", X_neg_normalized.std())
print("X_neg_normalized.median()", np.median(X_neg_normalized))
print("######################################################################################## ")

# Verify normalization worked - check L2 norms
print("X_pos_normalized sample norms:", np.linalg.norm(X_pos_normalized.reshape(-1, n_features)[:5], axis=1))
print("X_neg_normalized sample norms:", np.linalg.norm(X_neg_normalized.reshape(-1, n_features)[:5], axis=1))

# Plot PCA
from format_results_fixed import plot_pca_or_tsne_layerwise
plot_pca_or_tsne_layerwise(X_pos_normalized,
                           X_neg_normalized,
                           hate_data['is_harmfull_opposition'],
                           standardize=False, n_components=5,
                           components=[0, 1], mode='pca',
                           plot_title='PCA clustering, Gemma2b Base not pretr hate vs normal')

## **5. Find best CCS**


To find the best layers, train CCS on all hidden states. You can use different normalization strategies, the best in experiments was l2 + median.

In [None]:
from ccs import CCS, train_ccs_on_hidden_states
from sklearn.model_selection import train_test_split
import random

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Train-test indexes (for all experiments)
idx = np.arange(len(X_pos_normalized))
train_idx, test_idx = train_test_split(idx, test_size=0.15, random_state=71, shuffle=True)


# train CCS on normalized states
orig_ccs = train_ccs_on_hidden_states(X_pos_normalized,
                                        X_neg_normalized,
                                        hate_data['is_harmfull_opposition'],
                                         train_idx,
                                         test_idx,
                                        normalizing='median')

After that, we train the CCS on the selected layer to obtain the probe and its weights.

In [None]:
from ccs import CCS
import random

LAYER_IDX = 4

X_pos_normalized_data = pd.DataFrame(X_pos_normalized[:, LAYER_IDX])
X_neg_normalized_data = pd.DataFrame(X_neg_normalized[:, LAYER_IDX])


# Train-Test indexes
X_pos_norm_train = X_pos_normalized_data.loc[train_idx, :].values.astype(np.float32)
X_pos_norm_test  = X_pos_normalized_data.loc[test_idx, :].values.astype(np.float32)

X_neg_norm_train = X_neg_normalized_data.loc[train_idx, :].values.astype(np.float32)
X_neg_norm_test  = X_neg_normalized_data.loc[test_idx, :].values.astype(np.float32)

y_train = hate_data['is_harmfull_opposition'][train_idx]
y_test  = hate_data['is_harmfull_opposition'][test_idx]

# Median normalizing
X_pos_norm_train -= np.median(X_pos_norm_train)
X_pos_norm_test -= np.median(X_pos_norm_train)

X_neg_norm_train -= np.median(X_neg_norm_train)
X_neg_norm_test -= np.median(X_neg_norm_train)

# Train CCS without labels first NEG, after POS!
ccs = CCS(X_neg_norm_train, X_pos_norm_train, y_train.values, var_normalize=False, lambda_classification=0, predict_normalize=False)
ccs.repeated_train()

# Evaluate
ccs_acc = ccs.get_acc(X_neg_norm_test, X_pos_norm_test, y_test.values)
print("CCS accuracy: {}".format(ccs_acc))

I DON'T UNDERSTAND WHY THE ACCURACY IS DIFFERENT AND I DON'T SEE THE BUG AT ALL

# **6. Steering**

In [None]:
from steering import plot_steering_power, plot_boundary
from steering import PatchHook

In [None]:
deltas = np.linspace(-0.05, 0.05, 30)

X_pos_tensor = torch.tensor(X_pos_normalized[:257], dtype=torch.float32, device=ccs.device)
X_neg_tensor = torch.tensor(X_neg_normalized[:257], dtype=torch.float32, device=ccs.device)

plot_steering_power(ccs, X_pos_tensor, X_neg_tensor, deltas, labels=["POS (statement + ДА) [harm]", "NEG (statement + НЕТ) [harm]"], 
                    title="Steering along opinion direction [harm]")

In [None]:
deltas = np.linspace(-0.05, 0.05, 30)

X_pos_tensor = torch.tensor(X_pos_normalized[257:], dtype=torch.float32, device=ccs.device)
X_neg_tensor = torch.tensor(X_neg_normalized[257:], dtype=torch.float32, device=ccs.device)

plot_steering_power(ccs, X_pos_tensor, X_neg_tensor, deltas, labels=["POS (statement + ДА) [safe]", "NEG (statement + НЕТ) [safe]"], 
                    title="Steering along opinion direction [harm]")

Manual calibration

In [None]:
idx = 1  # first sample
h_orig = torch.tensor(X_pos_normalized_data.loc[idx], dtype=torch.float32, device=ccs.device)

# ensure weights is numpy array and normalized
weights, _ = ccs.get_weights()

direction = weights / (np.linalg.norm(weights) + 1e-6)
h_steered =h_orig + 0.025*direction

p_orig = ccs.best_probe(h_orig.unsqueeze(0)).item()
p_steered = ccs.best_probe(h_steered.unsqueeze(0)).item()

print(f"Original: {p_orig:.4f}, Steered: {p_steered:.4f}")

In [None]:
plot_boundary(X_pos_normalized_data, X_neg_normalized_data,  hate_data['is_harmfull_opposition'], ccs, 3, [0, 1])

# 6.1. Steering

Very important: `alpha_neg = -alpha_pos`

In [None]:
direction = torch.tensor(
    ccs.get_weights()[0] / np.linalg.norm(ccs.get_weights()[0]),
    dtype=torch.float32,
    device="cuda" if torch.cuda.is_available() else "cpu"  # 
)

alpha = 0.049
token_idx=0

LAYER_IDX = 4

true = hate_data['is_harmfull_opposition'] 
texts = hate_data['statement']
text_yes = texts + " Yes."

inputs_yes = gemma_tokenizer(list(text_yes), return_tensors="pt", padding=True)
inputs_yes = {k: v.to(direction.device) for k, v in inputs_yes.items()}

true_tensor = torch.tensor(true.values, dtype=torch.long, device=direction.device)

# Создаём и настраиваем hook_obj
hook_obj = PatchHook(token_idx=token_idx, # which token to adjust
                     direction=direction,  # CCS vector weights (normalize before)
                     character=true_tensor, # y tensor
                       alpha=alpha) #steering coef


print(f"[MAIN] hook_obj id: {id(hook_obj)}")
print(f"[MAIN] character shape: {hook_obj.character.shape}")

# Hook to specific layer 
h = gemma_model.deberta.encoder.layer[LAYER_IDX].output.register_forward_hook(hook_obj)

# Прогон
with torch.no_grad():
    outputs_patched_yes = gemma_model(**inputs_yes, output_hidden_states=True)

h.remove()

In [None]:
# Removing hooks   

def remove_all_forward_hooks(model):
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks.clear()
            remove_all_forward_hooks(child)
    
def remove_all_backward_hooks(model):
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_backward_hooks"):
                child._backward_hooks.clear()
            remove_all_backward_hooks(child)


In [None]:
remove_all_forward_hooks(gemma_model)
gemma_model.deberta.encoder.layer[LAYER_IDX]._forward_hooks.clear()

Negative texts with `-1*alpha`

In [None]:
# Negative texts
text_no = texts + " No."
inputs_no = gemma_tokenizer(list(text_no), return_tensors="pt", padding=True)
inputs_no = {k: v.to(direction.device) for k, v in inputs_no.items()}

# Hook
hook_obj2 = PatchHook(token_idx=token_idx, direction=direction, character=true_tensor, alpha=-1*alpha)


h = gemma_model.deberta.encoder.layer[LAYER_IDX].output.register_forward_hook(hook_obj2)

# Run
with torch.no_grad():
    outputs_patched_no = gemma_model(**inputs_no, output_hidden_states=True)

# **7. PCA results and steered probe**

In [None]:
X_pos_st = np.array(outputs_patched_yes.hidden_states)[:, :, 0, :].transpose(1, 0, 2)
X_neg_st = np.array(outputs_patched_no.hidden_states)[:, :, 0, :].transpose(1, 0, 2)

In [None]:
X_pos_st_norm = normalize(X_pos_st[:, LAYER_IDX, :], norm='l2', axis=1)
X_neg_st_norm = normalize(X_neg_st[:, LAYER_IDX, :], norm='l2', axis=1)

classes_st, probas_st = ccs.predict(X_pos_st_norm, X_neg_st_norm)
classes_or, probas_or = ccs.predict(X_neg_normalized[:, LAYER_IDX, :], X_pos_normalized[:, LAYER_IDX, :])

In [None]:
X_pos_st_normalized = X_pos_st.reshape(-1, X_pos_st.shape[-1])         # (512*25, 768)
X_pos_st_normalized = normalize(X_pos_st_normalized, norm='l2', axis=1).reshape(X_pos_st.shape)

X_neg_st_normalized = X_neg_st.reshape(-1, X_neg_st.shape[-1])         # (512*25, 768)
X_neg_st_normalized = normalize(X_neg_st_normalized, norm='l2', axis=1).reshape(X_neg_st.shape)

# Components 1, 3

plot_pca_or_tsne_layerwise(X_pos_st_normalized,
                           X_neg_st_normalized,
                           hate_data['is_harmfull_opposition'],
                           standardize=False,
                           n_components=5, components=[1, 3])

In [None]:
# Components 1, 0
plot_pca_or_tsne_layerwise(X_pos_st_normalized,
                           X_neg_st_normalized,
                           hate_data['is_harmfull_opposition'],
                           standardize=False,
                           n_components=5, components=[0, 1])

In [None]:
steered_ccs = train_ccs_on_hidden_states(X_pos_st_normalized,
                                        X_neg_st_normalized,
                                        hate_data['is_harmfull_opposition'],
                                         train_idx,
                                         test_idx,
                                        normalizing='median')

## **7. Plot results tables.**

In [None]:
from format_results import get_results_table
orig_ccs_data = get_results_table(orig_ccs)
st_ccs_data = get_results_table(steered_ccs)

In [None]:
fig, ax = plt.subplots(figsize=(18, 5))

plt.plot(orig_ccs_data['accuracy'], label='Orig_CCS')
plt.plot(st_ccs_data['accuracy'], label='St_CCS')

plt.hlines(1, 0, 7, label='ideal', colors='red', linewidth=3, linestyles=['--'])

plt.xlabel('Layer_number')
plt.ylabel('accuracy score ')
plt.legend(loc='upper right');

In [None]:
from format_results import get_results_table
orig_ccs_data = get_results_table(orig_ccs)
st_ccs_data = get_results_table(steered_ccs)

fig, ax = plt.subplots(figsize=(18, 5))

plt.plot(orig_ccs_data['contradiction idx ↓'], label='Orig_CCS')
plt.plot(st_ccs_data['contradiction idx ↓'], label='St_CCS')

plt.hlines(0, 0, 7, label='ideal', colors='red', linewidth=3, linestyles=['--'])

plt.xlabel('Layer_number')
plt.ylabel('Contradiction idx score ')
plt.legend(loc='upper right');

In [None]:
fig, ax = plt.subplots(figsize=(18, 5))

plt.plot(orig_ccs_data['agreement_score ↓'], label='Orig_CCS')
plt.plot(st_ccs_data['agreement_score ↓'], label='St_CCS')

plt.hlines(0, 0, 7, label='ideal', colors='red', linewidth=3, linestyles=['--'])

plt.xlabel('Layer_number')
plt.ylabel('Agreement score ')
plt.legend(loc='upper right');