In [None]:
# Mount Google Drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!pip install -U bitsandbytes
!pip install symspellpy

In [None]:
# This version should chunk into sections + save and resume feature (in case of crash, haven't really tested that here though)

import torch
import gc
from transformers import AutoTokenizer, LlamaForCausalLM, LogitsProcessor
import os
import string
from symspellpy import SymSpell, Verbosity

# Memory cleanup before loading
gc.collect()
torch.cuda.empty_cache()

# Path to model folder - make sure to download this first
model_dir = "/content/drive/MyDrive/my-model-folder"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
print("Tokenizer successfully loaded.")

# Load quantized model safely (avoiding meta tensors)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device.upper()}.")

# Explicit device map to control offloading
device_map = "auto"
offload_folder = "/content/offload"  # temporary folder for swapped layers
os.makedirs(offload_folder, exist_ok=True)

# Load model
model = LlamaForCausalLM.from_pretrained(
    model_dir,
    device_map=device_map,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False,
    llm_int8_enable_fp32_cpu_offload=True,
    local_files_only=True,
    offload_folder=offload_folder
)

# Force materialization (no meta tensors)
for _, p in model.named_parameters():
    if p.device.type == "meta":
        p.data = torch.zeros_like(p, device="cpu")

model.eval()
torch.cuda.empty_cache()

# Token masking setup
# Remove non-English letters and most punctuation
allowed_chars = string.ascii_letters + string.digits + "!\"$\',.;?" + " \n"
allowed_tokens = []

for i in range(tokenizer.vocab_size):
    token_str = tokenizer.decode([i])
    if all(ch in allowed_chars for ch in token_str):
        allowed_tokens.append(i)

disallowed_tokens = [
    i for i in range(tokenizer.vocab_size)
    if i not in allowed_tokens and i not in tokenizer.all_special_ids
]

# Letters to omit
# omit_list = "zqjxkvbpgyfmw" # least frequent in texts
# omit_list = "qjxzwkvfybhmp" # least frequent in dictionaries
# omit_list = "zqxjkvcdugmpy" # chat gpt recommendation
omit_list = "e" # letters to mask/ eliminate from LLM output
special_tokens = tokenizer.all_special_ids
n_tokens = tokenizer.vocab_size

def detect_unwanted_letters(token_str):
    omit_letters = "".join([o.upper() + o.lower() for o in omit_list])
    return any(c in omit_letters for c in token_str)

mask_indices = [
    i for i in range(n_tokens)
    if detect_unwanted_letters(tokenizer.decode([i])) and i not in special_tokens
]

# all characters to mask
final_mask = list(set(mask_indices) | set(disallowed_tokens))

print(f"Model  total number of tokens: {n_tokens:,}")
print(f"Masked total number of tokens: {len(final_mask):,}")

class MaskLogits(LogitsProcessor):
    def __init__(self, mask_indices):
        self.mask_indices = mask_indices

    def __call__(self, input_ids, scores):
        scores[:, self.mask_indices] = -float("inf")
        return scores

# spell checker (sometimes model will output words that don't exist - some of them can be corrected without adding omitted letters)
sym_spell = SymSpell(max_dictionary_edit_distance=2)
sym_spell.load_dictionary("/content/drive/MyDrive/word_dictionary.txt", term_index=0, count_index=1)

def clean_text(text):
    corrected_words = []
    for word in text.split():
        if not word.isalpha():
            corrected_words.append(word)
            continue
        suggestions = sym_spell.lookup(word.lower(), Verbosity.CLOSEST)
        if suggestions and not any(ch in omit_list for ch in suggestions[0].term):
            term = suggestions[0].term
            new_word = term.capitalize() if word[0].isupper() else term
            corrected_words.append(new_word)
        else:
            corrected_words.append(word)
    return " ".join(corrected_words)

# Path to saved file (i.e. file in which novel is saved)
existing_file_prompt = input("Enter path to existing novel file (hit ENTER if new novel): ")

def get_existing_novel(file_path):
  parts = {"current-chunk": ""}
  with open(file_path, "r", encoding="utf-8") as f:
    content = f.read()
    sections = content.split("\n\n")

  chunk = 0
  for section in sections:
    if section.startswith("Prompt: "):
      parts["prompt"] = section
    elif section.startswith("Characters: "):
      parts["characters"] = section
    elif section.startswith("Setting: "):
      parts["setting"] = section
    elif section.startswith("Tone: "):
      parts["tone"] = section
    elif section.startswith("Chapters: "):
      parts["chapters"] = section
    elif not section.startswith("CHAPTER "): # the actual text, not including the chapter title
      parts["prev-chunk"] = parts["current-chunk"]
      parts["current-chunk"] = section[-1000:] # will include chapter highlight as well
      parts["chunk"] = chunk
      chunk += 1

    print(f"Section Retrieved: {section}")

  return parts

