# This time it will work

## Imports + Device setup

In [1]:
import torch
import nltk # we need nltk to tokenize large inputs
from datasets import load_dataset, load_metric
nltk.download('punkt') # punkt seems nessecary for it

from transformers import AutoModelForSeq2SeqLM, pipeline, AutoTokenizer



# Good to empty the cache out to ensure its clean on each new run
torch.cuda.empty_cache()

# Sanity checking
print("torch cuda version:", torch.version.cuda)

# Check for cuda availability
if(torch.cuda.is_available()):
    deviceCount = torch.cuda.device_count()
    currentNumber = torch.cuda.current_device()
    deviceName = torch.cuda.get_device_name(currentNumber) 
    print(f"Cuda available. {deviceCount} device(s) detected.")
    print(f"Current Device: Number:{currentNumber} Name:{deviceName}")
else:
    print("Cuda not available")

# Set device variable. All tensors and models must be on the SAME device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Adam\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!



torch cuda version: 12.1
Cuda available. 1 device(s) detected.
Current Device: Number:0 Name:NVIDIA GeForce GTX 1080 Ti


## Init Model + Tokenizer
- We select our model type (this will download and cache a huggingface model and transformer)
- We examine the maximum inputs allowed for this model. This is very important to dictate our chunking method

In [3]:
# Model type we're using from Huggingface
modelID = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(modelID).to(device)

# Use the corresponding tokenizer for this model
tokenizer = AutoTokenizer.from_pretrained(modelID)

# Sanity checking, show limits of the model

maxInput = tokenizer.model_max_length # max length of input
maxSentenceTokens = tokenizer.max_len_single_sentence # max len of a single sentince
specialTokens = tokenizer.num_special_tokens_to_add() # tokenizer will add 2 special tokens for input seq

print(f"Max input length: {maxInput}, Max Sentence length: {maxSentenceTokens}, SpecialTokens: {specialTokens}")

Max input length: 512, Max Sentence length: 511, SpecialTokens: 1


## Chunking and other Helpers
- We need a way to split the input into digestable chunks for the transformer to operate on. This is chunking
- 

In [31]:
# The chunking function will iterate over sentinces of text and throw however many it can into a chunk.
# This seems to be a quick way to break up the text in coherent places
def chunkingFunction(fullTextString: str) -> [str]:
    
    # tokenize the large amount of text into sentinces
    sentencesList = nltk.tokenize.sent_tokenize(fullTextString)

    # check max length (in tokens) of all sentences
    # if we exceed this anywhere we need to arbitrarily break it up
    maxSentenceLen = max([len(tokenizer.tokenize(sentence)) for sentence in sentencesList])
    
    # This can be handled later, but for now we will raise an exception
    # if maxSentenceLen > maxSentence:
    #     #print(f"Sentince length: {maxSentenceLen} exceeds the model's max sentince length: {maxSentence}")
    
    # The current working chunk
    workingChunk = ""
    
    # All completed chunks
    completedChunks = []

    length = 0
    count = -1

    #print("sentenceList len", len(sentencesList),)
    # Iter over all sentences
    for sentence in sentencesList:
        
        sentenceTokenizedLength = len(tokenizer.tokenize(sentence))
        count += 1
        combinedLength =  sentenceTokenizedLength + len(tokenizer.tokenize(workingChunk))
        
        #print("senteince len:", sentenceTokenizedLength)
        #print("checking if combined lenth:", combinedLength, "is less than max:", tokenizer.max_len_single_sentence)
        # If the combined length is within permissable length
        if(combinedLength < tokenizer.max_len_single_sentence):
            
            # Then add the sentence
            workingChunk += sentence + " "
            length = combinedLength
            #print("added to working chunk")
            # Also if this is the last chunk: strip whitespace and save it to completedChunks
            if count == len(sentencesList) -1:
                completedChunks.append(workingChunk.strip())
                #print("saved chunk thanks to count==len check")
        
        # Otherwise, if adding this sentence breaches the maxmimum allowed chunks
        else:
            #print("this breaches allowed chunk len")
            
            # If this sentence is bigger than the max sentence lentgh, we will have to arbitrarily split it.
            # this is never good, but it's all we can do for now
            while (sentenceTokenizedLength > maxSentenceTokens):
                #print("sentnce is bigger than max len.  splitting loop")
                splitIndex = len(sentence)/2
                
                
                # keep lowering the split index until we can get a valid length
                newTokenizedSentence = tokenizer.tokenize(sentence[:splitIndex])
                while(newSentenceTokenizedLength > maxSentenceTokens):
                    #print("smaller splitting loop")
                    splitIndex = splitIndex/2
                    newTokenizedSentence = tokenizer.tokenize(sentence[:splitIndex])
                    newSentenceTokenizedLength = len(newTokenizedSentence)
                    
                    
                # by now we know the newtokenized is valid, add it
                completedChunks.append(newTokenizedSentence)
                
                sufficientSizeSentence = sentence[:splitIndex]
                sentenceRemaining = sentence[splitIndex:]
                completedChunks.append(sufficientSizeSentence.strip())
                #print("saved this thing")
            
            # Save the chunk we have, we can't add any more to it.
            completedChunks.append(workingChunk.strip())
            #print("append working chunk to completed chunk & reset workingChunk")
   
            # RESET the working chunk
            # Initialize it with the overflowing chunk and start again
            workingChunk = sentence
            length = sentenceTokenizedLength   
    
    return completedChunks


