In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [2]:
import os
from huggingface_hub import hf_hub_download

repo_id = "divyapatel4/StableBeluga-13B-jax"
filename = "StableBeluga-13B.pickle"

local_filepath = hf_hub_download(repo_id=repo_id, filename=filename,cache_dir='/mnt/mydisk/models')


In [3]:
# !rm -rf llama2-13B-jax
# !git lfs install
# !git clone https://huggingface.co/divyapatel4/llama2-7B-jax
# !git lfs install
# !git clone https://huggingface.co/datasets/divyapatel4/Microsoft-PeNS


In [4]:
from huggingface_hub import hf_hub_download
import os
import jax
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer, AutoTokenizer
from tqdm import tqdm
from lib.LLM import Llama
from lib.logits_processing import PresencePenaltyProcessor, TopKSampler, TopPSampler, make_logits_processor
from lib.param_utils import load_params
from lib.multihost_utils import shard_model_params
from lib.seeding import BEST_INTEGER

In [5]:
def load_params_from_disk() -> Llama:
    cpu_device = jax.devices('cpu')[0]
    with jax.default_device(cpu_device):
        params = load_params(local_filepath)
        params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
    params = shard_model_params(params)
    return params


print("..... START LOADING .....")
top_k = 6
params = load_params_from_disk()
print('Successfully loaded model parameters!')

..... START LOADING .....


tcmalloc: large alloc 2097160192 bytes == 0x9dc76000 @  0x7f0300442680 0x7f0300463824 0x5e4640 0x63e74d 0x6a71b2 0x550866 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0x636cc9 0x639a74 0x592245 0x70e39c 0x645bf4
tcmalloc: large alloc 2097160192 bytes == 0x11b4ea000 @  0x7f0300442680 0x7f0300463824 0x5e4640 0x63e74d 0x6a71b2 0x550866 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0x636cc9 0x639a74 0x592245 0x70e39c 0x645bf4
tcmalloc: large alloc 2097160192 bytes == 0x1984ec000 @  0x7f0300442680 0x7f0300463824 0x5e4640 0x63e74d 0x6a71b2 0x550866 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0x

Successfully loaded model parameters!


In [6]:
key = rand.key(BEST_INTEGER, impl='rbg')
tokenizer = AutoTokenizer.from_pretrained('stabilityai/StableBeluga-13B', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
logits_processor = make_logits_processor(
    PresencePenaltyProcessor(penalty=0.05),
    TopKSampler(top_k=top_k)
)

tokenizer_config.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

In [7]:
from functools import partial
from typing import NamedTuple

import einops as op
import jax
from jax import Array
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer
from typing import Callable

from lib.LLM import KVCache, Llama, RotaryValues, forward_llama_model, get_rotary_values_at_position, make_rotary_values, model_config_llama2_7B, shift_left_kv_cache
from lib.LLM import check_llama, model_config_llama1_7B, model_config_llama2_13B, model_config_llama2_70B, model_config_llama2_7B, model_config_orca2_13B, model_config_orca2_7B, model_config_llama1_30B, model_config_llama1_13B, model_config_solar_10_7B, zephyr_config_3B,mistral_config_7B

model_conf = model_config_llama2_13B

@partial(jax.jit, static_argnames=('logits_processor',))
def _generate_first(params: Llama, seq: Array, attn_mask: Array, logits_processor: Callable, *, rotary_values: RotaryValues, key: Array) -> tuple[Array, Array, Array, KVCache]:
    qk_mask = op.rearrange(jnp.tril(op.einsum(attn_mask, attn_mask, 'B L1, B L2 -> B L1 L2')), 'B L1 L2 -> B 1 1 L1 L2')  # causal QK mask
    outputs, kv_cache = forward_llama_model(params.model, seq, qk_mask, rotary_values=rotary_values, model_config=model_conf._replace(return_kv_cache=True))

    logits = outputs[:, -1] @ params.lm_head
    selected_token_ids = logits_processor(logits, seq=seq, attn_mask=attn_mask, key=key)

    seq = jnp.roll(seq, -1, axis=-1).at[:, -1].set(selected_token_ids)
    attn_mask = jnp.roll(attn_mask, -1, axis=-1).at[:, -1].set(True)
    kv_cache = shift_left_kv_cache(kv_cache)

    return seq, attn_mask, selected_token_ids, kv_cache

class GenerationState(NamedTuple):
    seq: Array
    attn_mask: Array
    selected_token_ids: Array
    max_n_iters: Array
    rotary_values: RotaryValues
    rotary_values_position: Array
    kv_cache: KVCache
    key: Array

@partial(jax.jit, static_argnames=('logits_processor',))
def _generate_rest(params: Llama, seq: Array, attn_mask: Array, selected_token_ids: Array, max_n_iters: Array, logits_processor: Callable, *, rotary_values: RotaryValues, kv_cache: KVCache, key: Array) -> Array:
    def cond_fun(state: GenerationState) -> Array:
        return state.max_n_iters.astype(jnp.bool_)

    def body_fun(state: GenerationState) -> GenerationState:
        seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key = state

        seq_ = op.rearrange(selected_token_ids, 'B -> B 1')
        qk_mask = op.rearrange(attn_mask, 'B L -> B 1 1 1 L')
        rotary_values_ = get_rotary_values_at_position(rotary_values, rotary_values_position)
        outputs, kv_cache = forward_llama_model(params.model, seq_, qk_mask, rotary_values=rotary_values_, kv_cache=kv_cache, model_config=model_conf._replace(return_kv_cache=True))

        logits = outputs[:, -1] @ params.lm_head
        key, subkey = rand.split(key)
        selected_token_ids = logits_processor(logits, seq=seq, attn_mask=attn_mask, key=subkey)

        seq = jnp.roll(seq, -1, axis=-1).at[:, -1].set(selected_token_ids)
        attn_mask = jnp.roll(attn_mask, -1, axis=-1).at[:, -1].set(True)
        kv_cache = shift_left_kv_cache(kv_cache)

        rotary_values_position += 1
        max_n_iters -= 1
        # TODO: early stopping (ayaka's comment). Since the generation continues untill it reaches maximum length,
        # we have to include eos token to determine the end of generation
        return GenerationState(seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key)

    rotary_values_position = jnp.array(0, jnp.uint16)
    initial_state = GenerationState(seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key)
    final_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)
    return final_state.seq

