In [None]:
import sys 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
import pandas as pd
from tqdm.notebook import tqdm
# seed = 42
import random
# # set seed for reproducibility
# torch.manual_seed(seed)

## Load 7B-it Model

In [None]:
# # Load the model
# VARIANT = "7b-it" 
# MACHINE_TYPE = "cuda" 
# weights_dir = f'D:\LLMs\gemma-{VARIANT}-weights' 

# @contextlib.contextmanager
# def _set_default_tensor_type(dtype: torch.dtype):
#   """Sets the default torch dtype to the given dtype."""
#   torch.set_default_dtype(dtype)
#   yield
#   torch.set_default_dtype(torch.float)

# model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
# model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

# device = torch.device(MACHINE_TYPE)
# with _set_default_tensor_type(model_config.get_dtype()):
#   model = GemmaForCausalLM(model_config)
#   ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
#   model.load_weights(ckpt_path)
#   model = model.to(device).eval()

## Load 7B-it-quant Model

In [None]:
# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = f'gemma-{VARIANT}-weights' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

In [None]:
PATH_TO_DATASET = 'gen_data.xlsx'
df = pd.read_csv(PATH_TO_DATASET)

In [None]:
df

In [None]:
torch.cuda.empty_cache()

In [None]:
new_template = "<start_of_turn>user\n{rewrite_prompt}\n{original_text}<end_of_turn>\n<start_of_turn>model\nRewritten Text: ```"

In [None]:
rewrite_data = {
    'original_text': [],
    'rewrite_prompt': [] ,
    'rewritten_text': [],
}
BATCH_SIZE = 2

def check_newline(text):
    if text[:2] == '\n':
        text = text[2:]
    
    if text[-2:] == '\n':
        text = text[:-2]
    
    return text
    
prompts = []
orig_text = []
rewrite_prompts = []
for i, (original_text, rewrite_prompt) in enumerate(zip(df.text.values, df.prompt.values), start = 1):
    prompts.append(new_template.format(rewrite_prompt=rewrite_prompt, original_text=original_text))
    orig_text.append(original_text)
    rewrite_prompts.append(rewrite_prompt)
    if i % BATCH_SIZE == 0 or i == len(df)-1:
        print("processing prompts from {} to {}".format(i-BATCH_SIZE+1, i))
        rewritten_text = model.generate(
            prompts,
            device=device,
            output_len=300,
        )
        rewritten_text = [x.split('```')[0].strip() for x in rewritten_text]
        rewritten_text = [check_newline(x) for x in rewritten_text]
        rewrite_data['original_text'].extend(orig_text)
        rewrite_data['rewrite_prompt'].extend(rewrite_prompts)
        rewrite_data['rewritten_text'].extend(rewritten_text)
        torch.cuda.empty_cache()
        prompts = []
        orig_text = []
        rewrite_prompts = []

if len(prompts) > 0:
    print("processing prompts from {} to {}".format(i-BATCH_SIZE+1, i))
    rewritten_text = model.generate(
        prompts,
        device=device,
        output_len=300,
    )
    rewritten_text = [x.split('```')[0].strip() for x in rewritten_text]
    rewritten_text = [check_newline(x) for x in rewritten_text]
    rewrite_data['original_text'].extend(orig_text)
    rewrite_data['rewrite_prompt'].extend(rewrite_prompts)
    rewrite_data['rewritten_text'].extend(rewritten_text)
    torch.cuda.empty_cache()

In [None]:
rewrite_data_df = pd.DataFrame(rewrite_data)
rewrite_data_df[:1].values

In [None]:
rewrite_data_df.to_excel('dataset.xlsx', index=False)