In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
import torch
import torch.nn as nn
import time
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP

# check if GPU is available
# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data

In [2]:
import pandas as pd

# file path
path = './nbroad-v1.csv'
# read file
df = pd.read_csv(path)
df.head()

Unnamed: 0,id,original_text,rewrite_prompt,rewritten_text
0,LNpAovroGe,"This quilt, that my mother made, \n \n Still m...",Regency Romance: Model the text on a Regency r...,"The softest brown and brightest blue quilt, cr..."
1,nnuxwwThWi,It's the job of our agency to keep track of th...,Write like Ernest Hemingway: Focus on Hemingwa...,The agency's responsibility is to track and co...
2,aYmnFCsjKl,"The first punch gets me right in the ribs, kno...",Grimm's Fairy Tales: Adapt the text to mimic t...,"In the sweltering sun, the stench of sweat and..."
3,ufIVkreRND,Some nights I lay awake staring at the ceiling...,High Fantasy Epic: Transform the essay into a ...,In the tapestry of the ethereal realm of Eldri...
4,XwLNuYdDdE,"I can hardly read the letter, because the hand...",Fairy Tale Villain: Use the menacing and craft...,"My hand quivered as I clutched the letter, the..."


In [3]:
original_text = list(df['original_text'])
rewritten_text = list(df['rewritten_text'])
rewrite_prompt = list(df['rewrite_prompt'])

In [4]:
# test
original_text[0]
# rewritten_text[0]
# rewrite_prompt[0]

"This quilt, that my mother made, \n \n Still makes me think to this day. \n \n It's softest brown, and brightest blue, \n \n The curved stitch here, reads `` made it May''. \n \n It's hard to see, but believe me it's true, \n \n That's not just a cloth but a piece of shirt. \n \n You can see a logo here, and right there, \n \n And a signature over there, someone named `` Bert''. \n \n This is my favorite part, a piece from a stuffed bear. \n \n I think it was my mother's favorite too, \n \n She always said so at least. \n \n Something from when she was two, \n \n Given by her grandad for Thanksgiving feast. \n \n My dad added this, a little button pin, \n \n Something from his mother, for being a scout. \n \n Apparently she went to a store and fished in a bin, \n \n Until night that day, to teach him what love was about. \n \n I'm sorry you had to see this, \n \n but their funeral was delayed. \n \n \n \n\n"

# Load model and tokenizer

In [5]:
model = nn.DataParallel(AutoModel.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")).cuda()

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [6]:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")

# Generate embeddings

In [7]:
def get_embeddings(sentences, model, tokenizer, batch_size=64, max_length=128):
    list_embeddings = []
    loader = DataLoader(sentences, batch_size=batch_size)  
    for batch in tqdm(loader):
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
    #     inputs = inputs.to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        list_embeddings.append(embeddings.cpu())
    return torch.cat(list_embeddings, dim=0)

In [8]:
original_text_embeddings = get_embeddings(original_text, model, tokenizer, batch_size=32)

100%|███████████████████████████████████████████| 68/68 [01:38<00:00,  1.45s/it]


In [9]:
original_text_embeddings[0:5]

tensor([[-0.2382,  0.6332, -0.1373,  ..., -0.7363,  0.5256,  1.0120],
        [ 0.0135,  2.1396,  0.4942,  ..., -0.9785, -2.5845, -2.2923],
        [ 0.4211,  0.9206, -1.2589,  ...,  0.6981, -1.6439,  2.0091],
        [ 2.1255, -0.5315,  1.8407,  ...,  2.6630, -0.4077,  1.7946],
        [ 0.1307,  0.1351, -1.2011,  ..., -0.1316, -0.0703,  0.9054]])

In [10]:
torch.save(original_text_embeddings, 'original_text_embeddings.pt')

In [11]:
rewritten_text_embeddings = get_embeddings(rewritten_text, model, tokenizer, batch_size=32)

100%|███████████████████████████████████████████| 68/68 [01:35<00:00,  1.40s/it]


In [12]:
rewritten_text_embeddings[0:5]

tensor([[ 1.0559,  1.5876,  1.2212,  ...,  0.0045,  0.5332, -1.1193],
        [ 2.2752,  0.8676,  2.2214,  ...,  1.5496,  0.2140, -1.3494],
        [-0.0169,  2.1337,  2.4846,  ...,  0.6680,  2.0917, -1.8365],
        [ 1.6355,  0.1889,  4.1800,  ...,  0.4820, -0.5707, -1.1147],
        [ 0.1879,  1.4658,  2.3471,  ...,  2.6584,  0.0825, -2.3326]])

In [13]:
torch.save(rewritten_text_embeddings, 'rewritten_text_embeddings.pt')

In [14]:
rewrite_prompt_embeddings = get_embeddings(rewrite_prompt, model, tokenizer, batch_size=32)

100%|███████████████████████████████████████████| 68/68 [00:50<00:00,  1.35it/s]


In [15]:
rewrite_prompt_embeddings[0:5]

tensor([[ 1.3169, -0.3529, -0.3525,  ..., -0.9691, -1.4382,  1.9463],
        [ 3.5835,  0.2232,  0.3202,  ..., -1.6859, -1.9136,  4.2549],
        [-0.2761,  0.1357,  0.8542,  ..., -1.6118, -2.8962, -2.1346],
        [ 2.4610,  0.1310, -1.2601,  ..., -2.3946, -0.8320,  2.4909],
        [ 1.6543,  0.3366,  1.1696,  ...,  0.4007, -0.0358,  1.1666]])

In [16]:
torch.save(rewrite_prompt_embeddings, 'rewrite_prompt_embeddings.pt')

# Predict prompts

In [None]:
predicted_prompts = []

In [25]:
# Forward pass through the model
output = model(inputs_embeds=torch.randn(1, 2, 4096).cuda())

In [30]:
output.last_hidden_state.shape

torch.Size([1, 2, 4096])