<a href="https://colab.research.google.com/github/DominiqueO/HS2023-Exercise08/blob/main/LLM_Assignment_1_(Coding).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM Assignment 1: Representation Surgery (Coding)
**IMPORTANT: MAKE A COPY OF THIS NOTEBOOK AND WORK IN THERE.**

Welcome to the coding part of the Representation Surgery question in Assignment 1.
The goal of this portion is to give you a deeper intuition for how representation surgery works and its strengths/limitations via hands-on implementations.

It consists of 2 parts:
1. Implementing linear concept erasure as described in the assignment, and applying it to a synthetic dataset.
2. Applying concept erasure to erase the concept of "gender" from the BERT representations of a dataset of biographies.

In this assignment, you'll have a series of tasks. For each task you will need to write some code and produce plots and/or accuracy results. You will need to export the notebook as a pdf and attach it to your submission. Please make sure all deliverables for each task (e.g., plots and accuracy results) are clearly visible, as we will be grading you on these outputs.

In [None]:
!git clone https://github.com/rycolab/llm-representation-surgery.git

In [None]:
# These are all the imports you should need
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import make_classification
from sklearn.cross_decomposition import CCA
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from typing import Tuple
np.random.seed(42)

## 0. Constructing a synthetic dataset

First, we'll give you a synthetic dataset encoding a concept, i.e., the target labels.

In [None]:
n, d, k = 2048, 128, 2
X, Z = make_classification(
    n_samples=n,
    n_informative=k,
    n_features=d,
    n_classes=k,
    random_state=1,
)
print(f"Shape of X: {X.shape}, Shape of Z: {Z.shape}")

What does this dataset look like? Visualizing 128-dimensional data is hard, so we'll reduce the dimensionality to 2 using CCA and plot the result.

In [None]:
def plot_cca(X, Z):
    """
    Use Canonical Component Analysis to visualize the dataset X and label Z.
    """
    Z_one_hot = np.zeros((len(Z), len(set(Z))))
    Z_one_hot[np.arange(len(Z)), Z] = 1
    cca = CCA(n_components=2)
    X_cca = cca.fit_transform(X, y=Z_one_hot)[0]
    plt.figure(figsize=(6, 3))
    plt.scatter(X_cca[:, 0], X_cca[:, 1], c=Z, cmap='viridis')
    plt.colorbar(label='Class')
    plt.xlabel('First Principal Component')
    plt.ylabel('Second Principal Component')
    plt.title('CCA visualization of X')
    plt.show()

plot_cca(X, Z)

As you can see, in these two components, the two concepts are fairly separable.

### Task 0.1: Implement a logistic regression classifier predict the concept $Z$ from $X$.
FOR THE SUBMISSION: Print the accuracy your classifier scores on the data.

In [None]:
def logreg(X: np.ndarray, Z: np.ndarray) -> float:
    """
    Train a logistic regression classifier to predict concept Z from X.

    Args:
        X - the data
        Z - the concept/label

    Return:
        Accuracy in predicting Z from X.
    """
    # TODO:
    pass

print(f"Accuracy before erasure: {logreg(X, Z):.3f}")

## 1. Implementing linear concept erasure

### Task 1.1: Validate that the resulting whitening and unwhitening matrices satisfy their definition and goals.
That is, verify (a) $WW^{\top} =I$ and (b) $cov(XW) = I$.

FOR THE SUBMISSION: Include your plots showing the whitening matrices.