def generate(sentences: list[str], tokenizer: LlamaTokenizer, params: Llama, logits_processor: Callable, *, max_len: int, key: Array) -> list[str]:
    batch_size = len(sentences)

    inputs = tokenizer(sentences, padding='max_length', truncation=True, max_length=max_len, return_tensors='jax')
    seq = inputs.input_ids.astype(jnp.uint16)
    attn_mask = inputs.attention_mask.astype(jnp.bool_)
    assert not attn_mask.all(axis=-1).any(), 'No room for generation since the length of a sentence is greater than `max_length`.'

    leftpad_len = attn_mask.argmax(axis=-1).astype(jnp.uint16)
    rotary_values = make_rotary_values(leftpad_len, batch_size, max_len, model_config=model_conf)

    key, subkey = rand.split(key)
    seq, attn_mask, selected_token_ids, kv_cache = _generate_first(params, seq, attn_mask, logits_processor, rotary_values=rotary_values, key=subkey)

    max_n_iters = leftpad_len.min()
    key, subkey = rand.split(key)
    seq = _generate_rest(params, seq, attn_mask, selected_token_ids, max_n_iters, logits_processor, rotary_values=rotary_values, kv_cache=kv_cache, key=subkey)
    return tokenizer.batch_decode(seq, skip_special_tokens=False) # Not skipping special tokens is the only reason we have to type this function ourselves

In [8]:
import ast
import pandas as pd

