<h1 style="text-align: center; font-size: 50px;">Text Summarization with LangChain</h1>

This notebook demonstrates how to build a semantic chunking and summarization pipeline for texts using LangChain, sentence transformers for semantic chunking, and LLMs for generating summaries.

# Notebook Overview
- Start Execution
- Install and Import Libraries
- Configure Settings
- Verify Assets
- Data Loading
- Semantic Chunking
- Model Setup
- Summarization Chain Creation
- Model Service

# Start Execution

In [1]:
import logging
import time

# Configure logger
logger: logging.Logger = logging.getLogger("run_workflow_logger")
logger.setLevel(logging.INFO)
logger.propagate = False  # Prevent duplicate logs from parent loggers

# Set formatter
formatter: logging.Formatter = logging.Formatter(
    fmt="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# Configure and attach stream handler
stream_handler: logging.StreamHandler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

In [2]:
start_time = time.time()  

logger.info("Notebook execution started.")

2025-09-16 15:40:50 - INFO - Notebook execution started.


# Install and Import Libraries

Most of the libraries that are necessary for the development of this example are built-in on the GenAI workspace, available in AI Studio. More specific libraries to handle the type of input will be added here. In this case, we are giving support to texts in the webvtt format, used to store texts, which require the webvtt-py library.

In [3]:
%%time

%pip install -r ../requirements.txt --quiet

Note: you may need to restart the kernel to use updated packages.
CPU times: user 23.2 ms, sys: 18.6 ms, total: 41.9 ms
Wall time: 1.3 s


In [4]:
import os
import sys
import logging

# Define the relative path to the 'src' directory (two levels up from current working directory)
src_path = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Add 'src' directory to system path for module imports (e.g., utils)
if src_path not in sys.path:
    sys.path.append(src_path)

# === Standard Library Imports ===
import os
import sys
import logging
import json
import time
import warnings
from datetime import datetime
from pathlib import Path

# === Third-Party Imports ===
import numpy as np
import pandas as pd
import webvtt
import mlflow
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from operator import itemgetter

# === Project-Specific Imports (from src.utils) ===
from src.utils import (
    load_config,
    load_secrets,
    load_secrets_to_env,
    configure_proxy,
    initialize_llm,
    configure_hf_cache
)
from src.prompt_templates import format_chunk_summarization_prompt

  from tqdm.autonotebook import tqdm, trange


# Configure Settings

In [5]:
warnings.filterwarnings("ignore")

In [6]:
# === Constants ===
# Model and experiment configuration
SENTENCE_TRANSFORMER_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
RUN_NAME = "Text_Summarization_Service"
PROJECT_NAME = "AIStudio_template_code_summarization"
EVALUATION_RUN_NAME = "textrization_evaluation"

# Path configuration
CONFIG_PATH = "../configs/config.yaml"
SECRETS_PATH = "../configs/secrets.yaml"
DATA_PATH = "../data"
MODEL_PATH = "/home/jovyan/datafabric/meta-llama3.1-8b-Q8/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"

# Text processing configuration
CHUNK_SEPARATOR = "\n\n"

# Summarization of Texts with Langchain

In this example, we intend to create a summarizer for long texts. The main goal is to break the original text into different chunks based on context - i.e. using an unsupervised approach to identify the different topics throughout the text (somehow similarly to Topic Modelling) - and summarize each of these chunks. in the end, the different summaries are returned to the user.

### Configuration of Hugging face caches

In the next cell, we configure HuggingFace cache, so that all the models downloaded from them are persisted locally, even after the workspace is closed. This is a future desired feature for AI Studio and the GenAI addon.

In [7]:
# Add the src directory to the path to import utils
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

# Configure HuggingFace cache
configure_hf_cache()

## Configuration and Secrets Loading

In this section, we load configuration parameters and API keys from separate YAML files. This separation helps maintain security by keeping sensitive information (API keys) separate from configuration settings.

- **config.yaml**: Contains non-sensitive configuration parameters like model sources and URLs
- **secrets.yaml**: Contains sensitive API keys for services like HuggingFace
- *(Optional for Premium users)* Secrets such as API keys for services like HuggingFace can be stored as environment variables for the project and loaded into the notebook (see the project's README file for steps on how to save secrets in Secrets Manager).

In [8]:
# Load secrets from secrets.yaml file (if it exists) into environment
if Path(SECRETS_PATH).exists():
    load_secrets_to_env(SECRETS_PATH)
else:
    print(f"No secrets file found at {SECRETS_PATH}; relying on preexisting environment")

# Retrieve secrets from environment
try:
    secrets = load_secrets()
except ValueError:
    secrets = {}

# Load configuration and secrets
config = load_config(CONFIG_PATH)

print("✅ Configuration loaded successfully")
print("✅ Secrets loaded successfully")

No secrets file found at ../configs/secrets.yaml; relying on preexisting environment
✅ Configuration loaded successfully
✅ Secrets loaded successfully


### Proxy Configuration

For certain enterprise networks, accessing external services might require an explicit setup of the proxy configuration. If this is your case, set up the "proxy" field on your config.yaml and the following cell will configure the necessary environment variable.

In [9]:
configure_proxy(config)

# Verify Assets

In [10]:
def log_asset_status(asset_path: str, asset_name: str, success_message: str, failure_message: str) -> None:
    """
    Logs the status of a given asset based on its existence.

    Parameters:
        asset_path (str): File or directory path to check.
        asset_name (str): Name of the asset for logging context.
        success_message (str): Message to log if asset exists.
        failure_message (str): Message to log if asset does not exist.
    """
    if Path(asset_path).exists():
        logger.info(f"{asset_name} is properly configured. {success_message}")
    else:
        logger.info(f"{asset_name} is not properly configured. {failure_message}")


# Check and log status for BERT model, embeddings file, and tokenizer
log_asset_status(
    asset_path=MODEL_PATH,
    asset_name="Local Llama model",
    success_message="",
    failure_message="Please create and download the required assets in your project on AI Studio if you want to use local model."
)

2025-09-16 15:40:58 - INFO - Local Llama model is properly configured. 


## Step 1: Loading the data from the text

At first, we need to read the data from the text. As our text is in the .vtt format, we use a library called webvtt-py to read the content. As the text is a trancript of audio/video, it is organized in small chunks of conversation, each containing a sequential id, the time of the start and end of the chunk, and the text content (often in the form speaker:content).

From this data, we expect to extract the actual content,  while keeping reference to the other metadata - for this reason, we are loading all the data into a Pandas dataset. 

In [11]:
if not os.path.exists(DATA_PATH):
    raise FileNotFoundError(f"'data' folder not found in path: {os.path.abspath(DATA_PATH)}")

file_path = os.path.join(DATA_PATH, "I_have_a_dream.vtt")

data = {
    "id": [],
    "speaker": [],
    "content": [],
    "start": [],
    "end": []
}

for caption in webvtt.read(file_path):
    line = caption.text.split(":")
    while len(line) < 2:
        line = [''] + line
    data["id"].append(caption.identifier)
    data["speaker"].append(line[0].strip())
    data["content"].append(line[1].strip())
    data["start"].append(caption.start)
    data["end"].append(caption.end)
    
df = pd.DataFrame(data)

df.head()

Unnamed: 0,id,speaker,content,start,end
0,1,,I am happy to join with you today,00:00:00.880,00:00:03.920
1,2,,in what will go down in history,00:00:06.500,00:00:09.360
2,3,,as the greatest demonstration for freedom in t...,00:00:11.720,00:00:16.460
3,4,,nation.,00:00:16.460,00:00:17.293
4,5,,"Five score years ago,",00:00:26.410,00:00:28.740


As a second option, we provide here a code to load the same structure from a plain text document, which only contains the actual content of the speech/conversation, without extra metadata. For the sake of simplicity and reuse of code, we keep the same Data Frame structure as the previous version, by filling the remaining fields with empty strings.

In [12]:
with open(file_path) as file:
    lines = file.read()

data = {
    "id": [],
    "speaker": [],
    "content": [],
    "start": [],
    "end": []
}

for line in lines.split("\n"):
    if line.strip() != "":
        data["id"].append("")
        data["speaker"].append("")
        data["content"].append(line.strip())
        data["start"].append("")
        data["end"].append("")        
        
df = pd.DataFrame(data)

df.head()

Unnamed: 0,id,speaker,content,start,end
0,,,﻿WEBVTT,,
1,,,1,,
2,,,00:00:00.880 --> 00:00:03.920,,
3,,,<v 0>I am happy to join with you today</v>,,
4,,,2,,


## Step 2: Semantic chunking of the text
Having the information content loaded according to the text format - with the text split into audio blocks, or into paragraphs, we now want to group these small blocks into relevant topics - so we can summarize each topic individually. Here, we are using a very simple approach for that, by using a semantic embedding of each sentence (using an embedding model from Hugging Face Sentence Transformers), and identifying the "breaks" among chunks as the ones with higher semantic distance. Notice that this method can be parameterized, to inform the number of topics or the best method to identify the breaks.

In [13]:
embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME)
embeddings = embedding_model.encode(df.content)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [14]:
class SemanticSplitter():
    """
    A class for semantically splitting text into coherent chunks based on embeddings.
    This class uses embedding-based distance metrics to identify topic transitions in text.
    """
    def __init__(self, content, embedding_model, method="number", partition_count=10, quantile=0.9, clustering_method=None, n_clusters=None):
        """
        Initialize the SemanticSplitter.
        
        Args:
            content: List of text segments to process and split
            embedding_model: Model to use for generating text embeddings
            method: Chunking method - 'number' (fixed number of breaks), 'quantiles' (threshold-based), or 'clustering'
            partition_count: Number of breaks to create when using 'number' method
            quantile: Threshold quantile to use when using 'quantiles' method
            clustering_method: Which clustering algorithm to use ('kmeans', 'hierarchical', None)
            n_clusters: Number of clusters to create when using clustering method
        """
        try:
            self.content = content
            self.embedding_model = embedding_model
            self.partition_count = partition_count
            self.quantile = quantile
            self.clustering_method = clustering_method
            self.n_clusters = n_clusters if n_clusters is not None else partition_count
            
            logger.info(f"Encoding {len(content)} content items with embedding model")
            self.embeddings = embedding_model.encode(content)
            logger.info(f"Generated embeddings with shape: {self.embeddings.shape}")
            
            # Calculate distances between consecutive embeddings
            self.distances = [cosine(self.embeddings[i - 1], self.embeddings[i]) for i in range(1, len(self.embeddings))]
            self.breaks = []
            self.centroids = []
            
            # Load break points using the specified method
            self.load_breaks(method=method)
            logger.info(f"Created {len(self.breaks)} breaks using method '{method}'")
        except Exception as e:
            logger.error(f"Error initializing SemanticSplitter: {str(e)}")
            raise

    def centroid_distance(self, embedding_id, centroid_id):
        """
        Calculate cosine distance between an embedding and a centroid.
        
        Args:
            embedding_id: Index of the embedding to compare
            centroid_id: Index of the centroid to compare
            
        Returns:
            Cosine distance between the embedding and centroid
        """
        if not self.centroids:
            logger.warning("Centroids haven't been loaded. Call load_centroids() first.")
            return 1.0  # Return max distance if no centroids
            
        try:
            return cosine(self.embeddings[embedding_id], self.centroids[centroid_id])
        except IndexError as e:
            logger.error(f"Invalid index in centroid_distance: {str(e)}")
            return 1.0  # Return max distance on error
        except Exception as e:
            logger.error(f"Error in centroid_distance: {str(e)}")
            return 1.0  # Return max distance on error

    def adjust_neighbors(self, window_size=3, distance_threshold=0.7):
        """
        Adjust break points by examining neighboring segments to improve coherence.
        This helps avoid breaking semantic units that should stay together.
        
        Args:
            window_size: Number of neighboring segments to consider
            distance_threshold: Threshold for merging nearby segments
        """
        if not self.breaks:
            logger.info("No breaks to adjust")
            return
            
        logger.info(f"Adjusting {len(self.breaks)} breaks with window size {window_size}")
        
        try:
            adjusted_breaks = []
            # Sort breaks to process them in order
            sorted_breaks = sorted(self.breaks)
            
            for i, break_pos in enumerate(sorted_breaks):
                # Skip if this break is too close to the previous adjusted break
                if adjusted_breaks and break_pos - adjusted_breaks[-1] < window_size:
                    continue
                    
                # Check surrounding context for better break point
                best_pos = break_pos
                best_dist = self.distances[break_pos]
                
                # Look at nearby positions for potentially better break points
                start = max(0, break_pos - window_size)
                end = min(len(self.distances) - 1, break_pos + window_size)
                
                for j in range(start, end + 1):
                    if j != break_pos and self.distances[j] > best_dist:
                        best_pos = j
                        best_dist = self.distances[j]
                        
                # Add the optimized break position
                adjusted_breaks.append(best_pos)
                
            self.breaks = sorted(list(set(adjusted_breaks)))
            logger.info(f"Adjusted breaks count: {len(self.breaks)}")
        except Exception as e:
            logger.error(f"Error adjusting neighbors: {str(e)}")
            # Keep original breaks on error

    def load_breaks(self, method='number'):
        """
        Load break points based on the specified method.
        
        Args:
            method: Method to determine breaks - 'number', 'quantiles', or 'clustering'
        """
        try:
            if method == 'number':
                # Ensure we don't request more breaks than possible
                if self.partition_count > len(self.distances):
                    logger.warning(f"Requested {self.partition_count} breaks but only {len(self.distances)} positions available.")
                    self.partition_count = len(self.distances)
                    
                # Find the partition_count highest distance positions
                self.breaks = np.sort(np.argpartition(self.distances, len(self.distances) - self.partition_count)[-self.partition_count:])
                logger.info(f"Created {len(self.breaks)} breaks using fixed number method")
                
            elif method == 'quantiles':
                # Find positions with distance above the quantile threshold
                threshold = np.quantile(self.distances, self.quantile)
                self.breaks = [i for i, v in enumerate(self.distances) if v >= threshold]
                logger.info(f"Created {len(self.breaks)} breaks using quantile method with threshold {threshold:.4f}")
                
            elif method == 'clustering':
                # Use clustering algorithms to group similar segments
                self._cluster_embeddings()
                logger.info(f"Created {len(self.breaks)} breaks using clustering method")
                
            else:
                logger.warning(f"Unknown method: {method}. No breaks created.")
                self.breaks = []
        except Exception as e:
            logger.error(f"Error loading breaks with method '{method}': {str(e)}")
            self.breaks = []

    def _cluster_embeddings(self):
        """
        Cluster embeddings using the specified clustering method.
        Sets breaks at the boundaries between clusters.
        """
        try:
            from sklearn.cluster import KMeans, AgglomerativeClustering
            
            if len(self.embeddings) < self.n_clusters:
                logger.warning(f"Not enough samples ({len(self.embeddings)}) for {self.n_clusters} clusters.")
                self.n_clusters = max(2, len(self.embeddings) // 2)
            
            # Choose clustering algorithm based on configuration
            if self.clustering_method == 'kmeans':
                logger.info(f"Performing KMeans clustering with {self.n_clusters} clusters")
                clustering = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
            else:  # Default to hierarchical clustering
                logger.info(f"Performing Hierarchical clustering with {self.n_clusters} clusters")
                clustering = AgglomerativeClustering(n_clusters=self.n_clusters)
            
            # Fit and predict cluster labels
            labels = clustering.fit_predict(self.embeddings)
            logger.info(f"Clustering complete, found {len(set(labels))} clusters")
            
            # Find transitions between clusters
            all_transitions = []
            for i in range(1, len(labels)):
                if labels[i] != labels[i-1]:
                    all_transitions.append((i-1, self.distances[i-1] if i-1 < len(self.distances) else 0))
            
            logger.info(f"Found {len(all_transitions)} raw transitions between clusters")
            
            # Filter transitions to avoid excessive fragmentation
            
            # Apply minimum chunk size (at least 3 segments per chunk)
            min_chunk_size = 3
            valid_transitions = []
            last_break = -1
            
            for idx, (break_pos, _) in enumerate(all_transitions):
                if break_pos - last_break >= min_chunk_size:
                    valid_transitions.append((break_pos, all_transitions[idx][1]))
                    last_break = break_pos
            
            # Sort by significance (distance) and limit to maximum number of breaks
            max_breaks = min(self.partition_count, len(valid_transitions))
            valid_transitions.sort(key=lambda x: x[1], reverse=True)
            significant_breaks = [x[0] for x in valid_transitions[:max_breaks]]
            
            # Sort breaks in sequential order
            self.breaks = sorted(significant_breaks)
            logger.info(f"After filtering: using {len(self.breaks)} breaks between clusters")
            
        except ImportError:
            logger.error("scikit-learn not available. Install it for clustering support.")
            self.breaks = []
        except Exception as e:
            logger.error(f"Error in clustering: {str(e)}")
            self.breaks = []

    def get_centroid(self, beginning, end):
        """
        Calculate centroid embedding for a range of content.
        
        Args:
            beginning: Start index (inclusive)
            end: End index (exclusive)
            
        Returns:
            Centroid embedding for the specified content range
        """
        try:
            if beginning >= end or beginning < 0 or end > len(self.content):
                logger.warning(f"Invalid range: {beginning}-{end}")
                return np.zeros(self.embeddings[0].shape)
                
            text = '\n'.join(self.content[beginning:end])
            return self.embedding_model.encode(text)
        except Exception as e:
            logger.error(f"Error calculating centroid: {str(e)}")
            if len(self.embeddings) > 0:
                return np.zeros(self.embeddings[0].shape)
            return np.zeros(384)  # Default embedding size if unknown
    
    def load_centroids(self):
        """
        Load centroids for each chunk after breaks have been calculated.
        """
        logger.info("Loading centroids for chunks")
        try:
            if len(self.breaks) == 0:
                self.centroids = [self.get_centroid(0, len(self.content))]
                logger.info("Created 1 centroid for the entire content")
            else:
                self.centroids = []
                beginning = 0
                for break_position in sorted(self.breaks):
                    self.centroids.append(self.get_centroid(beginning, break_position + 1))
                    beginning = break_position + 1
                self.centroids.append(self.get_centroid(beginning, len(self.content)))
                logger.info(f"Created {len(self.centroids)} centroids")
        except Exception as e:
            logger.error(f"Error loading centroids: {str(e)}")
            self.centroids = []

    def get_chunk(self, beginning, end):
        """
        Get content chunk between specified indices.
        
        Args:
            beginning: Start index (inclusive)
            end: End index (exclusive)
            
        Returns:
            Chunk of content as a single string
        """
        try:
            if beginning >= end or beginning < 0 or end > len(self.content):
                logger.warning(f"Invalid chunk range: {beginning}-{end}")
                return ""
            return '\n'.join(self.content[beginning:end])
        except Exception as e:
            logger.error(f"Error getting chunk: {str(e)}")
            return ""
    
    def get_chunks(self):
        """
        Get all content chunks based on calculated breaks.
        
        Returns:
            List of content chunks
        """
        try:
            if len(self.breaks) == 0:
                logger.info("No breaks found, returning entire content as single chunk")
                return [self.get_chunk(0, len(self.content))]
            else:
                chunks = []
                beginning = 0
                sorted_breaks = sorted(self.breaks)
                for break_position in sorted_breaks:
                    chunk = self.get_chunk(beginning, break_position + 1)
                    chunks.append(chunk)
                    beginning = break_position + 1
                # Add the last chunk after the final break
                chunks.append(self.get_chunk(beginning, len(self.content)))
                logger.info(f"Generated {len(chunks)} chunks from content")
                return chunks
        except Exception as e:
            logger.error(f"Error getting chunks: {str(e)}")
            return [self.get_chunk(0, len(self.content))]

## Topic Segmentation with Clustering

While the basic chunking method using cosine distances can be effective, it may produce noisy results for complex documents. To improve topic identification, we can use clustering algorithms like KMeans and Hierarchical Clustering to group semantically related content.

The implementation offers two main clustering approaches:

1. **KMeans clustering** - Groups embeddings into k clusters based on vector similarity
2. **Hierarchical clustering** - Creates a tree of clusters by progressively merging similar groups

These methods identify natural topic boundaries in the text by finding transitions between semantic clusters, which often produces more coherent topical chunks than simple distance-based approaches.

In [15]:
# Create a new instance of the SemanticSplitter with cosine distance method
# splitter = SemanticSplitter(df.content, embedding_model, method="number", partition_count=6)

# Create a new instance of the SemanticSplitter with KMeans clustering
splitter = SemanticSplitter(
    content=df.content, 
    embedding_model=embedding_model, 
    method="clustering",
    clustering_method="kmeans", 
    n_clusters=6,
    partition_count=6
)

# Get chunks using KMeans clustering
chunks = splitter.get_chunks()
text = CHUNK_SEPARATOR.join(chunks)

2025-09-16 15:41:06 - INFO - Encoding 250 content items with embedding model
2025-09-16 15:41:06 - INFO - Generated embeddings with shape: (250, 384)
2025-09-16 15:41:06 - INFO - Performing KMeans clustering with 6 clusters
2025-09-16 15:41:08 - INFO - Clustering complete, found 6 clusters
2025-09-16 15:41:08 - INFO - Found 248 raw transitions between clusters
2025-09-16 15:41:08 - INFO - After filtering: using 6 breaks between clusters
2025-09-16 15:41:08 - INFO - Created 6 breaks using clustering method
2025-09-16 15:41:08 - INFO - Created 6 breaks using method 'clustering'
2025-09-16 15:41:08 - INFO - Generated 7 chunks from content


## Step 3: Using a LLM model to Summarize each chunk
In our example, we are going to summarize each individual chunk separately. This solution might be advantageous in different situations:
 * When the original text is too big , or the loaded model works with a context that is too small. In this scenario, breaking information into chunks are necessary to allow the model to be applied
 * When the user wants to make sure that all the separate topics of a conversation are covered into the summarized version. An extra step could be added to allow some verification or manual configuration of the chunks to allow the user to customize the output

In this notebook, we provide three different options for loading the model:
 * **local**: by loading the Meta Llama 3.1 model with 8B parameters from the asset downloaded on the project
 * **hugging-face-local** by downloading a DeepSeek model from Hugging Face and running locally
 * **hugging-face-cloud** by accessing the Mistral model through Hugging Face cloud API (requires HuggingFace API key saved on secrets.yaml)

This choice can be set in the config.yaml file. The model deployed on the bottom cells of this notebook will load the choice from the config file.

In [16]:
model_source = config["model_source"]

In [17]:
%%time

llm = initialize_llm(model_source, secrets)

CPU times: user 1.14 s, sys: 2.27 s, total: 3.41 s
Wall time: 1min 10s


In [18]:
prompt_template = format_chunk_summarization_prompt(model_source)

## Step 4: Create parallel chain to summarize the text

In the following cell, we create a chain that will receive a single string with multiple chunks (separated by the declared separator), than:
  * Break the input into separated chains - using the break_chunks function embedded in a RunnableLambda to be used in LangChain
  * Run a Parallel Chain with the following elements for each chunk:
    * Get an individual element
    * Personalize the prompt template to create an individual prompt for each chunk
    * Use the LLM inference to summarize the chunk
  * Merge the individual summaries into a single one




In [19]:
# Converts prompt_template to LangChain object
prompt = ChatPromptTemplate.from_template(prompt_template)

def break_chunks(text):
    """
    Split text into chunks using the predefined separator.
    """
    return text.split(CHUNK_SEPARATOR)

def process_chunk(chunk_text):
    # Create a proper runnable chain for each chunk
    chunk_chain = (
        RunnablePassthrough.assign(context=lambda _: chunk_text)
        | prompt 
        | llm
    )
    return chunk_chain.invoke({})

def process_chunks(text):
    chunks_list = break_chunks(text)
    results = []
    
    logger.info(f"Processing {len(chunks_list)} chunks")
    
    for i, chunk in enumerate(chunks_list):
        try:
            result = process_chunk(chunk)
            results.append(result)
        except Exception as e:
            error_msg = f"Error processing chunk {i+1}: {str(e)}"
            logger.error(error_msg)
            logger.error(f"Exception type: {type(e).__name__}")
            import traceback
            logger.error(f"Traceback: {traceback.format_exc()}")
            results.append(f"Error: {str(e)}")
            
    return "\n\n".join(results)

lambda_break = RunnableLambda(break_chunks)

def join_summaries(summaries_dict):
    # Extract values from the dictionary and join them
    joined_summaries = "\n\n".join([str(v) for v in summaries_dict.values()])
    logger.info(f"Joined {len(summaries_dict)} summaries")
    return joined_summaries

lambda_join = RunnableLambda(join_summaries)

# Create the complete chain
chain = RunnableLambda(process_chunks) | StrOutputParser()

## Step 5: Run the chain and evaluate quality metrics

In this section, we call the created chain and implement local quality metrics evaluation. We create a local metric evaluation using HuggingFace's implementation of ROUGE (using the evaluate library) to measure the quality of the summarization.

This approach provides:
- **Local evaluation**: No external service dependencies required
- **ROUGE metrics**: Industry-standard evaluation for summarization tasks
- **Performance tracking**: Execution time monitoring for optimization

In [20]:
def evaluate_rouge_local(reference_text: str, prediction_text: str) -> dict:
    """
    Calculate ROUGE metrics locally using HuggingFace evaluate library.
    
    Args:
        reference_text: Original input text (reference)
        prediction_text: Generated summary (prediction)
    
    Returns:
        Dictionary containing ROUGE scores
    """
    try:
        if not reference_text or not prediction_text:
            logger.warning("Empty reference or prediction text")
            return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}

        # Calculate ROUGE metrics
        rouge = evaluate.load("rouge")
        rouge_values = rouge.compute(
            predictions=[prediction_text], 
            references=[reference_text]
        )
        
        logger.info(f"ROUGE Scores - ROUGE-1: {rouge_values.get('rouge1', 0.0):.4f}, "
                   f"ROUGE-2: {rouge_values.get('rouge2', 0.0):.4f}, "
                   f"ROUGE-L: {rouge_values.get('rougeL', 0.0):.4f}")
        
        return rouge_values
    except Exception as e:
        logger.error(f"Error calculating ROUGE metrics: {e}")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}

