In [1]:
!pip install scikit-learn==0.24.2 torch==1.9.0 transformers==4.8.2 PyAutoFact==0.1.17 datasets==1.10.2



In [2]:
import os, sys
from itertools import chain
import datasets
import random
import torch
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
from py_auto_fact import auto_fact
import numpy as np
from tqdm import tqdm
import hashlib

In [3]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [4]:
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')

In [5]:
count_param(model)

774030080

# Apply partial factorization to GPT2 model

In [6]:
# Only factorize last one-third of transformer layers of the GPT2 model
factorizable_submodules = list(model.transformer.h[-(model.config.n_layer // 3):])

In [7]:
%%time
fact_model = auto_fact(model, rank=384, deepcopy=True, solver='svd', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 2min 30s, sys: 5.38 s, total: 2min 36s
Wall time: 17.3 s


632472320

# Speed test on CPU

### Test Inference CPU

In [8]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(2, 64, dtype=torch.long))

252 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(2, 64, dtype=torch.long))

220 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [10]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [11]:
x = torch.zeros(2,64, dtype=torch.long).cuda()

In [12]:
%%timeit
with torch.no_grad():
    y = model(x)

33.2 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

34.5 ms ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Prepare Dataset and DataLoader

In [14]:
class SSTDataset(Dataset):
    # Static constant variable
    NUM_LABELS = 2

    def __init__(self, data_split, exp_args, *args, **kwargs):
        self.data_split = data_split
        self.exp_args = exp_args

        if data_split == 'train':
            self.dataset = datasets.load_dataset('sst')['train']
        elif data_split == 'validation':
            self.dataset = datasets.load_dataset('sst')['validation']
        elif data_split == 'test':
            self.dataset = datasets.load_dataset('sst')['test']
        else:
            raise ValueError(f'Invalid dataset split: `{data_split}`')

    def __getitem__(self, index):
        label = np.round(self.dataset[index]['label'])
        text = self.dataset[index]['sentence']
        return text, label

    def __len__(self):
        return self.dataset.num_rows

In [15]:
def generate_prompt(texts_by_labels, labels, test_samples):
    prompts = []
    for label_1 in labels:
        pos_samples = texts_by_labels[label_1]
        neg_samples = []
        prefix = ""        
        for label_2 in labels:
            if label_1 != label_2:
                neg_samples = neg_samples + texts_by_labels[label_2]

        all_samples = pos_samples + neg_samples
        random.shuffle(all_samples)

        for sample in all_samples:
            text, label = sample["text"], sample["label"]
            if label != label_1:
                prefix = prefix + text + "=>" + label_1 + "=false\n"
            else:
                prefix = prefix + text + "=>" + label_1 + "=true\n"
        prompts.append([prefix, label_1])
    
    few_shot_prompts = []
    for sample in test_samples:
        prompt_per_label = []
        for prompt in prompts:
            prefix, label = prompt
            new_prompt = prefix + sample["text"] + "=>" + label + "="
            prompt_per_label.append(new_prompt)
        few_shot_prompts.append(prompt_per_label)

    return few_shot_prompts

def generate_sst_dataset(k_shot):
    texts_by_labels = {}
    IDX_TO_LABELS = {}

    train_dataset = SSTDataset('train', None)
    test_dataset = SSTDataset('test', None)

    IDX_TO_LABELS = {0: "negative", 1: "positive"}
    for i in range(len(train_dataset)):
        text, label = train_dataset[i]
        if IDX_TO_LABELS[label] not in texts_by_labels:
            texts_by_labels[IDX_TO_LABELS[label]] = []
        texts_by_labels[IDX_TO_LABELS[label]].append({"text":text, "label":IDX_TO_LABELS[label]})

    test_samples = []
    for i in range(len(test_dataset)):
        text, label = test_dataset[i]
        test_samples.append({"text":text, "label":IDX_TO_LABELS[label]})

    for label in texts_by_labels:
        random.shuffle(texts_by_labels[label])
    targets = ["negative", "positive"]
    
    for label in texts_by_labels:
        texts_by_labels[label] = texts_by_labels[label][:k_shot]

    few_shot_samples = generate_prompt(texts_by_labels, targets, test_samples)
    return few_shot_samples, test_samples, targets

In [16]:
few_shot_samples, test_samples, targets = generate_sst_dataset(10)

No config specified, defaulting to: sst/default
Reusing dataset sst (/home/samuel/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff)
No config specified, defaulting to: sst/default
Reusing dataset sst (/home/samuel/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff)


# Run In-Context Learning

In [17]:
def score_next(model, tokenizer, encoded, token):
    with torch.no_grad():
        # print(encoded.size(), token.size())
        outputs = model(encoded)
        next_token_logits = outputs.logits

        def _log_softmax(x):
            maxval = np.max(x)
            logsum = np.log(np.sum(np.exp(x - maxval)))
            return x - maxval - logsum

        next_token_logits = next_token_logits[:,-1].squeeze()
        # print(next_token_logits.size())
        scores = _log_softmax(next_token_logits.cpu().detach().numpy())
        del next_token_logits
        return scores[int(token)]

def argmax(array):
    """argmax with deterministic pseudorandom tie breaking."""
    max_indices = np.arange(len(array))[array == np.max(array)]
    idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(),16) % len(max_indices)
    return max_indices[idx]

def logsumexp(x):
    c = x.max()
    return c + np.log(np.sum(np.exp(x - c)))

def normalize(x):
    x = np.array(x)
    return np.exp(x - logsumexp(x))

def calculate_log_prob_gpt(model, tokenizer, prefix, targets):
    label2id = {}
    for target in targets:
        # works for single token label e.g., true or false, yes or no
        # label2id[target] = tokenizer.convert_tokens_to_ids(target)
        label2id[target] = tokenizer(target, truncation=True)["input_ids"][0] # only take the first token

    tokenized = tokenizer(list([prefix]), truncation=True, return_tensors="pt")
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    
    input_ids = input_ids.cuda()
    attention_mask = attention_mask.cuda()
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits.squeeze()[-1]
        prob = torch.nn.functional.softmax(logits, dim=-1)
        prob = prob.cpu().detach().numpy()
    normalized_scores = []

    for c in targets:
        score = prob[label2id[c]]
        normalized_scores.append(score)

    pred = targets[argmax(normalized_scores)]
    return pred, np.array(normalized_scores)

In [18]:
golds, preds = [], []
pbar = tqdm(iter(few_shot_samples), leave=True, total=len(few_shot_samples))
for id, batch in enumerate(pbar):
    prompts = few_shot_samples[id]
    test_sample = test_samples[id]
    all_scores = []
    for prompt in prompts:
        pred, normalized_scores = calculate_log_prob_gpt(model, tokenizer, prompt, ["true", "false"])
        all_scores.append(normalized_scores)

    highest_score_idx = 0
    highest_score = 0
    for k in range(len(all_scores)):
        if all_scores[k][0] > highest_score:
            highest_score = all_scores[k][0]
            highest_score_idx = k

    pred = targets[highest_score_idx]
    gold = test_samples[id]["label"]

    golds.append(gold)
    preds.append(pred)

acc = accuracy_score(preds, golds) * 100
f1 = f1_score(golds, preds, average='macro') * 100
print(f"EVAL SCORE | ACC: {acc} F1: {f1}")

100%|█████████████████████████████████████████████████| 2210/2210 [16:30<00:00,  2.23it/s]

EVAL SCORE | ACC: 63.98190045248868 F1: 60.091651542649736





In [19]:
golds, preds = [], []
pbar = tqdm(iter(few_shot_samples), leave=True, total=len(few_shot_samples))
for id, batch in enumerate(pbar):
    prompts = few_shot_samples[id]
    test_sample = test_samples[id]
    all_scores = []
    for prompt in prompts:
        pred, normalized_scores = calculate_log_prob_gpt(fact_model, tokenizer, prompt, ["true", "false"])
        all_scores.append(normalized_scores)

    highest_score_idx = 0
    highest_score = 0
    for k in range(len(all_scores)):
        if all_scores[k][0] > highest_score:
            highest_score = all_scores[k][0]
            highest_score_idx = k

    pred = targets[highest_score_idx]
    gold = test_samples[id]["label"]
    
    golds.append(gold)
    preds.append(pred)

acc = accuracy_score(preds, golds) * 100
f1 = f1_score(golds, preds, average='macro') * 100
print(f"EVAL SCORE | ACC: {acc} F1: {f1}")

100%|█████████████████████████████████████████████████| 2210/2210 [15:10<00:00,  2.43it/s]

EVAL SCORE | ACC: 68.68778280542986 F1: 66.80748287980441



