pip install --force-reinstall openai==1.8

In [1]:
from datasets import load_dataset
import pandas as pd
from rouge import Rouge
from openai import OpenAI
import json
from settings import *
from utils import preprocess, prompt, score, utils
import os 

from langchain.llms import OpenAI
from langchain import HuggingFaceHub, LLMChain
from langchain.prompts import load_prompt, PromptTemplate
from tqdm import tqdm

from langchain.prompts import PromptTemplate
from langchain import LLMChain
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
import warnings
warnings.filterwarnings("ignore")
                                                             
config = utils.load_json(CONFIG_DIR)

In [3]:
pubmed = utils.load_data(config['data_name'])
pubmed

Unnamed: 0,article,abstract
0,a review of the literature and an extensive me...,backgrounda review of the literature and an ex...
1,"nathan , as an oncology fellow , knew well tha...",t cells tell macrophages when to start making ...
2,temporary henna tattoos or pseudotattoo have b...,temporary henna tattoos or pseudotattoos have ...
3,care coordination is an important aspect of nu...,introductioncare coordination is an important ...
4,the laparoscopic removal of a cervical stump f...,"a 43-year - old , who underwent a subtotal hys..."
...,...,...
2964,to assess attitudes and practices of documenta...,purposes : to assess attitudes and practices o...
2965,"hearts , isolated from wildtype zebrafish embr...",electrical gradients are critical for many bio...
2966,both f-18 fluorodeoxyglucose ( fdg ) and c-11 ...,"a 10-year - boy post - operative , post - radi..."
2967,monilethrix is an autosomal dominant disorder ...,congenital hypotrichosis may be due to a numbe...


In [4]:
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
cache_dir = "/data/ephemeral/Youtube-Short-Generator/mistral"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,cache_dir=cache_dir)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300, device = 0, pad_token_id=tokenizer.eos_token_id)
hf = HuggingFacePipeline(pipeline=pipe)

In [24]:
template = """
<s>[INST]<>You are an abstractive summarizer that follows the output pattern.
Please revise the extracted summary based on the document. The revised summary should include the information in the extracted summary. Original Document: {document}<>[/INST]<\s>.
"""
doc = pubmed.iloc[0, 0]
prompt = PromptTemplate(template=template, input_variables=["document"])
llm_chain = LLMChain(prompt=prompt, llm=hf)
response = llm_chain.invoke(input = doc)

In [6]:
os.path.join(OUT_DIR, 'mistral')

'/data/ephemeral/Youtube-Short-Generator/results/mistral'

In [11]:
def get_summarization(df,save_name, iter_num = 5):
    template = """
    <s>[INST]<>You are an abstractive summarizer that follows the output pattern.
    Please revise the extracted summary based on the document. The revised summary should include the information in the extracted summary. Original Document: {document}<>[/INST]<\s>.
    """
    for i in range(iter_num):
        response_list = []
        for idx in tqdm(range(len(df))):
            doc = df.iloc[idx, 0]
            prompt = PromptTemplate(template=template, input_variables=["document"])
            llm_chain = LLMChain(prompt=prompt, llm=hf)
            response = llm_chain.invoke(input = doc)['text']
            if len(response) > 0:
                response_list.append([response, df.iloc[idx, 1]])
        df = pd.DataFrame(response_list, columns = ['generate', 'abstract'])
        df.to_csv(os.path.join(OUT_DIR, f"{save_name}_{i}.csv"), index = False)

In [13]:
test = pubmed.iloc[:2]
get_summarization(test, 'test')

100%|██████████| 2/2 [00:28<00:00, 14.39s/it]
100%|██████████| 2/2 [00:25<00:00, 12.77s/it]
100%|██████████| 2/2 [00:25<00:00, 12.98s/it]
100%|██████████| 2/2 [00:25<00:00, 12.83s/it]
100%|██████████| 2/2 [00:25<00:00, 12.76s/it]


In [46]:
def get_rouge_from_df(generate_df, rouge_type = 'rouge-l', metric = 'f'):
    df = pd.read_csv(os.path.join(OUT_DIR, generate_df))   
    value = 0
    for idx, row in df.iterrows():
        value_dic = score.get_Rouge_score(row['generate'], row['abstract'])
        value += value_dic[rouge_type][metric]
    return value / len(generate_df)
        

In [33]:
[i for i in os.listdir(OUT_DIR) if '4' in i]

['test_4.csv']

In [49]:
def get_rouge_from_all_df(save_name):
    file_list = [i for i in os.listdir(OUT_DIR) if save_name in i]
    value_list = []
    for file in file_list:
        value = get_rouge_from_df(file)
        value_list.append(value)
    return sum(value_list) / len(file_list)

In [50]:
a = get_rouge_from_all_df('test')