# Imports

In [45]:
from sklearn.linear_model import LogisticRegression
import torch
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import numpy as np

# Constants

In [46]:
DATASET = "com2sense"
MODEL_NAME = "gemma-2-2b"
DATA_DIR = f"./experimental_data/{MODEL_NAME}/{DATASET}/"
WEIGHTS_DIR = f"./weights/linear_analysis/{MODEL_NAME}/{DATASET}/"
TRAIN_SIZE = 0.8
TOP_K = 10  # Set to -1 to see every element
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data

Load data

In [47]:
acts_exp_resid = torch.load(DATA_DIR + "acts_exp_resid.pt", map_location=device)
acts_resid = torch.load(DATA_DIR + "acts_resid.pt", map_location=device)

  acts_exp_resid = torch.load(DATA_DIR + "acts_exp_resid.pt", map_location=device)
  acts_resid = torch.load(DATA_DIR + "acts_resid.pt", map_location=device)


In [48]:
print(acts_exp_resid.shape)

torch.Size([1874, 26, 2304])


Split data

In [49]:
X_train_index = int(TRAIN_SIZE * acts_exp_resid.shape[0])
X_train_exp_resid = acts_exp_resid[:X_train_index, :, :]
X_train_resid = acts_resid[:X_train_index, :, :]
X_test_exp_resid = acts_exp_resid[X_train_index:, :, :]
X_test_resid = acts_resid[X_train_index:, :, :]

In [50]:
print(X_train_resid.shape, X_test_resid.shape)

torch.Size([1499, 26, 2304]) torch.Size([375, 26, 2304])


1 indicates CoT, 0 is Non-CoT

In [51]:
y_train_exp_resid = torch.ones(X_train_exp_resid.shape[0])
y_train_resid = torch.zeros(X_train_resid.shape[0])
y_test_exp_resid = torch.ones(X_test_exp_resid.shape[0])
y_test_resid = torch.zeros(X_test_resid.shape[0])

In [52]:
print(y_train_resid.shape, y_test_resid.shape)

torch.Size([1499]) torch.Size([375])


Concatenate the data

In [53]:
X_train = torch.cat((X_train_exp_resid, X_train_resid), dim=0)
X_test = torch.cat((X_test_exp_resid, X_test_resid), dim=0)
y_train = torch.cat((y_train_exp_resid, y_train_resid), dim=0)
y_test = torch.cat((y_test_exp_resid, y_test_resid), dim=0)

In [54]:
print(X_train.shape, y_train.shape)

torch.Size([2998, 26, 2304]) torch.Size([2998])


# Initialize and train classifiers

In [55]:
classifiers = [LogisticRegression(fit_intercept=False)] * acts_resid.shape[1]

In [56]:
X_train, y_train, X_test, y_test = X_train.cpu().numpy(), y_train.cpu().numpy(), X_test.cpu().numpy(), y_test.cpu().numpy()
for i, classifier in enumerate(classifiers):
    classifier.fit(X_train[:, i, :], y_train)

# Evaluate classifiers

In [57]:
weights_list = []
for i, classifier in enumerate(classifiers):
    y_pred = classifier.predict(X_test[:, i, :])
    print(f"layer {i}")
    accuracy = accuracy_score(y_test, y_pred)
    print("Accuracy:", accuracy)
    
    conf_matrix = confusion_matrix(y_test, y_pred)
    print("Confusion Matrix:\n", conf_matrix)
    
    class_report = classification_report(y_test, y_pred)
    print("Classification Report:\n", class_report)

    weights = classifier.coef_
    print(f"Weights: {weights}")
    weights_list.append(weights)

    top_k = np.argsort(weights[0, :])[::-1][:TOP_K]
    print(f"Top k indeces: {top_k}")

    print("\n")

layer 0
Accuracy: 0.5026666666666667
Confusion Matrix:
 [[  2 373]
 [  0 375]]
Classification Report:
               precision    recall  f1-score   support

         0.0       1.00      0.01      0.01       375
         1.0       0.50      1.00      0.67       375

    accuracy                           0.50       750
   macro avg       0.75      0.50      0.34       750
weighted avg       0.75      0.50      0.34       750

Weights: [[ 0.00319615 -0.00143483 -0.00128562 ... -0.00271136  0.00181736
  -0.00137571]]
Top k indeces: [1393  689 2269 1227 1824  714 1413 2257  331 1788]


layer 1
Accuracy: 0.5053333333333333
Confusion Matrix:
 [[  4 371]
 [  0 375]]
Classification Report:
               precision    recall  f1-score   support

         0.0       1.00      0.01      0.02       375
         1.0       0.50      1.00      0.67       375

    accuracy                           0.51       750
   macro avg       0.75      0.51      0.35       750
weighted avg       0.75      0.51  

Starting from layer 7 we are able to obtain perfect prediction. However, it seems like certain neurons are always activated.

# Save weights

In [58]:
for i, weights in enumerate(weights_list):
    weights = torch.tensor(weights)
    torch.save(weights, WEIGHTS_DIR + f"layer_{i}.pt")