In [4]:
import transformers
import torch
import os
import json
import random
import numpy as np
import argparse
from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple
import datasets
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import sys
import re
from rouge_score import rouge_scorer
from collections import defaultdict
import re

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_policy_and_tokenizer(model_name: str):

    # fetch model and tokenizer
    if model_name == 'llama2-7b':
        model_path = 'NousResearch/Llama-2-7b-hf'
        chat_template = "<s> {}"
    elif model_name == 'llama2-7b-chat':
        model_path = 'NousResearch/Llama-2-7b-chat-hf'
        chat_template = "<s>[INST] {} End your answer as, 'Hence the answer is: '. [/INST]"
    elif model_name == 'mistral-7b':
        model_path = 'mistralai/Mistral-7B-v0.1'
        chat_template = "<s> {}"
    elif model_name == 'mistral-7b-chat':
        model_path = 'mistralai/Mistral-7B-Instruct-v0.1'
        chat_template = "<s>[INST] {} End your answer as, 'Hence the answer is: '. [/INST]"
    elif model_name == 'gemma-7b':
        model_path = 'google/gemma-7b'
        chat_template = "{}"
    elif model_name == 'gemma-7b-chat':
        model_path = 'google/gemma-7b-it'
        chat_template = "<bos><start_of_turn>user\n {} <end_of_turn>\n<start_of_turn>model\n"

    policy = transformers.AutoModelForCausalLM.from_pretrained(model_path, device_map='balanced', torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side='left')
    tokenizer.pad_token_id = tokenizer.eos_token_id

    return policy, tokenizer, chat_template

In [23]:
model_name = 'llama2-7b'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

model_name = 'llama2-7b-chat'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

model_name = 'gemma-7b'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

model_name = 'gemma-7b-chat'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

model_name = 'mistral-7b'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

model_name = 'mistral-7b-chat'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)
num_params = policy.num_parameters()
print(f"Number of parameters in {model_name}: {num_params}")

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.88s/it]


Number of parameters in llama2-7b: 6738415616


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.71s/it]


Number of parameters in llama2-7b-chat: 6738415616


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.24it/s]


Number of parameters in gemma-7b: 8537680896


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Number of parameters in gemma-7b-chat: 8537680896


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.28s/it]


Number of parameters in mistral-7b: 7241732096


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


Number of parameters in mistral-7b-chat: 7241732096


In [7]:
sentences = [
    "Harry Potter is a famous wizard.",
    "Hermione Granger is known for her intelligence.",
    "Ron Weasley is Harry's best friend.",
    "The Hogwarts School of Witchcraft and Wizardry is where Harry studied.",
]

In [20]:
model_name = 'mistral-7b-chat'
policy, tokenizer, chat_template = get_policy_and_tokenizer(model_name)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.39s/it]


In [21]:
tokenizer.pad_token = tokenizer.eos_token
encoded = tokenizer.batch_encode_plus([chat_template.format(x) for x in sentences], padding=True, return_tensors='pt', max_length=128, truncation=True)
encoded = {k: v.to(policy.device) for k, v in encoded.items()}
generation = policy.generate(**encoded, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id)
tokenizer.batch_decode(generation, skip_special_tokens=True)

["[INST] Harry Potter is a famous wizard. End your answer as, 'Hence the answer is: '. [/INST] Harry Potter is a famous wizard. He is known for his bravery, intelligence, and leadership skills. He is also known for his friendship with his best friends, Hermione Granger and Ron Weasley. Harry has faced many challenges throughout his life, including battling against the dark lord Voldemort and saving the world from evil. Despite these challenges, he never gave up and always did what was right. Hence the answer is: Harry Potter is a legendary wizard known for his bravery, intelligence, and leadership skills.",
 "[INST] Hermione Granger is known for her intelligence. End your answer as, 'Hence the answer is: '. [/INST] Hermione Granger is known for her intelligence. She consistently outperforms her peers in academic subjects and uses her knowledge to solve complex problems. Her dedication to learning and her ability to think critically have helped her overcome numerous challenges throughou

In [22]:
policy.half()
tokenizer.pad_token = tokenizer.eos_token
encoded = tokenizer.batch_encode_plus([chat_template.format(x) for x in sentences], padding=True, return_tensors='pt', max_length=128, truncation=True)
encoded = {k: v.to(policy.device) for k, v in encoded.items()}
generation = policy.generate(**encoded, max_new_tokens=128, num_return_sequences=1, do_sample=False, pad_token_id=tokenizer.eos_token_id)
tokenizer.batch_decode(generation, skip_special_tokens=True)

["[INST] Harry Potter is a famous wizard. End your answer as, 'Hence the answer is: '. [/INST] Harry Potter is a famous wizard. He is known for his bravery, intelligence, and leadership skills. He is also known for his friendship with his best friends, Hermione Granger and Ron Weasley. Harry has faced many challenges throughout his life, including the death of his parents, the rise of Voldemort, and the final battle against the Dark Lord. Despite these challenges, Harry has always remained true to his values and has become an inspiration to millions of people around the world. Hence the answer is: Harry Potter is a famous wizard known for his bravery, intelligence, leadership, friendship,",
 "[INST] Hermione Granger is known for her intelligence. End your answer as, 'Hence the answer is: '. [/INST] Hermione Granger is known for her intelligence. She consistently outperforms her peers in academic subjects and uses her knowledge to solve complex problems. Her dedication to learning and h