## Let's implement CCS from scratch.
This will deliberately be a simple (but less efficient) implementation to make everything as clear as possible.

In [1]:
from tqdm import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression

# Let's just try IMDB for simplicity
data = load_dataset("imdb")["test"]

# Here are a few different model options you can play around with:
# model_name = "deberta"
model_name = "gpt2-xl"
# model_name = "t5"

# if you want to cache the model weights somewhere, you can specify that here
cache_dir = None

if model_name == "deberta":
    model_type = "encoder"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xxlarge", cache_dir=cache_dir)
    model = AutoModelForMaskedLM.from_pretrained("microsoft/deberta-v2-xxlarge", cache_dir=cache_dir)
elif model_name == "gpt2-xl":
    model_type = "decoder"
    tokenizer = AutoTokenizer.from_pretrained("gpt2-xl", cache_dir=cache_dir)
    model = AutoModelForCausalLM.from_pretrained("gpt2-xl", cache_dir=cache_dir)
elif model_name == "t5":
    model_type = "encoder_decoder"
    tokenizer = AutoTokenizer.from_pretrained("t5-11b", cache_dir=cache_dir)
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-11b", cache_dir=cache_dir)
else:
    print("Not implemented!")

tokenizer.pad_token = tokenizer.eos_token

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Found cached dataset imdb (/home/ubuntu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout)

## First let's write code for extracting hidden states given a model and text. 
How we do this exactly will depend on the type of model.