# Gets existing novel pieces (if they exist/ resuming story)
existing_novel_parts = get_existing_novel(existing_file_prompt) if existing_file_prompt else None

# Text generation
print("Model is ready with runtime token masking.")
prompt = existing_novel_parts["prompt"] if existing_file_prompt else input("> ")

# Save prompt
file_path = existing_file_prompt if existing_file_prompt else f"/content/drive/MyDrive/novel_{prompt}.txt"
if not existing_file_prompt:
  with open(file_path, "a", encoding="utf-8") as f:
    f.write("Prompt: " + prompt + "\n\n")

# Generation
def generation(gen_prompt):
    messages = [{"role": "user", "content": gen_prompt}]
    chat_input = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,  # ensures model knows it’s time to reply
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            chat_input,
            max_new_tokens=5000,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.2, # penalize repeated tokens
            no_repeat_ngram_size=4, # prevents same 4-token phrase repetition
            logits_processor=[MaskLogits(final_mask)]
        )

    # Decode and print response (just the new part, not the prompt)
    og_response = tokenizer.decode(
        output_ids[0][chat_input.shape[-1]:],
        skip_special_tokens=True
    ).strip()
    return clean_text(og_response)

starter_prompt = f"I want to write a full-length novel based on the following idea: {prompt}"

# used for enforcing who the characters are, setting, and tone
def generate_story_background():
  if existing_file_prompt:
    return existing_novel_parts["characters"], existing_novel_parts["setting"], existing_novel_parts["tone"]

  character_length = "This generation should be no longer than 200 words."
  setting_length = "This generation should contain no more than 40 words."
  tone_length = "This generation should be no longer than 80 words."

  character_prompt = f"{starter_prompt} Create a list of characters that fit this story, including demographic information, personality traits, and how they physically look."
  character_add = "Every generated sentence should describe exactly one character.  Each character should only have one sentence describing them.  There should be an equal number of sentences as there are characters. Each sentence should list the character\’s name, two to three pieces of demographic information, and two to three pieces of information that describe their personality.  The character should be described in this order and there should be no additional filler or unnecessary numbering of the demographic or personality traits. Do not make them optional choices to choose from.  Instead, each character must have a specific role that makes them integral to the story. An individual character description should take up no more than 35 words."
  setting_prompt = f"{starter_prompt} Create the general setting for this story, including the year (or range of years) and location."
  setting_add = "The generated output should contain exactly two sentences.  The first sentence should describe the time period in which the story takes place.  The second sentence should describe the general location or locations that the story takes place.  For example, if the story is about a man traveling to every country, the setting will be the world. Do not add any filler context. Do not give a list of possible settings.  Only generate one setting that the story takes place in."
  tone_prompt = f"{starter_prompt} Generate the tone that this story should follow, which should dictate the emotional and intellectual stance of the story."
  tone_add = "Directly answer this prompt by outlining the type of word choice, sentence structure, and style of writing. Do not add any filler context.  Do not generate different options for the tone.  Only generate one style of tone."

  characters = generation(f"{character_prompt} {character_add} {character_length}")
  print(f"\nCharacters: {characters}\n")
  setting = generation(f"{setting_prompt} {setting_add} Here are a list of characters in this story. Do not reference them in the generation: {characters} {setting_length}")
  print(f"\nSetting: {setting}\n")
  tone = generation(f"{tone_prompt} {tone_add} Here are a list of characters and the setting in this story. Do not reference them in the generation: Characters: {characters} Setting: {setting} {tone_length}")
  print(f"\nTone: {tone}\n")

  with open(file_path, "a", encoding="utf-8") as f:
    f.write("Characters: " + characters + "\n\n")
    f.write("Setting: " + setting + "\n\n")
    f.write("Tone: " + tone + "\n\n")

  return characters, setting, tone

# Chunk into pieces - roughly 91,000 words total (~6,500 words generated per 10,000 tokens)
num_story_pieces = 14
tokens_per_chunk = 10000
# Generates the chapters
def chunk_text(background_prompt):
  if existing_file_prompt:
    return existing_novel_parts["chapters"]

  general_prompt = f"{starter_prompt} Create a list of {num_story_pieces} one-sentence story beats that together outline the entire novel from beginning to end. Each sentence should be broad enough to represent about 6,000–7,000 words of story, meaning each one covers many scenes or major developments, not just a single moment. The {num_story_pieces} sentences should form a complete, coherent story arc (setup, rising action, climax, resolution), introduce and evolve characters, conflicts, and themes logically, maintain consistent tone and style, and flow naturally from one sentence to the next. Think of it as a condensed version of the entire novel, told as {tokens_per_chunk} long, detailed chapter summaries — one sentence per chapter.  Each chapter should consist of exactly one sentence and each sentence should use exactly one period. Do not number, label, or list any sentence. Separate each chapter with exactly one period and one space. There should be no other periods anywhere in the text."
  all_chapters = generation(f"{general_prompt} {background_prompt}")
  print(f"\nAll Chapters: {all_chapters}\n")

  with open(file_path, "a", encoding="utf-8") as f:
    f.write("Chapters: " + all_chapters + "\n\n")

  return all_chapters

