In [1]:
from pathlib import Path
import matplotlib.pyplot as plt
import lightning as L
import torch
import torch.nn as nn
from lit_llama import model
import random
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup

from datasets import load_dataset
from transformers import AutoTokenizer



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb

In [3]:
fabric = L.Fabric(devices=1)
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)



In [4]:
squad = load_dataset("squad", split="train[:5000]")
squad = squad.train_test_split(test_size=0.2)



Found cached dataset squad (/home/andrew/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


{'id': '56ce7bf4aab44d1400b887f5', 'title': 'IPod', 'context': 'The name iPod was proposed by Vinnie Chieco, a freelance copywriter, who (with others) was called by Apple to figure out how to introduce the new player to the public. After Chieco saw a prototype, he thought of the movie 2001: A Space Odyssey and the phrase "Open the pod bay door, Hal!", which refers to the white EVA Pods of the Discovery One spaceship. Chieco saw an analogy to the relationship between the spaceship and the smaller independent pods in the relationship between a personal computer and the music player. Apple researched the trademark and found that it was already in use. Joseph N. Grasso of New Jersey had originally listed an "iPod" trademark with the U.S. Patent and Trademark Office (USPTO) in July 2000 for Internet kiosks. The first iPod kiosks had been demonstrated to the public in New Jersey in March 1998, and commercial use began in January 2000, but had apparently been discontinued by 2001. The tradema

In [5]:
import json
with open('datasets/alpaca_data_cleaned.json') as f:
    alpaca_json = json.load(f)

# Create tokenized j
squad_train = []
squad_test = []

for item in squad['train']:
    squad_train.append(
        {
            'instruction': tokenizer.encode(item['context'], bos=True, eos=False, device=fabric.device),
            'input': tokenizer.encode(item['question'], bos=False, eos=False, device=fabric.device),
            'output':tokenizer.encode(item['answers']['text'][0], bos=False, eos=True, device=fabric.device)
        }
    )


for item in squad['test']:
    squad_test.append(
        {
            'instruction': tokenizer.encode(item['context'], bos=True, eos=False, device=fabric.device),
            'input': tokenizer.encode(item['question'], bos=False, eos=False, device=fabric.device),
            'output':tokenizer.encode(item['answers']['text'][0], bos=False, eos=True, device=fabric.device)
        }
    )

In [6]:
def get_single_example(dataset, index=None):
    if(index is None):
        index = random.sample(range(len(dataset)), k=1)[0]
    # IST
    IST = IST_generator(LLamaModel(dataset[index]['instruction'].unsqueeze(0).to(fabric.device))[1])[:,-1,:]

    # Question
    question = LLamaModel.transformer.wte(dataset[index]['input'].unsqueeze(0).to(fabric.device)).squeeze()

    # Answer fragment
    answer_len = dataset[index]['output'].size(0)
    trunc_len = random.randint(0,answer_len-1)
    #print(answer_len)
    #print(trunc_len)

    truncated_answer = dataset[index]['output'][:trunc_len]
    truncated_answer = LLamaModel.transformer.wte(truncated_answer)
    
    target_tokens = torch.cat([dataset[index]['input'], dataset[index]['output'][:trunc_len+1]])
    #print(tokenizer.decode(target_tokens))

    if(question.dim() == 1):
        question = question.unsqueeze(0)

    if(truncated_answer.dim() == 1):
        truncated_answer = truncated_answer.unsqueeze(0)

    llama_input = torch.cat([IST,question,truncated_answer])
    return llama_input.unsqueeze(0), target_tokens.type(torch.LongTensor).unsqueeze(0)
    
    

In [7]:
train_losses = []
test_losses = []

In [8]:
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth")
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")


def load_LLaMA(checkpoint_path):
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        with EmptyInitOnDevice(
                device=fabric.device, dtype=dtype, quantization_mode=None # We won't quantize the weights
        ):
            model = LLaMA.from_name(name)

        model.load_state_dict(checkpoint)
    return model

In [10]:

dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

LLaMA_config = model.LLaMAConfig.from_name('7B')
print('Loading models...')
# Load the LLaMa model and the IST generator (also a LLaMA model)
LLamaModel = load_LLaMA(checkpoint_path).to(fabric.device)
#LLamaModel = LLaMA(LLaMA_config).to(fabric.device)
print('Finished loading the first model')
print('Finished loading models')
tokenizer = Tokenizer(tokenizer_path)

IST_schemes = ['vanilla', 'last 4', '2nd to last', 'all layers']
scheme_losses = {}

IST_generator = model.Block(LLaMA_config)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(IST_generator.parameters(), lr=1e-4)
IST_generator = IST_generator.to(fabric.device)

for param in LLamaModel.parameters():
    param.requires_grad=False

Loading models...
Finished loading the first model
Finished loading models


In [11]:
optimizer = torch.optim.Adam(IST_generator.parameters(), lr=1e-5)

In [12]:
learning_rate = 1e-5
batch_size=32
trainset_size=4000
testset_size=1000

config = {
    'lr': learning_rate,
    'batch_size': batch_size,
    'trainset_size': trainset_size,
    'testset_size':testset_size,
}

In [13]:
# init wandb
wandb.init(
    project='IST QA',
    config=config,
    name="Training LLama on the SQuAD dataset"
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandrew-zeng[0m ([33msmalllanguagemodels[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
import re


def filter_string(input_string):
    input_string = input_string.lower()
    filtered_string = re.sub(r'[^a-zA-Z0-9\s]', '', input_string)
    filtered_string = re.sub(r'\n', ' ', filtered_string)
    filtered_string = re.sub(r' +', ' ', filtered_string)
    return filtered_string

# Example usage

def calculate_F1_score(
        model_output: str,
        ground_truth_output: str,
):
    model_output_words = set(filter_string(model_output).split(' '))
    ground_truth_words = set(filter_string(ground_truth_output).split(' '))
    
    shared_words = model_output_words & ground_truth_words
    if(len(shared_words) == 0):
        return 0
    #print(shared_words)

    precision = len(shared_words) / len(model_output_words)
    recall = len(shared_words) / len(ground_truth_words)

    return 2 / (1/recall + 1/precision)


In [15]:
def get_IST(string):
    tokens = tokenizer.encode(string).unsqueeze(0).type(torch.LongTensor).to(fabric.device)
    x = LLamaModel(tokens)[1]
    x = IST_generator(x)
    return x[:,-1,:]

In [16]:
def generate(model, tokenizer, prompt, IST=None, max_new_tokens=200):
  
    generated = ''
    tokenized_input = tokenizer.encode(prompt).to(fabric.device)
    with torch.no_grad():

        for _ in range(max_new_tokens):
            last_logits = model(tokenized_input.unsqueeze(0), IST.type(torch.bfloat16))[0][:,-1,:]
            new_token = torch.argmax(last_logits, dim=1)
            if(new_token == 2 and _ > 0): #eos
                break
            generated += tokenizer.decode(new_token)
            tokenized_input = torch.cat([tokenized_input, new_token])

    return tokenized_input, tokenizer.decode(tokenized_input)[len(prompt)+1:]

In [21]:


total_f1s = 0
for index, item in enumerate(list(squad['test'])[:10] ):

    context = item['context']
    question = item['question']
    answer = item['answers']['text'][0]
    _, out = generate(LLamaModel, tokenizer, question, IST=get_IST(context),max_new_tokens=len(tokenizer.encode(answer)))
    total_f1s += calculate_F1_score(out, answer)


In [27]:
loss_fn = nn.CrossEntropyLoss()
for param in LLamaModel.parameters():
    param.requires_grad=False

batch_size=32


cumulative_batch_num = 0
LLamaModel.eval()

for epoch in range(10):

    indices = list(range(4000))
    random.shuffle(indices)
    epoch_train_loss = 0

    for batch in range(len(indices) // batch_size):
        
        IST_generator.train()
        wandb_log_dict = {'batch': cumulative_batch_num}
        cumulative_batch_num += 1
        
        batch_indices = indices[:batch_size]
        indices = indices[batch_size:]
        batch_loss = 0

        optimizer.zero_grad()
        for i in range(batch_size):
            input, target = get_single_example(squad_train, index=batch_indices[i])
            llama_output = LLamaModel.forward_embeddings(input.type(torch.bfloat16))[0]
            loss = loss_fn(llama_output.squeeze().to(fabric.device), target.squeeze().to(fabric.device))
            loss.backward()
            batch_loss += loss.item()
            del llama_output

        batch_loss /= batch_size

        optimizer.step()
        train_losses.append(batch_loss)
        epoch_train_loss += batch_loss
        wandb_log_dict['batch train loss'] = batch_loss

        # validation:
        IST_generator.eval()
        with torch.no_grad():
            batch_loss = 0

            for i in range(batch_size):
                input, target = get_single_example(squad_test, index=i)
                llama_output = LLamaModel.forward_embeddings(input.type(torch.bfloat16))[0]
                loss = loss_fn(llama_output.squeeze().to(fabric.device), target.squeeze().to(fabric.device))
                del llama_output
                batch_loss += loss.item()
            batch_loss /= batch_size

            test_losses.append(batch_loss)
            wandb_log_dict['batch validation loss'] = batch_loss

            if(batch % 10 == 0):
                total_f1s = 0
                for index, item in enumerate(list(squad['test'])[:10] ):
                    context = item['context']
                    question = item['question']
                    answer = item['answers']['text'][0]
                    _, out = generate(LLamaModel, tokenizer, question, IST=get_IST(context),max_new_tokens=len(tokenizer.encode(answer)))
                    total_f1s += calculate_F1_score(out, answer)
                wandb_log_dict['F1 score'] = total_f1s / 10



        print(wandb_log_dict)
        
        wandb.log(wandb_log_dict)

    
        

{'batch': 0, 'batch train loss': 7.98193359375, 'batch validation loss': 7.71435546875, 'F1 score': 0.045}
{'batch': 1, 'batch train loss': 7.45166015625, 'batch validation loss': 6.794921875}
{'batch': 2, 'batch train loss': 6.45703125, 'batch validation loss': 5.427734375}
{'batch': 3, 'batch train loss': 6.18212890625, 'batch validation loss': 4.7080078125}
{'batch': 4, 'batch train loss': 5.138671875, 'batch validation loss': 4.3916015625}
{'batch': 5, 'batch train loss': 5.0234375, 'batch validation loss': 4.1669921875}
{'batch': 6, 'batch train loss': 4.525390625, 'batch validation loss': 4.12353515625}
{'batch': 7, 'batch train loss': 4.40283203125, 'batch validation loss': 3.9921875}
{'batch': 8, 'batch train loss': 4.251953125, 'batch validation loss': 4.11962890625}
{'batch': 9, 'batch train loss': 4.43017578125, 'batch validation loss': 4.1435546875}
{'batch': 10, 'batch train loss': 4.2275390625, 'batch validation loss': 4.0556640625, 'F1 score': 0.05333333333333333}
{'batc

In [28]:
testing_json = []
exact_match = 0

for index, item in enumerate(squad['test']):
    context = item['context']
    question = item['question']
    answer = item['answers']['text'][0]
    _, out = generate(LLamaModel, tokenizer, question, IST=get_IST(context),max_new_tokens=20)
    #print(out)
    testing_json.append({'context':context, 'question':question, 'model_output': out, 'ground_truth': answer})
    print(out, answer)
    if(out == answer):
        exact_match += 1
        print('exact match found')

his sister and his friends Jane Stirling
The 2018-2019 school year is off to a great The Tai Situpa
$500,000 $250,000.
1869 1883


In [40]:
testing_json = []
exact_match = 0

total_f1s = 0
for index, item in enumerate(squad['test']):

    if(index == 100):
        break
    context = item['context']
    question = item['question']
    answer = item['answers']['text'][0]
    _, out = generate(LLamaModel, tokenizer, question, IST=get_IST(context),max_new_tokens=20)
    #print(out)
    testing_json.append({'context':context, 'question':question, 'model_output': out, 'ground_truth': answer})
    total_f1s += calculate_F1_score(out, answer)
    print(out, answer)
    if(out == answer):
        exact_match += 1
        print('exact match found')
    

his sister and his friends Jane Stirling
The 2018-2019 school year is off to a great The Tai Situpa
$500,000 $250,000.
1869 1883
10,000 135,000
October 29, 2012 October 29, 2012
exact match found
20% 28%
The 2019-2020 school year is off to a great Rise Up
The 2018-2019 school year is off to a great Health care
The 2018-19 school year is off to a great start! not
The Funeral March Revolutionary Étude
iPod nano (1st generation) Nano
2001 2001
exact match found
Bey Hive Beyontourage
"At the Ballet" America the Beautiful
1830 September 1829
1876 1876
exact match found
10 5,335
The 2018-2019 school year is off to a great soldering tools
The 2018-2019 school year is off to a great George Clooney and Wyclef Jean
jazz
What was the name of the first black newspaper in New York? The New York Age
 jazz
St. Mary's College South Bend
The 2018-2019 school year is off to a great The British Library
Rockefeller Center Rockefeller Center
exact match found
100 60
100,000 10,000
100,000 Over 200,000
Beij

In [39]:
len(squad_test)

1000

In [37]:
total_f1s/100

0.19050144300144298

In [32]:
with open('outputs.json', 'r') as f:
    j = json.load(f)

In [34]:
total_f1s = 0
for item in j:
    total_f1s += calculate_F1_score(item['model_output'], item['ground_truth'])

print(total_f1s / len(j))

0.20301104765261133


In [27]:
testing_json

[{'context': "Chopin's public popularity as a virtuoso began to wane, as did the number of his pupils, and this, together with the political strife and instability of the time, caused him to struggle financially. In February 1848, with the cellist Auguste Franchomme, he gave his last Paris concert, which included three movements of the Cello Sonata Op. 65.",
  'question': 'Who did Chopin have at his last Parisian concert in 1848?',
  'model_output': 'Franz Liszt',
  'ground_truth': 'Auguste Franchomme'},
 {'context': 'In 2001, she became the first African-American woman and second woman songwriter to win the Pop Songwriter of the Year award at the American Society of Composers, Authors, and Publishers Pop Music Awards. Beyoncé was the third woman to have writing credits on three number one songs ("Irreplaceable", "Grillz" and "Check on It") in the same year, after Carole King in 1971 and Mariah Carey in 1991. She is tied with American songwriter Diane Warren at third with nine songwrit

In [30]:
exact_match/index

0.16716716716716717

In [41]:
import json

with open('outputs2.json', 'w') as f:
    json.dump(testing_json, f, indent=2)

In [11]:
out = generate(LLamaModel, tokenizer, question, IST=get_IST(context))

In [115]:
squad['test'][index]

{'id': '56bf76ef3aeaaa14008c9667',
 'title': 'Beyoncé',
 'context': 'Beyoncé attended St. Mary\'s Elementary School in Fredericksburg, Texas, where she enrolled in dance classes. Her singing talent was discovered when dance instructor Darlette Johnson began humming a song and she finished it, able to hit the high-pitched notes. Beyoncé\'s interest in music and performing continued after winning a school talent show at age seven, singing John Lennon\'s "Imagine" to beat 15/16-year-olds. In fall of 1990, Beyoncé enrolled in Parker Elementary School, a music magnet school in Houston, where she would perform with the school\'s choir. She also attended the High School for the Performing and Visual Arts and later Alief Elsik High School. Beyoncé was also a member of the choir at St. John\'s United Methodist Church as a soloist for two years.',
 'question': 'Which song did Beyonce sing to win a competition at age 7?',
 'answers': {'text': ['Imagine'], 'answer_start': [385]}}