# This time it will work

## Imports + Device setup

In [None]:
import torch
import nltk # we need nltk to tokenize large inputs
nltk.download('punkt') # punkt seems nessecary for it

from transformers import AutoModelForSeq2SeqLM, pipeline, AutoTokenizer



# Good to empty the cache out to ensure its clean on each new run
torch.cuda.empty_cache()

# Sanity checking
print("torch cuda version:", torch.version.cuda)

# Check for cuda availability
if(torch.cuda.is_available()):
    deviceCount = torch.cuda.device_count()
    currentNumber = torch.cuda.current_device()
    deviceName = torch.cuda.get_device_name(currentNumber) 
    print(f"Cuda available. {deviceCount} device(s) detected.")
    print(f"Current Device: Number:{currentNumber} Name:{deviceName}")
else:
    print("Cuda not available")

# Set device variable. All tensors and models must be on the SAME device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Init Model + Tokenizer
- We select our model type (this will download and cache a huggingface model and transformer)
- We examine the maximum inputs allowed for this model. This is very important to dictate our chunking method

In [None]:
# Model type we're using from Huggingface
modelID = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(modelID).to(device)

# Use the corresponding tokenizer for this model
tokenizer = AutoTokenizer.from_pretrained(modelID)

# Sanity checking, show limits of the model

maxInput = tokenizer.model_max_length # max length of input
maxSentence = tokenizer.max_len_single_sentence # max len of a single sentince
specialTokens = tokenizer.num_special_tokens_to_add() # tokenizer will add 2 special tokens for input seq

print(f"Max input length: {maxInput}, Max Sentence length: {maxSentence}, SpecialTokens: {specialTokens}")

## Chunking and other Helpers
- We need a way to split the input into digestable chunks for the transformer to operate on. This is chunking
- 

In [None]:
# The chunking function will iterate over sentinces of text and throw however many it can into a chunk.
# This seems to be a quick way to break up the text in coherent places
def chunkingFunction(fullTextString):
    
    # tokenize the large amount of text into sentinces
    sentencesList = nltk.tokenize.sent_tokenize(fullTextString)

    # check max length (in tokens) of all sentences
    # if we exceed this anywhere we need to arbitrarily break it up
    maxSentenceLen = max([len(tokenizer.tokenize(sentence)) for sentence in sentencesList])
    
    # This can be handled later, but for now we will raise an exception
    if maxSentenceLen > maxSentence:
        raise Exception(f"Sentince length: {maxSentenceLen} exceeds the model's max sentince length: {maxSentence}")
    
    # The current working chunk
    workingChunk = ""
    
    # All completed chunks
    completedChunks = []

    length = 0
    count = -1

    
    # Iter over all sentences
    for sentence in sentencesList:
        count += 1
        combinedLength = len(tokenizer.tokenize(sentence)) + length
        
        # If the combined length is within permissable length
        if(combinedLength < tokenizer.max_len_single_sentence):
            
            # Then add the sentence
            workingChunk += sentence + " "
            length = combinedLength
            
            # Also if this is the last chunk: strip whitespace and save it to completedChunks
            if count == len(sentencesList) -1:
                completedChunks.append(workingChunk.strip())
        
        # Otherwise, if adding this sentence breaches the maxmimum allowed chunks
        else:
            # Save the chunk we have, we can't add any more to it.
            completedChunks.append(workingChunk.strip())
            
   
            # RESET the working chunk
            # Initialize it with the overflowing chunk and start again
            workingChunk = sentence
            length = len(tokenizer.tokenize(sentence))   