In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BitsAndBytesConfig
import os
from accelerate import PartialState


# if torch.cuda.is_available():
#     torch.set_default_device("cuda")
# else:
#     torch.set_default_device("cpu")

os.environ['HF_TOKEN'] = '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'
os.environ['HF_HOME'] = '/data_vault/hexai/huggingface/hub/'

model_type = 't5-small' # orca13b
model_id = "google-t5/t5-small"

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_quant_type="nf8",
    bnb_8bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'], cache_dir=os.environ['HF_HOME'], use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id, quantization_config=bnb_config, device_map="auto", token=os.environ['HF_TOKEN'], cache_dir=os.environ['HF_HOME']
)

# model = AutoModelForCausalLM.from_pretrained(f"nlp/model/{model_type}", device_map="cuda:1", torch_dtype=torch.float16)
# tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'], cache_dir=os.environ['HF_HOME'])

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
import pandas as pd
data = "/data_vault/hexai/Biolaysum/biolaysumm2024_data/eLife_val.jsonl"
elife_train = pd.read_json(path_or_buf=data, lines=True)

In [3]:
from transformers import pipeline
from langchain import HuggingFacePipeline, PromptTemplate, LLMChain

text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text2text-generation",
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    batch_size=4,
    repetition_penalty=1.1,
    max_new_tokens=600,
    temperature = 0.3,
    do_sample=True,
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

In [4]:
import tiktoken
def num_tokens_from_string(string: str, encoding_name: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [5]:
from langchain.chains import MapReduceDocumentsChain, LLMChain, ReduceDocumentsChain, StuffDocumentsChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate

In [6]:
from langchain.document_loaders import JSONLoader

def metadata_func(record: dict, metadata: dict) -> dict:

    metadata["lay_summary"] = record.get("lay_summary")

    return metadata




In [7]:
def load_json():
    # Load the pdf file
    loader = JSONLoader(
        file_path="/data_vault/hexai/Biolaysum/biolaysumm2024_data/eLife_val.jsonl",
        jq_schema='.',
        content_key="article",
        metadata_func=metadata_func,
        json_lines=True
    )

    documents = loader.load()

    token_count = num_tokens_from_string(str(documents), "cl100k_base")
    print(f'JSON Token Count: {token_count}')
    return documents, token_count


In [8]:
docs, counts = load_json()

JSON Token Count: 3475695


In [10]:
19585/512

38.251953125

In [85]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=3000)

In [107]:
splits = text_splitter.create_documents([docs[40].page_content])

In [108]:
len(splits)

8

In [109]:
from transformers import T5Tokenizer, T5EncoderModel
import torch

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
t5_enc_model = T5EncoderModel.from_pretrained("google-t5/t5-small").to("cuda")
extractor = pipeline(
    model=t5_enc_model,
    tokenizer=tokenizer,
    task="feature-extraction"
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [110]:
embedd = extractor(splits[0].page_content)

Token indices sequence length is longer than the specified maximum sequence length for this model (1515 > 512). Running this sequence through the model will result in indexing errors


In [111]:
document_embeddings = []

In [112]:
import numpy as np
for doc in splits:
    embedd = extractor(doc.page_content)
    mean_embedd = np.array(embedd).mean(axis=1).squeeze(axis=0)
    document_embeddings.append(mean_embedd)

In [113]:
nump_embedd = np.array(document_embeddings)

In [114]:
nump_embedd.shape

(8, 512)

In [116]:
# Assuming 'embeddings' is a list or array of 1536-dimensional embeddings

# Choose the number of clusters, this can be adjusted based on the book's content.
# I played around and found ~10 was the best.
# Usually if you have 10 passages from a book you can tell what it's about
num_clusters = 5

from sklearn.cluster import KMeans
# Perform K-means clustering
kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(nump_embedd)


In [117]:
# Find the closest embeddings to the centroids

# Create an empty list that will hold your closest points
closest_indices = []

# Loop through the number of clusters you have
for i in range(num_clusters):
    
    # Get the list of distances from that particular cluster center
    distances = np.linalg.norm(nump_embedd - kmeans.cluster_centers_[i], axis=1)
    
    # Find the list position of the closest one (using argmin to find the smallest distance)
    closest_index = np.argmin(distances)
    
    # Append that position to your closest indices list
    closest_indices.append(closest_index)

In [118]:
selected_indices = sorted(closest_indices)
selected_indices

[0, 1, 3, 4, 6]

In [129]:
map_prompt  = """ Summarize:
```{text}```

"""

In [130]:
map_prompt_template = PromptTemplate(template=map_prompt, input_variables=["text"])


In [131]:
from langchain.chains import load_summarize_chain
map_chain = load_summarize_chain(llm=llm,
                                 chain_type="stuff",
                                 
                                 prompt=map_prompt_template)

In [132]:
selected_docs = [splits[doc] for doc in selected_indices]


In [None]:
# Make an empty list to hold your summaries
summary_list = []

# Loop through a range of the lenght of your selected docs
for i, doc in enumerate(selected_docs):
    
    # Go get a summary of the chunk
    chunk_summary = map_chain.run([doc])
    
    # Append that summary to your list
    summary_list.append(chunk_summary)
    
    print (f"Summary #{i} (chunk #{selected_indices[i]}) - Preview: {chunk_summary[:250]} \n")

Summary #0 (chunk #0) - Preview: . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 

Summary #1 (chunk #1) - Preview: compared to the control monkeys . a neurologist and two neuroscientists compared the results to those obtained . monkey S performed a brain scan of the lGN . a scan of the lGN showed an almost complete loss of primary visual cortex . 



In [124]:
from langchain.schema import Document

summaries = "\n".join(summary_list)

# Convert it back to a document
summaries = Document(page_content=summaries)

print (f"Your total summary has {llm.get_num_tokens(summaries.page_content)} tokens")

Your total summary has 232 tokens


In [125]:
combine_prompt = """
Summarize:
"""
combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=["text"])

In [126]:
reduce_chain = load_summarize_chain(llm=llm,
                             chain_type="stuff",
                             prompt=combine_prompt_template)

In [127]:
output = reduce_chain.run([summaries])

In [128]:
print (output)

a series of summaries from a medical article. The summaries will be enclosed in triple backticks () Your goal is to give a verbose summary of what happened in the book . 'a single passage of a medical article. a single passage of a medical article. a single passage of a medical article. a single passage of a medical article. a single passage of a medical article. a single passage of a medical article. a
