# N-shot learning with Bloom LLM
This notebook contains code for testing n-shot learning for MGT detection using the LLM Bloom.

## Prepare data (WikiGPT-dataset)

In [1]:
from datasets import load_dataset
import pandas as pd

ds = load_dataset("aadityaubhat/GPT-wiki-intro", split="train")

Using custom data configuration aadityaubhat--GPT-wiki-intro-10ad8b711a5f3880
Found cached dataset csv (C:/Users/andre/.cache/huggingface/datasets/aadityaubhat___csv/aadityaubhat--GPT-wiki-intro-10ad8b711a5f3880/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


In [2]:
ds = ds.train_test_split(test_size=0.33333)
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'wiki_intro', 'generated_intro', 'title_len', 'wiki_intro_len', 'generated_intro_len', 'prompt', 'generated_text', 'prompt_tokens', 'generated_text_tokens'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['id', 'url', 'title', 'wiki_intro', 'generated_intro', 'title_len', 'wiki_intro_len', 'generated_intro_len', 'prompt', 'generated_text', 'prompt_tokens', 'generated_text_tokens'],
        num_rows: 50000
    })
})

## Prepare LLM (bloomz-560m)

In [3]:
from transformers import BloomForCausalLM, AutoTokenizer

model = BloomForCausalLM.from_pretrained("bigscience/bloomz-560m", num_labels=2).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")

In [4]:
model_config = {
    'do_sample': True,
    'max_new_tokens': 1,
    'top_k': 1000,
    'top_p': 0.999,
    'temperature': 0.5, 
    'repetition_penalty': 2.4,
}

def complete(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(inputs["input_ids"], **model_config)    
    return tokenizer.decode(outputs[0][-1])

In [9]:
ds['test'][0]['wiki_intro']

'William Simon U\'Ren (January 10, 1859 – March 8, 1949) was an American lawyer and political activist. U\'Ren promoted and helped pass a corrupt practices act, the presidential primary, and direct election of U.S. senators. As a progressive, U\'Ren championed the initiative, referendum, and recall systems in an effort to bring about a Georgist "Single Tax" on the unimproved value of land, but these measures were also designed to promote democracy and weaken the power of backstage elites.  His reforms in Oregon were widely copied in other states. He supported numerous other reforms, such as the interactive model of proportional representation, which was not enacted. Early life\nWilliam Simon U\'Ren (accent the last syllable) was born on January 10, 1859 in Lancaster, Wisconsin, the son of immigrants from Cornwall, England. Their surname was originally spelled Uren. U\'Ren\'s father, William Richard U\'Ren was a socialist who worked as a blacksmith and emigrated to America owing to diff

In [6]:
import random
max_len_example = 1000

sample = ""
while len(sample) > max_len_example or len(sample) == 0:
    sample = random.choice(ds['test'])['wiki_intro']
len(sample)

873

## n-shot testing

In [7]:
from tqdm import tqdm
from time import sleep
import random

init_prompt = """Following is a text-classification model classifying if a text is human-written or machine-generated. The following text will be classified as either 'True' or 'False', where 'True' means human-written and 'False' is machine-generated:
[EXAMPLES]

Predict:
"[TEXT]"

Prediction (True/False): """

false_pos = 0
false_neg = 0
true_pos = 0
true_neg = 0
na_pred = 0

n = 100     # how many tests to run
n_shot = 2  # how many examples of each label to provide
max_len_example = 1000 # max-length of each example (so we don't run out of VRAM) - should be cleaned in dataset instead but I'm lazy atm

for i in tqdm(range(n)):
    # Sample examples for n-shot learning
    prompt = init_prompt
    examples = "Examples:\n" if n_shot >0 else ""
    for s in range(n_shot*2):
        is_true = s>=n_shot
        label = 'wiki_intro' if is_true else 'generated_intro'
        sample = ""
        while len(sample) > max_len_example or len(sample) == 0:
            sample = random.choice(ds['test'])[label]
        examples += f"\nExample {s+1}\nInput: \"" + sample + "\"\nExpected output: " + str(is_true) + "\n"
    prompt = prompt.replace("[EXAMPLES]", examples)

    # Sample an example for prediction
    is_true = random.randint(0,1) == 1 #(i%2 == 0)
    label = 'wiki_intro' if is_true else 'generated_intro'
    sample = random.choice(ds['test'])[label]
    prompt = prompt.replace("[TEXT]", sample)
    # print(prompt)
    # Predict and evaluate
    prediction = complete(prompt)

    if ('true' in prediction.lower() or 'yes' in prediction.lower()):
        if (is_true):
            true_pos += 1
        else:
            false_pos += 1
    elif ('false' in prediction.lower() or 'none' in prediction.lower() or 'no' in prediction.lower()):
        if (is_true):
            false_neg += 1
        else:
            true_neg += 1
    else:
        print("N/A prediction: " + prediction)
        na_pred += 1

sleep(0.1)
print(f"""
n-shot: {n_shot}
n_samples: {n}

Acc: {(true_pos + true_neg) / n}
False-positive: {false_pos / n}
False-negative: {false_neg / n}
N/A-predictions: {na_pred / n}

Acc (non-N/A): {(true_pos + true_neg) / (n-na_pred)}
False-positive (non-N/A): {false_pos / (n-na_pred)}
False-negative (non-N/A): {false_neg / (n-na_pred)}
""")


100%|██████████| 100/100 [03:32<00:00,  2.12s/it]


n-shot: 2
n_samples: 100

Acc: 0.49
False-positive: 0.19
False-negative: 0.32
N/A-predictions: 0.0

Acc (non-N/A): 0.49
False-positive (non-N/A): 0.19
False-negative (non-N/A): 0.32






In [8]:
n = i
print(f"""
n-shot: {n_shot}
n_samples: {n}

Acc: {(true_pos + true_neg) / n}
False-positive: {false_pos / n}
False-negative: {false_neg / n}
N/A-predictions: {na_pred / n}

Acc (non-N/A): {(true_pos + true_neg) / (n-na_pred)}
False-positive (non-N/A): {false_pos / (n-na_pred)}
False-negative (non-N/A): {false_neg / (n-na_pred)}
""")


n-shot: 2
n_samples: 99

Acc: 0.494949494949495
False-positive: 0.1919191919191919
False-negative: 0.32323232323232326
N/A-predictions: 0.0

Acc (non-N/A): 0.494949494949495
False-positive (non-N/A): 0.1919191919191919
False-negative (non-N/A): 0.32323232323232326

