# Implementing LLM Summarization

In [6]:
from transformers import pipeline, T5Tokenizer, BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
import torch


## Loading the model

In [2]:
model_name = "t5-base"

In [None]:
print("CUDA available:", torch.cuda.is_available())
print("GPU device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")

### Loading tokenizer to calculate input size

In [None]:
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [None]:
summarizer = pipeline(
    "summarization",
    model=model_name,
    tokenizer=tokenizer,
    device_map="auto",
    model_kwargs={
        "cache_dir": '../models/'
    }
)

# Initializing input text

In [7]:
test_text = """
Machine learning (ML) is a field of inquiry devoted to understanding and building methods that 'learn', 
that is, methods that leverage data to improve performance on some set of tasks. It is seen as a part 
of artificial intelligence. Machine learning algorithms build a model based on sample data, known as 
training data, in order to make predictions or decisions without being explicitly programmed to do so. 
Machine learning algorithms are used in a wide variety of applications, such as in medicine, email 
filtering, speech recognition, and computer vision, where it is difficult or unfeasible to develop 
conventional algorithms to perform the needed tasks. Deep learning is a subset of machine learning that uses neural networks with multiple layers (deep neural networks) 
to progressively extract higher-level features from raw input. For example, in image processing, lower layers may 
identify edges, while higher layers may identify the concepts relevant to a human such as digits or letters or faces. Natural Language Processing (NLP) is a branch of artificial intelligence that helps computers understand, interpret, 
and manipulate human language. NLP draws from many disciplines, including computer science and computational linguistics, 
in its pursuit to fill the gap between human communication and computer understanding. Computer vision is an interdisciplinary scientific field that deals with how computers can gain high-level understanding 
from digital images or videos. From the perspective of engineering, it seeks to understand and automate tasks that the 
human visual system can do. Computer vision tasks include methods for acquiring, processing, analyzing and understanding 
digital images, and extraction of high-dimensional data from the real world in order to produce numerical or symbolic 
information, e.g., in the forms of decisions.
""" 

### Chunking function

In [8]:
def split_text(text, max_tokens=128):
    # Tokenize the entire text
    tokens = tokenizer.encode(text, return_tensors="pt")
    total_tokens = tokens.shape[1]

    # If the text is less than the max_tokens, return the text as a single chunk
    if total_tokens <= max_tokens:
        return [text]

    # Split into sentences (rough approximation)
    sentences = text.split('.')
    chunks = []
    current_chunk = []
    current_length = 0

    for sentence in sentences:
        # Add period back to sentence
        sentence = sentence.strip() + '.'

        # Tokenize the sentence
        sentence_tokens = tokenizer.encode(sentence, return_tensors="pt")
        sentence_length = sentence_tokens.shape[1]

        # If adding this sentence would exceed max_tokens, start a new chunk
        if current_length + sentence_length > max_tokens and current_chunk:
            chunks.append(' '.join(current_chunk))
            current_chunk = [sentence]
            current_length = sentence_length
        else:
            current_chunk.append(sentence)
            current_length += sentence_length

    # Add the last chunk if it exists
    if current_chunk:
        chunks.append(' '.join(current_chunk))

    return chunks

### Min and Max length calculation function

In [9]:
def calculate_lengths( text):
        # Calculate tokenized length
        tokens = tokenizer.encode(text, return_tensors="pt")
        input_length = tokens.shape[1]
        
        # Calculate appropriate summary lengths
        max_length = min(input_length // 2, 150)  # Half of input or 150 tokens, whichever is smaller
        min_length = max(input_length // 4, 5)   # Quarter of input or 30 tokens, whichever is larger
        
        return min_length, max_length

### Summarization function

In [10]:
def summarize_chunk(chunk):
    min_length, max_length = calculate_lengths(chunk)
    summary = summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
    return summary


In [None]:
chunks = split_text(test_text)
if len(chunks) == 1:
    summary = summarize_chunk(chunks[0])
else:
    summary = []
    for chunk in chunks:
        summary.append(summarize_chunk(chunk))
    summary = ' '.join(summary)

bullet_points = summary.split(".")
for point in bullet_points:
    print(f"- {point}")

In [None]:
for chunk in chunks:
    print(chunk)
    print("--------------------------------")