In [None]:
def compute_W(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute the whitening and unwhitening matrix for the data using the ZCA whitening method
    (see Appendix A of https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf for more details).
    We give this to you as a numerically stable implementation.
    """
    sigma = np.cov(X.T)

    L, V = np.linalg.eigh(sigma)

    # Assuming PSD; account for numerical error
    L = L.clip(0.0)

    # Threshold used by torch.linalg.pinv
    mask = L > (L[-1] * sigma.shape[-1] * np.finfo(float).eps)

    with np.errstate(divide='ignore'):
      W = V * np.where(mask, 1/np.sqrt(L), 0.0) @ V.T
      W_inv = V * np.where(mask, np.sqrt(L), 0.0) @ V.T

    return W, W_inv

def heatmap(arr, title):
    plt.figure(figsize=(10, 8))
    sns.heatmap(arr, cmap='viridis')
    plt.title(title)
    plt.tight_layout()
    plt.show()

W, W_inv = compute_W(X)

In [None]:
def verify_W_W_inv(W, W_inv):
    """
    Plot a heatmap of W @ W_inv to verify it satisfies the definition of whitening.
    """
    # TODO
    pass

verify_W_W_inv(W, W_inv)

In [None]:
def verify_cov_XW(X, W):
    """
    Plot a heatmap of the covariance of XW to verify it satisfies the definition of whitening.
    """
    # TODO
    pass

verify_cov_XW(X, W)

### Task 1.2: Implement linear concept erasure
Use the equations derived from the written part of the assignment.

FOR THE WRITEUP: Include your plot and accuracy of the embeddings before erasure and after erasure.

In [None]:
def erase(X: np.ndarray, Z: np.ndarray) -> np.ndarray:
    """
    Implement linear concept erasure.

    Args:
        X - the data, shape (n, d)
        Z - the concept/label, shape (n,)

    Return:
        The data with the concept erased. Shape (n, d).
    """
    # TODO:
    pass

erased_X = erase(X, Z)

Now, let's validate that the concept has been erased.

In [None]:
plot_cca(erased_X, Z)
print(f"Accuracy after erasure: {logreg(erased_X, Z)}")

## 2. Erasing gender bias
In this portion, we'll apply the same method to a real dataset of word embeddings of biography data. Your goal is to (a) construct sentence embeddings and (b) erase the gender concept from these embeddings. Use the **`hard_text_untokenized`** column for your embeddings.

In [None]:
df = pd.read_csv("llm-representation-surgery/data/biographies.csv")
df.head()

### Task 2.1: Compute embeddings and use the `erase` function defined above to erase the gender concept from the embeddings.

FOR THE FOR THE SUBMISSION: Include:
(a) CCA plot of the embeddings before erasure,
(b) CCA plot of the embeddings after erasure,
(c) logistic regression accuracy before erasure, and
(d) logistic regression accuracy after erasure.

In [None]:
def extract_gender_labels(df: pd.DataFrame) -> np.ndarray:
    """
    Extract the gender labels from the dataframe.
    """
    # TODO:
    pass

def compute_embeddings(df: pd.DataFrame) -> np.ndarray:
    """
    Compute the embeddings using a pretrained model for the "hard_text_untokenized" column of the given dataframe.
    Feel free to experiment with different models; a good starting point is the classic BERT model, "google-bert/bert-base-uncased".
    Also feel free to use any libraries which would be helpful, such as `transformers` or `sentence_transformers`.
    """
    # TODO:
    pass

def erase_gender(embeddings: np.ndarray, gender_labels: np.ndarray) -> np.ndarray:
    """
    This function should do 3 things:
        1. Erase the gender concept from the embeddings and return these new embeddings.
        2. Visualize the embeddings before erasure and print the accuracy before erasure.
        3. Visualize the embeddings after erasure and print the accuracy after erasure.

    Args:
        embeddings: The embeddings to erase the gender concept from.
        gender_labels: The gender labels for the embeddings.

    Returns:
        The embeddings with the gender concept erased.
    """
    # TODO:
    pass

gender_labels = extract_gender_labels(df)
embeddings = compute_embeddings(df)
embeddings_erased = erase_gender(embeddings, gender_labels)

### Task 2.2: Verifying minimally-destructive edits.
One key aspect of concept erasure is we want the edits to be *minimal*, i.e., they should not destroy other aspects of the representations. To test the degree to which the representation is preserved, use the profession as a proxy. Show that the profession information is preserved within the embeddings by computing the accuracy (a) before erasure and (b) after erasure.

FOR THE SUBMISSION: Report the accuracy before and after erasure.

In [None]:
def visualize_profession(df: pd.DataFrame) -> np.ndarray:
    """
    Visualize the label distribution of professions, sorted from highest to lowest.
    """
    # TODO:
    pass

visualize_profession(df)

In [None]:
def verify_profession_intact(embeddings: np.ndarray, embeddings_erased: np.ndarray, df: pd.DataFrame) -> np.ndarray:
    """
    Verify the profession information is preserved within the embeddings.
    """
    # TODO:
    pass

verify_profession_intact(embeddings, embeddings_erased, df)

### Task 2.3: Show that using a nonlinear classifier, you can recover the concept from the erased embeddings.

While we've successfully removed gender from being *linearly encoded* in the embeddings, they can still be extracted using other models. Your goal is to extract the concept using some other model.

FOR THE SUBMISSION: Report the accuracy before and after erasure.

In [None]:
def nonlinear_classifier(embeddings: np.ndarray, gender_labels: np.ndarray) -> float:
    """
    Train a nonlinear classifier on the embeddings to predict the gender labels and return the accuracy.
    """
    # TODO
    pass