In [1]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from datasets import load_dataset
import pandas as pd

### Preparing the data

1. Fetching the GPT-wiki-intro dataset
2. Extracting first 1000 promts
3. Creating dataset with 1000 promts, ids and generated text. Generated text is empty for now.

In [None]:
dataset = load_dataset("aadityaubhat/GPT-wiki-intro")

In [6]:
dataset['train'][0]

{'id': 63064638,
 'url': 'https://en.wikipedia.org/wiki/Sexhow%20railway%20station',
 'title': 'Sexhow railway station',
 'wiki_intro': "Sexhow railway station was a railway station built to serve the hamlet of Sexhow in North Yorkshire, England. The station was on the North Yorkshire and Cleveland's railway line between  and , which opened in 1857. The line was extended progressively until it met the Whitby & Pickering Railway at . Sexhow station was closed in 1954 to passengers and four years later to goods. The station was located  south of Stockton, and  west of Battersby railway station. History\nThe station was opened in April 1857, when the line from Picton was opened up as far as . Mapping shows the station to have had three sidings in the goods yard, coal drops and a crane. The main station buildings were on the westbound (Picton direction) side of the station. The station was south of the village that it served, and was actually in the parish of Carlton in Cleveland, which ha

In [27]:
base_df = dataset['train'].to_pandas()

In [29]:
base_df = base_df.drop(base_df[base_df.index >= 1000].index)

In [31]:
base_df = base_df[['id', 'prompt']]

In [None]:
base_df['generated'] = ''

In [36]:
base_df.to_csv('base.csv')

### Parameters

In [27]:
# facebook/opt-1.3b
# facebook/opt-2.7b
# facebook/opt-125m
# meta-llama/Llama-2-7b-chat-hf
# meta-llama/Llama-2-13b-chat-hf
MODEL_PATH = ""

## Preparing quantization configuration

In [28]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,  # load in 4bit
    bnb_4bit_quant_type="nf4",  # use 4-bit NormalFloat quant type
    bnb_4bit_use_double_quant=True, # use 4-bit NormalFloat quant type
    bnb_4bit_compute_dtype=torch.float16,  # use type with higher precision for computations
)

## Loading model

In [30]:
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    quantization_config=quant_config,
    device_map="auto",
    # use_flash_attention_2=True
)

### Loading tokenizer

In [31]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, trust_remote_code=True
)

Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.


### Load base dataframe

In [32]:
base_df = pd.read_csv('base.csv')

### Helper functions

- `generate_text` - generates text for given prompt
- `extract_response` - extracts response from generated text

### Start generating text

In [33]:
def generate_text(promt: str):
    input_ids = tokenizer.encode(promt, return_tensors="pt")
    generated_ids = base_model.generate(
        input_ids,
        max_length=500,
        num_return_sequences=1,
        # next time probably
        # do_sample=True,
        # top_k=50, 
        # top_p=0.95, 
    )
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [34]:
def extract_response(text: str):
    return text.split('\n    ')[1]

In [None]:
from tqdm import tqdm
errors = []
for index, row in tqdm(iterable = base_df.iterrows(), desc=f"Processing model {MODEL_PATH}", total=base_df.shape[0]):
    try:
        prompt = row['prompt']
        inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
        outputs = base_model.generate(inputs, max_length=500)
        generated_text = generate_text(prompt)
        generated_text = extract_response(generated_text)
        base_df.loc[index, 'generated'] = generated_text
    except Exception as e:
        errors.append(e)
        continue

In [13]:
base_df.to_csv('Llama-2-13b-chat-hf.csv')