# We want to "batch" the inputs (split them into multiple smaller sets to process simultaneously
# We also want to not hold all the dataset elements in memory, so we create a generator to return batches of them
def batchGenerator(elementList, batchSize):
    for i in range(0, len(elementList), batchSize):
        yield elementList[i : i + batchSize]

     


## import data

In [5]:
# Load all the data
datasetDict = load_dataset("ccdv/govreport-summarization")
# this dataset is a dict with train, validation, and test

# get lengths of dataset splits
split_lengths = [len(datasetDict[split])for split in datasetDict]

# Print our dataset for sanity checking
print("Our dataset:")
trainLen = datasetDict["train"].num_rows
testLen =  datasetDict["test"].num_rows
validationLen = datasetDict["validation"].num_rows
splits = [f"Set '{split}': rows:{datasetDict[split].num_rows}, features:{datasetDict[split].column_names}" for split in datasetDict]
print("\n".join(splits))

# Print a a sample report and summary
print("\nSample Report:")

print(datasetDict["test"][1]["report"])

print("Sample Summary:")

print(datasetDict["test"][1]["summary"])


No config specified, defaulting to: govreport-summarization/document
Found cached dataset govreport-summarization (C:/Users/Adam/.cache/huggingface/datasets/ccdv___govreport-summarization/document/1.0.0/57ca3042de9c40c218cc94084cbc80a99a161036134bfc88112c57d251443590)


  0%|          | 0/3 [00:00<?, ?it/s]

Our dataset:
Set 'train': rows:17517, features:['report', 'summary']
Set 'validation': rows:973, features:['report', 'summary']
Set 'test': rows:973, features:['report', 'summary']

Sample Report:
A variety of federal laws, regulations, and policies establish requirements and guidance for EPA to follow when appointing members to serve on advisory committees. For example, one purpose of FACA is to ensure that uniform procedures govern the establishment and operation of advisory committees. Also under FACA, an agency establishing an advisory committee must, among other things, require the committee’s membership to be balanced in terms of the points of view represented and the functions to be performed by the committee. In addition, federal ethics regulations establish when and how federal officials should review financial disclosure forms to identify and prevent conflicts of interest prohibited by federal law for any prospective committee members required to file these forms in connectio

In [16]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

# Create a dataloader
train_dataloader = DataLoader(
  datasetDict["train"], shuffle=True, batch_size=8
)

# for batch in train_dataloader:
#   print({k: len(v) for k, v in batch.items()})
#   break
# for batch in train_dataloader:
#   reports = batch["report"]
#   print(len(reports))
#   # reports = items['report']
  
#   # reports = items.report
#   print("type", type(items))
#   print({k: len(v) for k, v in items})


## Training Loop

In [33]:
def generateLongSummary(reports:[str], summaries:[str]):
    
    # chunk the reports into digestable sizes for the model
    chunkedReports = [chunkingFunction(report) for report in reports]
    # tokenize all these chunks
    tokenizedChunkedReports = [tokenizer.tokenize(chunk) for chunk in chunkedReports]
    
    
    
    # print(len(chunkedReports), len(summaries))
    
    # Iter over all report/summary pairs
    for idx in range(len(summaries)):
        
        generatedSummaries = []
        # Iter over all reports
        
        summaries = model.generate(input_ids=tokenChunk["input_ids"].to(device),
                        attention_mask=tokenChunk["attention_mask"].to(device), 
                        length_penalty=0.8, num_beams=8, max_length=128)
            
            
        print("chunks len:", len(chunkedReports[idx]))

for x in range(1):
    data_iter = iter(train_dataloader)
    single_batch = next(data_iter)

    reports = single_batch["report"]
    summaries = single_batch["summary"]

    generateLongSummary(reports, summaries)

8 8
chunks len: 5
chunks len: 29
chunks len: 10
chunks len: 29
chunks len: 12
chunks len: 31
chunks len: 5
chunks len: 16


In [8]:
from tqdm.auto import tqdm
from transformers import AdamW
from transformers import get_scheduler

# Num of training epochs we'll perform
num_epochs = 1

optimizer = AdamW(model.parameters(), lr=5e-5) # optimier for our training


#num training steps to perform
num_training_steps = num_epochs * trainLen

# init progress bar
progress_bar = tqdm(range(num_training_steps))


lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


model.train()
for epoch in range(1):
    # Get each batch of data
    for batch in train_dataloader:
        reports = batch["report"]
        summaries = batch["summary"]
        
        chunkedReports = [chunkingFunction(report) for report in reports]
        print("chunkedReport len:", len(chunkedReports))
        
        # batch = {k: v.to(device) for k, v in batch.items()}
        # outputs = model(**batch)
        # loss = outputs.loss
        # loss.backward()
        
        # optimizer.step()
        # lr_scheduler.step()
        # optimizer.zero_grad()
        progress_bar.update(1)



  0%|          | 0/17517 [00:00<?, ?it/s]

chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8


Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors


chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport len: 8
chunkedReport

KeyboardInterrupt: 