In [1]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 107, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 107 (delta 29), reused 17 (delta 17), pack-reused 68[K
Receiving objects: 100% (107/107), 2.14 MiB | 10.95 MiB/s, done.
Resolving deltas: 100% (56/56), done.


In [2]:
import sys
sys.path.append("/kaggle/working/gemma_pytorch/")

import contextlib
import gc
import os
import pandas as pd
import pickle
import torch

from threading import Thread

from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
from tqdm.auto import tqdm

tqdm.pandas()

torch.cuda.empty_cache()
gc.collect()

0

In [3]:
# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

# Model Config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

models = []

with _set_default_tensor_type(model_config.get_dtype()):
  for i in range(torch.cuda.device_count()):
    model = GemmaForCausalLM(model_config)
    ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
    model.load_weights(ckpt_path)
    device = torch.device(f"cuda:{i}")
    models.append(model.to(device).eval())

  return self.fget.__get__(instance, owner)()


In [5]:
def generate(ot, rp, model, device) -> str:
    USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"

    prompt = (
        USER_CHAT_TEMPLATE.format(
            prompt=f"{ot}\n\n{rp}"
        )
        + "<start_of_turn>model\n"
    )

    return model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device
    )

In [6]:
def gen_df(df, model, device):
    df['rewritten_text'] = df.progress_apply(lambda row: generate(ot=row['original_text'], rp=row['rewrite_prompt'], model=model, device=device), axis=1)

In [7]:
original_texts = [pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv").original_text[0]]
rewrite_prompts = ["Rewrite this as a poem.",
                   "Rewrite this as a rap song.",
                   "Improve that text.",
                   "Rewrite this in a more professional way.",
                   "Simplify this text.",
                   "Translate this into Shakespearean English."]

In [8]:
original_texts_df = pd.DataFrame({'original_text': original_texts})
rewrite_prompts_df = pd.DataFrame({'rewrite_prompt': rewrite_prompts})
data = original_texts_df.merge(rewrite_prompts_df, how='cross')

In [9]:
%%time

threads = []

N = len(data)
half = round(N/2)

sub1 = data[0:half].copy()
sub2 = data[half:N].copy()

t0 = Thread(target=gen_df, args=(sub1,models[0],"cuda:0"))
t1 = Thread(target=gen_df, args=(sub2,models[1],"cuda:1"))

t0.start()
t1.start()

threads.append(t0)
threads.append(t1)

# Wait all threads to finish.
for t in threads:
    t.join()

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

CPU times: user 3min 23s, sys: 30.8 s, total: 3min 54s
Wall time: 1min 59s


In [10]:
output = pd.concat([sub1, sub2])
output.to_csv("output.csv", index=False)

In [11]:
output

Unnamed: 0,original_text,rewrite_prompt,rewritten_text
0,The competition dataset comprises text passage...,Rewrite this as a poem.,"A text rewritten, a gem's delight,\nThe Gemma ..."
1,The competition dataset comprises text passage...,Rewrite this as a rap song.,"Yo, listen up to the tale of text rewritten go..."
2,The competition dataset comprises text passage...,Improve that text.,The text above describes a code competition wh...
3,The competition dataset comprises text passage...,Rewrite this in a more professional way.,The competition dataset presents text passages...
4,The competition dataset comprises text passage...,Simplify this text.,"Sure, here's simplified text:\n\nThe competiti..."
5,The competition dataset comprises text passage...,Translate this into Shakespearean English.,"Sure, here's the translation into Shakespearea..."