In [2]:
def get_encoder_hidden_states(model, tokenizer, input_text, layer=-1):
    """
    Given an encoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["hidden_states"]
    
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return hs

def get_encoder_decoder_hidden_states(model, tokenizer, input_text, layer=-1):
    """
    Given an encoder-decoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    decoder_text_ids = tokenizer("", return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, decoder_input_ids=decoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["encoder_hidden_states"]
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return hs

def get_decoder_hidden_states(model, tokenizer, input_text, layer=-1):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return hs

def get_hidden_states(model, tokenizer, input_text, layer=-1, model_type="encoder"):
    fn = {"encoder": get_encoder_hidden_states, "encoder_decoder": get_encoder_decoder_hidden_states,
          "decoder": get_decoder_hidden_states}[model_type]

    return fn(model, tokenizer, input_text, layer=layer)

## Now let's write code for formatting data and for getting all the hidden states.

In [3]:
def format_imdb(text, label):
    """
    Given an imdb example ("text") and corresponding label (0 for negative, or 1 for positive), 
    returns a zero-shot prompt for that example (which includes that label as the answer).
    
    (This is just one example of a simple, manually created prompt.)
    """
    return "The following movie review expresses a " + ["negative", "positive"][label] + " sentiment:\n" + text


def get_hidden_states_many_examples(model, tokenizer, data, model_type, n=100):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []

    # loop
    for i in tqdm(range(n)):
        # for simplicity, sample a random example until we find one that's a reasonable length
        # (most examples should be a reasonable length, so this is just to make sure)
    
        while True:
            idx = np.random.randint(len(data))
            text, true_label = data[idx]["text"], data[idx]["label"]
            # the actual formatted input will be longer, so include a bit of a marign
            if len(tokenizer(text)) < 400:  
                break
                
        # get hidden states
        try: # some examples are too long (despite the check above????)
            neg_hs = get_hidden_states(model, tokenizer, format_imdb(text, 0), model_type=model_type)
            pos_hs = get_hidden_states(model, tokenizer, format_imdb(text, 1), model_type=model_type)
        except:
            continue

        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
        all_gt_labels.append(true_label)

        
    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
    all_gt_labels = np.stack(all_gt_labels)

    return all_neg_hs, all_pos_hs, all_gt_labels

In [4]:
neg_hs, pos_hs, y = get_hidden_states_many_examples(model, tokenizer, data, model_type,n=40)

100%|██████████| 40/40 [00:08<00:00,  4.83it/s]

[array([ 0.16242853,  0.12774429,  0.5089752 , ..., -2.7614143 ,
        0.39502737, -0.43636805], dtype=float32), array([ 0.05473407,  0.09184736,  0.57591033, ..., -2.7621608 ,
        0.39023775, -0.372198  ], dtype=float32), array([ 0.21584935, -0.00651491,  0.650196  , ..., -2.749702  ,
        0.41944247, -0.38815033], dtype=float32), array([ 0.19097316,  0.11692121,  0.5743857 , ..., -2.715699  ,
        0.41009545, -0.35605434], dtype=float32), array([ 0.17253505,  0.1422063 ,  0.6351348 , ..., -2.7112942 ,
        0.4632127 , -0.3151917 ], dtype=float32), array([ 0.09781828,  0.12510385,  0.6803259 , ..., -2.7654235 ,
        0.4056766 , -0.4035735 ], dtype=float32), array([ 0.24013755,  0.10803954,  0.6503092 , ..., -2.755596  ,
        0.37842858, -0.3674418 ], dtype=float32), array([ 0.20026268,  0.04173487,  0.5700549 , ..., -2.7377753 ,
        0.35521492, -0.3898255 ], dtype=float32), array([ 0.13118097,  0.07552929,  0.6259619 , ..., -2.7407672 ,
        0.3600795 , -0.




## Let's verify that the model's representations are good

Before trying CCS, let's make sure there exists a direction that classifies examples as true vs false with high accuracy; if logistic regression accuracy is bad, there's no hope of CCS doing well.

In [5]:
# let's create a simple 50/50 train split (the data is already randomized)
n = len(y)
neg_hs_train, neg_hs_test = neg_hs[:n//2], neg_hs[n//2:]
pos_hs_train, pos_hs_test = pos_hs[:n//2], pos_hs[n//2:]
y_train, y_test = y[:n//2], y[n//2:]

# for simplicity we can just take the difference between positive and negative hidden states
# (concatenating also works fine)
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test

lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test, y_test)))

Logistic regression accuracy: 0.8


## Now let's try CCS

In [6]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear1 = nn.Linear(d, 100)
        self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        o = self.linear2(h)
        return torch.sigmoid(o)

class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1, 
                 verbose=False, device="cuda", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize

        # TODO: allow using training normalization mean and std for predictions
        self.x0 = self.normalize(x0)
        self.x1 = self.normalize(x1)
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        
        # probe
        self.linear = linear
        self.probe = self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)

        
    def initialize_probe(self):
        if self.linear:
            self.probe = nn.Linear(self.d, 1, device=self.device)
        else:
            self.probe = MLPProbe(self.d)
        self.probe.to(self.device)    


    def normalize(self, x):
        """
        Mean-normalizes the data x (of shape (n, d))
        If self.var_normalize, also divides by the standard deviation
        """
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)

        return normalized_x

    def tensor_normalize(self, x): 
        normalized_x = x - x.mean(dim=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(dim=0, keepdims=True)

        return normalized_x
        
    def get_tensor_data(self):
        """
        Returns x0, x1 as appropriate tensors (rather than np arrays)
        """
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1
    

    def get_loss(self, p0, p1):
        """
        Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
        """
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss

    def get_prediction(self, x0, x1): 
        """
        Returns a tensor of of shape (len(x0), 2) with the predicted probabilities of x0 and x1 respectively. 
        """
        p0 = self.probe(x0)
        p1 = self.probe(x1)
        return torch.cat([p0, p1], dim=1)


    def get_acc(self, x0_test, x1_test, y_test):
        """
        Computes accuracy for the current parameters on the given test inputs
        """
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        acc = max(acc, 1 - acc)

        return acc
    
        
    def train(self):
        """
        Does a single training run of nepochs epochs
        """
        x0, x1 = self.get_tensor_data()
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]
        
        # set up optimizer
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        
        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        nbatches = len(x0) // batch_size

        # Start training (full batch)
        for epoch in range(self.nepochs):
            for j in range(nbatches):
                x0_batch = x0[j*batch_size:(j+1)*batch_size]
                x1_batch = x1[j*batch_size:(j+1)*batch_size]
            
            # probe
            p0, p1 = self.probe(x0_batch), self.probe(x1_batch)

            # get the corresponding loss
            loss = self.get_loss(p0, p1)

            # update the parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return loss.detach().cpu().item()
    
    def repeated_train(self):
        best_loss = np.inf
        for train_num in range(self.ntries):
            self.initialize_probe()
            loss = self.train()
            if loss < best_loss:
                self.best_probe = copy.deepcopy(self.probe)
                best_loss = loss

        return best_loss

In [7]:
# Train CCS without any labels
ccs = CCS(neg_hs_train, pos_hs_train)
ccs.repeated_train()

# Evaluate
ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
print("CCS accuracy: {}".format(ccs_acc))

CCS accuracy: 0.8


## Test zero-shot performance

### IMDB

In [10]:
def format_imdb_zero_shot(text, label):
    return "Find the sentiment of the following text.\n" + text + "\nThe sentiment of the preceding text is"
    
def get_decoder_logits(model, tokenizer, input_text, layer=-1):
    """
    output["logits"] has shape (batch_size, seq_len, vocab_size)
    output["hidden_states"] is a length 49 tuple (number of layres + original embedding?) each of shape (batch_size, seq_len, hidden_size)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    # print(output["logits"].shape)
    # print(len(output["hidden_states"]))
    # print(output["hidden_states"][-1].shape)
    hs_tuple = output["hidden_states"]
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return output["logits"][0, -1]

def get_decoder_logits_many_examples(model, tokenizer, data, n=100):
    # setup
    model.eval()
    all_logits, all_labels = [],[]

    # loop
    for i in tqdm(range(n)):
        # for simplicity, sample a random example until we find one that's a reasonable length
        # (most examples should be a reasonable length, so this is just to make sure)
    
        while True:
            idx = np.random.randint(len(data))
            text, true_label = data[idx]["text"], data[idx]["label"]
            # the actual formatted input will be longer, so include a bit of a marign
            if len(tokenizer(text)) < 400:  
                break

        # get hidden states
        try: # some examples are too long (despite the check above????)
            logits = get_decoder_logits(model, tokenizer, format_imdb_zero_shot(text, 1))
        except:
            continue

        # collect
        all_logits.append(logits)
        all_labels.append(true_label)

        
    all_logits = np.stack(all_logits)
    all_labels = np.stack(all_labels)

    return all_logits, all_labels

In [11]:
zs_logits, zs_labels = get_decoder_logits_many_examples(model, tokenizer, data, n=40)

100%|██████████| 40/40 [00:00<00:00, 154.49it/s]


In [10]:
def calculate_sentiment_probs(logits): 
    """
    Isolate yes / no logits from the model, normalize, and compute the cross entropy loss with the 
    probe out (which is a distribution over yes / no that sums to 1). Note, yes_token is 3763 and 
    no_token is 645. positive_token is 3967 and negative_token is 4633 (for)
    """
    yes_logits = torch.tensor(logits[:, 3967])
    no_logits = torch.tensor(logits[:, 4633])
    logits = torch.stack([yes_logits, no_logits], dim=-1)
    probs = torch.softmax(logits, dim=-1)
    return probs

### Testing

In [11]:
probs = calculate_sentiment_probs(zs_logits)
est_labels = (probs[:,0]>0.5).float().numpy()
zs_performance = (est_labels == zs_labels).astype(float)
print(zs_performance.sum()/len(zs_performance))

0.717948717948718


In [12]:
est_labels

array([0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1.,
       1., 0., 0., 1., 1.], dtype=float32)

### Finetuning script

In [12]:
def get_train_data(data, n=100): 
    # for decoder only model rn
    # output is an array of len n where each entry is {"text": . "label": }
    train_data = []
    for i in tqdm(range(n)):
        while True:
            idx = np.random.randint(len(data))
            text, true_label = data[idx]["text"], data[idx]["label"]
            # the actual formatted input will be longer, so include a bit of a marign
            if len(tokenizer(text)) < 400:  
                break
        train_data.append({"text": text, "label": true_label})

    return train_data

In [13]:
def finetune_loss(batched_logits, probe_out): 
    """
    batched_logits.shape = (batch_size, vocab_size)
    probe_out.shape = (batch_size, 2)
    """
    yes_logits = torch.tensor(logits[:, 3967])
    no_logits = torch.tensor(logits[:, 4633])
    logits = torch.stack([yes_logits, no_logits], dim=-1) # batch_size x 2
    probs = torch.softmax(logits, dim=-1) # batch_size x 2
    return F.cross_entropy(probs, probe_out) 

In [14]:
def finetune(model, tokenizer, train_data, ccs, layer=-1, epochs=100, lr=0.001, batch_size=32): 
    """
    Finetune model output on train_data (same format as above) by minimizing cross entropy between 
    probe out distribution and model yes / no logit distribution. Assumes probe has already been
    trained. Assume train_data is array with each entry being a dict of {"text": , "label": }. 0 
    means negative and 1 means positive. 
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    csl = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs): 
        for i in range(0,len(train_data), batch_size):
            optimizer.zero_grad()

            # prep data (assumes formatted)
            sentences = [elem["text"] for elem in train_data[i:i+batch_size]]
            labels = [elem["label"] for elem in train_data[i:i+batch_size]]
            zs_sentences = [format_imdb_zero_shot(sentences[i], labels[i]) for i in range(len(sentences))]
            yes_sentences = [format_imdb(sentences[i], 1) + tokenizer.eos_token for i in range(len(sentences))]
            no_sentences = [format_imdb(sentences[i], 0) + tokenizer.eos_token for i in range(len(sentences))]
            print("FORMATTED BATCH")

            # get logits
            # TODO: wrap in try except loop to handle big inputs
            batched_input_ids = tokenizer(zs_sentences, return_tensors="pt", padding=True).input_ids.to(model.device)
            print("TOKENIZED ZERO SHOT INPUTS")
            print(batched_input_ids.shape)
            batched_output = model(batched_input_ids, output_hidden_states=False)
            batched_logits = batched_output["logits"][:,-1] # (batch_size, vocab_size) tensor
            print("CALCULATED ZERO SHOT LOGITS")
            print(batched_logits.shape)

            # get activations
            neg_input_ids = tokenizer(no_sentences, return_tensors="pt", padding=True).input_ids.to(model.device)
            print("B")
            output = model(neg_input_ids, output_hidden_states=True)
            print("B")
            neg_hs = output["hidden_states"][layer][:, -1].detach().cpu()
            print("CALCULATED NEGATIVE HIDDEN STATES")

            pos_input_ids = tokenizer(yes_sentences, return_tensors="pt", padding=True).input_ids.to(model.device)
            print("B")
            output = model(pos_input_ids, output_hidden_states=True)
            print("B")
            pos_hs = output["hidden_states"][layer][:, -1].detach().cpu()
            print("CALCULATED POSITIVE HIDDEN STATES")

            # get probe prediction
            x0 = ccs.tensor_normalize(neg_hs)
            x1 = ccs.tensor_normalize(pos_hs)
            probe_out = ccs.get_prediction(x0, x1)
            print(probe_out.shape)

            # compute loss
            loss = finetune_loss(batched_logits, probe_out)
            print(loss)
            loss.backward()
            optimizer.step()
        
        print("Epoch {} loss: {}".format(epoch, loss.item()))

    model.eval()

In [15]:
finetune(model, tokenizer, get_train_data(data, n=10), ccs, epochs=10, lr=0.001, batch_size=10)

100%|██████████| 10/10 [00:00<00:00, 1274.05it/s]


FORMATTED BATCH
TOKENIZED ZERO SHOT INPUTS
torch.Size([10, 609])
