In [2]:
import os
import json
import tiktoken

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain.chains.summarize import load_summarize_chain
from dotenv import find_dotenv, load_dotenv

dotenv_path = find_dotenv()
load_dotenv(dotenv_path)

True

In [6]:
prompt_template = """Create a detailed and in-depth summary of this section of the Huberman Lab Podcast. Don't compress any information from the original transcript. If there is any information that should be added, please do so. Here is the transcript:


{text}


SUMMARY IN 10 BULLET POINTS:"""

In [7]:
def truncate_text_to_max_tokens(text: str, max_tokens: int = 4000, encoding_name: str = "gpt-3.5-turbo") -> str:
    """Truncate text from the file to a maximum number of tokens."""
    
    def num_tokens_from_string(string: str, encoding_name: str) -> int:
        """Returns the number of tokens in a text string."""
        encoding = tiktoken.encoding_for_model(encoding_name)
        num_tokens = len(encoding.encode(string))
        return num_tokens
    
    current_num_tokens = num_tokens_from_string(text, encoding_name)

    if current_num_tokens > max_tokens:
        print(f'Text truncated, num tokens: {current_num_tokens}')
        encoding = tiktoken.encoding_for_model(encoding_name)
        token_list = encoding.encode(text)
        truncated_token_list = token_list[:max_tokens]
        truncated_text = encoding.decode(truncated_token_list)
    else:
        print(f'Text not truncated, num tokens: {current_num_tokens}')
        truncated_text = text

    return truncated_text

In [None]:
def summarize_files_from_directory(input_directory, output_directory, prompt_template, model_name="gpt-3.5-turbo"):
    llm = ChatOpenAI(model_name=model_name)
    BULLET_POINT_PROMPT = PromptTemplate(template=prompt_template, 
                                        input_variables=["text"])
    chain = load_summarize_chain(llm,
                             chain_type="stuff",
                             prompt=BULLET_POINT_PROMPT)
    for filename in os.listdir(input_directory):
        full_path = os.path.join(input_directory, filename)
        save_path = os.path.join(output_directory, f'(Summary) {filename}')
        if not os.path.exists(save_path):
            print(f'Summarizing: {filename}')
            with open(full_path) as f:
                text = f.read()
            text = truncate_text_to_max_tokens(text)
            doc = [Document(page_content=text)]
            output_summary = chain.run(doc)
            with open(save_path, "w") as f:
                f.write(output_summary)
        else:
            print(f"{filename} already summarised!")

In [None]:
input_dir = os.path.join('data', 'transcripts')
output_dir = os.path.join('data', 'summaries')
summarize_files_from_directory(input_dir, output_dir, prompt_template)