pers = pd.read_csv('Microsoft-PeNS/personalization/pers_preprocessed.csv',sep='\t')
pers['context'] = pers['context'].apply(ast.literal_eval)
pers.head()

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Unnamed: 0.1,Unnamed: 0,userID,clicknewsID,posnewID,rewrite_titles,context,News body,Category,Topic,Headline,Title entity
0,0,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N24110,Legal battle looms over Trump EPA's rule chang...,[Nike faces backlash after pulling 'Betsy Ross...,Democratic state attorney generals and environ...,news,newspolitics,High-stakes legal fight looms over Trump pollu...,{'Trump': 'Donald Trump'}
1,1,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N62769,Wise choices for stylish updating of old homes,[Nike faces backlash after pulling 'Betsy Ross...,We love old houses. Their architectural styles...,lifestyle,lifestylehomeandgarden,The One Thing That Immediately Makes Your Hous...,{}
2,2,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N36186,Verlander may be reconsidering his stance on M...,[Nike faces backlash after pulling 'Betsy Ross...,Justin Verlander made headlines earlier in the...,sports,baseball_mlb,Justin Verlander got 'chewed out' by MLB befor...,"{'Verlander': 'Justin Verlander', 'MLB': 'Nati..."
3,3,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N101669,Infamous o.j. Simpson launching official Twitt...,[Nike faces backlash after pulling 'Betsy Ross...,LOS ANGELES O.J. Simpson launched a Twitter ...,tv,tvnews,OJ Simpson on Twitter: 'I got a little gettin'...,{}
4,4,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N19241,15 year old cori gauff beats Venus Williams at...,[Nike faces backlash after pulling 'Betsy Ross...,"WIMBLEDON, England (AP) Coco Gauff grew up a...",sports,tennis,"Gauff, just 15, shocks 5-time champ Venus, 39,...",{'Venus': 'Venus Williams'}


In [9]:
import random

def prompt_generator_withcontext(news_body, context_list, context_no = 20, max_tokens=2000):
    selected_context = context_list[:context_no]
    context_body = ', '.join(selected_context)
    prompt_start = f"""Use the preferences of the user from the news headlines user is interested in and then see the news body of the article and generate a personalized headline for the given article news body.
    
    User preferences (Given as the titles of news articles the user reads or prefers): {context_body}
    
    News Article: """
    
    prompt_end = "\n Your task is to generate a professional and personalized news article headline for the above article, using the user preferences given at the beginning. The headline should be concise, accurate, and engaging. Do not provide any explanations or reasons for the headline. JUST GIVE ONE SINGLE HEADLINE NOTHING ELSE. \n\n Personalized News Headline for above article :"
    start_tokens = len(tokenizer.encode(prompt_start))
    end_tokens = len(tokenizer.encode(prompt_end))
    remaining_tokens = max_tokens - start_tokens - end_tokens
    news_body_tokens = tokenizer.encode(news_body)[:remaining_tokens]
    truncated_news_body = tokenizer.decode(news_body_tokens)
    truncated_news_body = truncated_news_body.replace('<s>', '').replace('</s>', '')

    prompt = prompt_start + truncated_news_body + prompt_end
    return prompt


def prompt_generator_nocontext(news_body, context_list, context_no = 20, max_tokens=2000):
    selected_context = context_list[:context_no]
    context_body = ', '.join(selected_context)
    prompt_start = f"""
    News Article: """
    
    prompt_end = "\n Your task is to generate a professional and personalized news article headline for the above article, using the user preferences given at the beginning. The headline should be concise, accurate, and engaging. Do not provide any explanations or reasons for the headline. JUST GIVE ONE SINGLE HEADLINE NOTHING ELSE. \n\n Personalized News Headline for above article : "
    start_tokens = len(tokenizer.encode(prompt_start))
    end_tokens = len(tokenizer.encode(prompt_end))
    remaining_tokens = max_tokens - start_tokens - end_tokens
    news_body_tokens = tokenizer.encode(news_body)[:remaining_tokens]
    truncated_news_body = tokenizer.decode(news_body_tokens)
    truncated_news_body = truncated_news_body.replace('<s>', '').replace('</s>', '')

    prompt = prompt_start + truncated_news_body + prompt_end
    return prompt


In [10]:
import re
def extract_headline(text):
    # print(text)
    pattern = r"Personalized News Headline for above article :(.*?)</s>"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None

In [11]:
from huggingface_hub import hf_hub_download
import os
import jax
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer, AutoTokenizer
from tqdm import tqdm
from lib.LLM import Llama
from lib.logits_processing import PresencePenaltyProcessor, TopKSampler, TopPSampler, make_logits_processor
from lib.param_utils import load_params
from lib.multihost_utils import shard_model_params
from lib.seeding import BEST_INTEGER

In [12]:
pers.head()

Unnamed: 0.1,Unnamed: 0,userID,clicknewsID,posnewID,rewrite_titles,context,News body,Category,Topic,Headline,Title entity
0,0,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N24110,Legal battle looms over Trump EPA's rule chang...,[Nike faces backlash after pulling 'Betsy Ross...,Democratic state attorney generals and environ...,news,newspolitics,High-stakes legal fight looms over Trump pollu...,{'Trump': 'Donald Trump'}
1,1,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N62769,Wise choices for stylish updating of old homes,[Nike faces backlash after pulling 'Betsy Ross...,We love old houses. Their architectural styles...,lifestyle,lifestylehomeandgarden,The One Thing That Immediately Makes Your Hous...,{}
2,2,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N36186,Verlander may be reconsidering his stance on M...,[Nike faces backlash after pulling 'Betsy Ross...,Justin Verlander made headlines earlier in the...,sports,baseball_mlb,Justin Verlander got 'chewed out' by MLB befor...,"{'Verlander': 'Justin Verlander', 'MLB': 'Nati..."
3,3,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N101669,Infamous o.j. Simpson launching official Twitt...,[Nike faces backlash after pulling 'Betsy Ross...,LOS ANGELES O.J. Simpson launched a Twitter ...,tv,tvnews,OJ Simpson on Twitter: 'I got a little gettin'...,{}
4,4,NT1,"['N108480', 'N38238', 'N35068', 'N110487', 'N9...",N19241,15 year old cori gauff beats Venus Williams at...,[Nike faces backlash after pulling 'Betsy Ross...,"WIMBLEDON, England (AP) Coco Gauff grew up a...",sports,tennis,"Gauff, just 15, shocks 5-time champ Venus, 39,...",{'Venus': 'Venus Williams'}


In [None]:
batch_size = 16
pers['generated_headline_withcontext'] = ''
pers['generated_headline_nocontext'] = ''

for i in range(0, len(pers), batch_size):
    batch = pers.iloc[i:i+batch_size]

    prompts = batch.apply(lambda row: prompt_generator_withcontext(row['News body'], row['context'], context_no=25, max_tokens=1900), axis=1)

    key, subkey = rand.split(key)

    outputs = generate(prompts.tolist(), tokenizer, params, logits_processor, max_len=2048, key=subkey)
    for j, output in enumerate(outputs):
        headline = output
        print(extract_headline(output))
        print('\n')
        print('============================================================================')
        pers.loc[i+j, 'generated_headline_withcontext'] = extract_headline(headline)
    
    pers.to_csv('StableBeluga-13B.csv', sep='~', index=False)


for i in range(0, len(pers), batch_size):
    batch = pers.iloc[i:i+batch_size]

    prompts = batch.apply(lambda row: prompt_generator_nocontext(row['News body'], row['context'], context_no=25, max_tokens=1900), axis=1)

    key, subkey = rand.split(key)

    outputs = generate(prompts.tolist(), tokenizer, params, logits_processor, max_len=2048, key=subkey)
    for j, output in enumerate(outputs):
        headline = output
        print(extract_headline(output))
        print('\n')
        print('============================================================================')
        pers.loc[i+j, 'generated_headline_nocontext'] = extract_headline(headline)
    
    pers.to_csv('StableBeluga-13B.csv', sep='~', index=False)
    

Token indices sequence length is longer than the specified maximum sequence length for this model (5506 > 4096). Running this sequence through the model will result in indexing errors


Trump Administration Sued Over Rollback of Power Plant Pollution Rule


"Update Your Outdated Kitchen Countertops Without Breaking the Bank"


MLB officials scold Justin Verlander for accusing MLB of juicing balls


"O.J. Simpson Launches Twitter Account, Plans to 'Straighten Some Things Out'"


Teen Tennis Sensation Cori "Coco" Gauff Stuns Venus Williams at Wimbledon


Vermont Leads the Way in Promoting Literacy Skills Across the United States


Eagles' Front Office Announcements, Offensive Weapons, and More Discussed in BGN Radio Episode 59


Retirees, Maximize Your Golden Years with Smart Money Moves and Travel Experiences


Theme Park Express Pass vs. Line Wait Times


CBS News Anchor Angie Miles Says Goodbye for Second Time


Apple's Jony Ive Leaves Apple to Launch Independent Design Company LoveFrom


"Iranian Ships Vanish and Transfer Oil to Skirt Sanctions, Re-Appearing Weeks Later"


Baltimore Housing Complex Residents Suffer from Continuing Water Issues


"Astronauts Prepare 