def chunk_generation(prev_chunk, chapter_number, chapter_sentence, background_prompt, prior_story=''):
  # Generate chapter
  with open(file_path, "a", encoding="utf-8") as f:
        f.write(f"CHAPTER {chapter_number}: {chapter_sentence}\n\n")

  # Generate text
  chunk_whole_text = prior_story
  constant_text = "Continue this story exactly where it left off. Do not summarize or restart. Avoid repeating earlier parts of the story. Focus only on new developments."
  chapter_text = f"Base the story on the following sentence: {chapter_sentence}"
  additional_prompting = "Make sure to keep point of view (first, second, or third) consistent throughout the whole generation. Keep characters consistent with any existing story that you are given. Do not add new characters unless absolutely necessary. Make sure linking verbs (such as is and are) are used when needed."
  initial_prompt = True
  continuation_text = f"Here is the first part of the story: {prev_chunk}" if prev_chunk else "" # text from previous chunk
  english_prompt = "Ensure the story continues naturally and does not repeat the same phrase or description. Avoid using invented or non-existent words (i.e. words that rarely appear in training data). If no suitable synonym exists, rephrase the thought naturally using other words.  Do not use too many symbol characters in a generation.  Only use English letters, words, and punctuation. Even though some letters are missing, ensure your sentences remain grammatical. Always include linking verbs like is, was, are, and were where appropriate and if possible. Write clearly and naturally, even if word choice feels limited."
  present_tense = "All text should be in the present tense, if possible." # might remove later

  while len(tokenizer(chunk_whole_text).input_ids) < tokens_per_chunk:
      # Format as a chat-style input (for instruct-tuned models)
      if initial_prompt:
        full_prompt = f"{prompt} {chapter_text} {continuation_text[-1000:]} {constant_text}"
      else:
        full_prompt = f"{prompt} Here is the first part of the story: {chunk_whole_text[-1000:]} {constant_text}" # takes last 1,000 tokens - might want to change logic for more accuracy
      full_prompt += (" " + english_prompt)
      full_prompt += (" " + additional_prompting)
      full_prompt += (" " + background_prompt)
      full_prompt += (" " + present_tense)

      response = generation(full_prompt)
      print(f"Current Response: \n{response}\n")

      # Open file in append mode and write
      with open(file_path, "a", encoding="utf-8") as f:
        f.write("\t" + response + "\n")

      # Update variables
      initial_prompt = False
      chunk_whole_text += ("\t" + response + "\n")

      torch.cuda.empty_cache()
      gc.collect()

  with open(file_path, "a", encoding="utf-8") as f:
        f.write("\n\n")
  return chunk_whole_text

def generate_novel():
  whole_text = ""
  current_chunk = existing_novel_parts["chunk"] if existing_file_prompt else 0

  background_tuple = generate_story_background()
  characters = background_tuple[0]
  setting = background_tuple[1]
  tone = background_tuple[2]
  background_prompt = f"Make sure to use the following characters: {characters} Make sure to use the following setting: {setting} Make sure to use the following tone: {tone}"

  parts = chunk_text(background_prompt).split(".")
  previous_text_chunk = existing_novel_parts["prev-chunk"] if existing_file_prompt else ""
  current_text_chunk = existing_novel_parts["current-chunk"] if existing_file_prompt else ""

  # This condition is if we are reloading the existing novel; we are finishing off the chapter
  if existing_file_prompt:
    regen_text = current_text_chunk
    current_text_chunk = chunk_generation(previous_text_chunk, current_chunk + 1, parts[current_chunk], background_prompt, regen_text)

    # Make sure not to include bit of that chapter that was already added to the text file
    whole_text += current_text_chunk[len(regen_text):]
    current_chunk += 1

  while current_chunk < num_story_pieces:
    previous_text_chunk = current_text_chunk

    current_text_chunk = chunk_generation(previous_text_chunk, current_chunk + 1, parts[current_chunk], background_prompt)
    whole_text += current_text_chunk
    current_chunk += 1

    print(f"Chunk {current_chunk} complete!")

  print("Whole Generation: " + whole_text)

generate_novel()