In [1]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)

In [2]:
device

'cuda'

In [3]:
import pandas as pd
df = pd.read_csv('/teamspace/studios/this_studio/NLP_project/openai_reddit_final.csv', sep=';')

In [4]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83797 entries, 0 to 83796
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   Unnamed: 0    83797 non-null  int64 
 1   user_id       83797 non-null  object
 2   doc_id        83797 non-null  object
 3   user_profile  83797 non-null  object
 4   post/article  83797 non-null  object
 5   summary_text  83797 non-null  object
 6   confidence    83797 non-null  int64 
dtypes: int64(2), object(5)
memory usage: 4.5+ MB


In [5]:
data = df[['doc_id', 'post/article']]

In [6]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83797 entries, 0 to 83796
Data columns (total 2 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   doc_id        83797 non-null  object
 1   post/article  83797 non-null  object
dtypes: object(2)
memory usage: 1.3+ MB


In [7]:
mapp = {}

for idx , row in data.iterrows():
    doc_id = row['doc_id']
    text = row['post/article']

    if doc_id not in mapp:
        mapp[doc_id] = text

In [8]:
len(mapp)

6218

In [9]:
ids = list(mapp.keys())
len(ids)

6218

In [10]:
unique_texts = list(mapp.values())
len(unique_texts)

6218

In [11]:
text = "summarize: " + df.iloc[0]['post/article']

In [12]:
inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)

# Generate the summary
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)

# Decode the generated tokens
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

summary

"the only time i've flirted or dated was as an over-confident, hormone-riddled teenager. i'm no way in a rush to get into a new relationship, but that doesn't mean i want to be completely alone in the meantime."

In [46]:
batch_size = 7

def translate_batch(texts):
    prefixed_texts = ["summarize: " + text for text in texts]
    inputs = tokenizer(prefixed_texts, return_tensors="pt", max_length=512, truncation=True, padding=True).to(device)

    summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
    summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
    return summaries

In [47]:
src_texts = unique_texts[3770:]
len(src_texts)

2448

In [33]:
bs = 0

In [49]:
for i in range(0, len(src_texts), batch_size):
    batch_texts = src_texts[i:i + batch_size]
    if batch_texts == ["*"]:
        results.extend(['Random bs'])
        bs += 1
        continue
    
    translated_texts = translate_batch(batch_texts)
    results.extend(translated_texts)

In [50]:
len(results)

6218

In [51]:
gen_summaries = results

In [52]:
mapp = dict(zip(ids, gen_summaries))

In [53]:
len(mapp)

6218

In [54]:
df['t5_model_summary'] = df['doc_id'].map(mapp)

In [55]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83797 entries, 0 to 83796
Data columns (total 8 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   Unnamed: 0        83797 non-null  int64 
 1   user_id           83797 non-null  object
 2   doc_id            83797 non-null  object
 3   user_profile      83797 non-null  object
 4   post/article      83797 non-null  object
 5   summary_text      83797 non-null  object
 6   confidence        83797 non-null  int64 
 7   t5_model_summary  83797 non-null  object
dtypes: int64(2), object(6)
memory usage: 5.1+ MB


In [56]:
df.to_csv('openai_reddit_final_t5.csv', sep=';')