# Execute the summarization chain and measure performance
logger.info("Starting text summarization and evaluation...")

# Measure execution time
start_time = time.time()
response = chain.invoke(text)
total_time = time.time() - start_time

logger.info(f"✅ Summarization completed in {total_time:.2f} seconds")

# Evaluate the summarization quality using ROUGE metrics
rouge_scores = evaluate_rouge_local(text, response)

# Display results
print("\n" + "="*50)
print("SUMMARIZATION RESULTS")
print("="*50)
print(f"Original text length: {len(text)} characters")
print(f"Summary length: {len(response)} characters")
print(f"Compression ratio: {len(response)/len(text):.2%}")
print(f"Processing time: {total_time:.2f} seconds")
print("\nROUGE Scores:")
for metric, score in rouge_scores.items():
    print(f"  {metric.upper()}: {score:.4f}")
print("="*50)
print("\nGenerated Summary:")
print("-"*50)
print(response)
print("-"*50)

2025-09-16 15:42:19 - INFO - Starting text summarization and evaluation...
2025-09-16 15:42:19 - INFO - Processing 7 chunks
2025-09-16 15:42:29 - INFO - ✅ Summarization completed in 10.46 seconds
2025-09-16 15:42:29 - ERROR - Error calculating ROUGE metrics: name 'evaluate' is not defined



SUMMARIZATION RESULTS
Original text length: 6560 characters
Summary length: 1715 characters
Compression ratio: 26.14%
Processing time: 10.46 seconds

ROUGE Scores:
  ROUGE1: 0.0000
  ROUGE2: 0.0000
  ROUGEL: 0.0000

Generated Summary:
--------------------------------------------------
The excerpt appears to be a speech, likely the "I Have a Dream" speech by Martin Luther King Jr. The speaker reflects on the progress made since the Emancipation Proclamation 100 years prior, but notes that despite this progress, African Americans still face significant challenges and inequalities in society.

The excerpt is a poetic passage that describes the state of Mississippi as being sweltering with the heat of injustice.

The excerpt is a passage from Martin Luther King Jr.'s famous "I Have a Dream" speech. In the passage, King expresses his vision of a future where people are judged not by the color of their skin but by the content of their character. He envisions a world where children of all co

In [21]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")

2025-09-16 15:42:29 - INFO - ⏱️ Total execution time: 0m 10.46s
2025-09-16 15:42:29 - INFO - ✅ Notebook execution completed successfully.


Built with ❤️ using [**HP AI Studio**](https://hp.com/ai-studio).