In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
import json
import os

In [4]:
# !pip install peft

In [5]:
with open("../keys.json", "r") as file:
    token_data = json.load(file)
HUGGINGFACE_TOKEN = token_data["huggingface_access_token"]

os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN 

In [6]:
# Define the instruct-tuned model checkpoint
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
scratch_dir = "../models"

# Load the tokenizer (force using the slow tokenizer)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=scratch_dir)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=scratch_dir,
    device_map="auto",
    torch_dtype=torch.float16
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
# Define the LoRA configuration.
# You can adjust r, alpha, dropout, and target modules according to your needs.
lora_config = LoraConfig(
    r=8,                      # rank of update matrices
    lora_alpha=32,            # scaling factor
    target_modules=["q_proj", "v_proj"],  # which modules to adapt (example modules)
    lora_dropout=0.1,         # dropout probability for lora layers
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Wrap the model with LoRA
model = get_peft_model(model, lora_config)

# Load and preprocess a dataset (here we use Wikitext-2 as an example)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")


In [6]:
sample_dataset = dataset.select(range(100))

In [7]:
# Set the pad_token if it isn't already defined
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
def tokenize_function(example):
    outputs = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=512  # adjust max_length as required
    )
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

tokenized_dataset = sample_dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


In [8]:

# Define training arguments.
training_args = TrainingArguments(
    output_dir="./lora_mistral",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,  # simulate larger batch sizes
    learning_rate=1e-4,
    fp16=True,                     # mixed precision training
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="no"
)





In [9]:
# Create the Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

# Start training
trainer.train()

# Save the LoRA adapter weights (or the entire model if desired)
model.save_pretrained("./lora_mistral_adapter")

Step,Training Loss


config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

## Predict pipeline for Mistral-7b

In [1]:
import os
import json
from langgraph.prebuilt import create_react_agent
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from IPython.display import Markdown, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Dataset can be downloaded from
# !wget https://s3.amazonaws.com/ai2-vision-textbook-dataset/dataset_releases/tqa/tqa_train_val_test.zip
# !unzip tqa_train_val_test.zip

# Load the dataset JSON file
data_path = "../data/tqa_train_val_test/train/tqa_v1_train.json"
with open(data_path, "r", encoding="utf-8") as f:
    data = json.load(f)

# For a quick check, print the first record
print(data[0])

{'adjunctTopics': {'Apply Concepts': {'content': {'figures': [], 'mediaLinks': [], 'text': '5. A glacier is melting. What are all of the scientists you can think of who might be involved in studying this glacier? What would each of them do? '}, 'orderID': 't_12'}, 'Introduction': {'content': {'figures': [{'caption': 'FIGURE 1.10 Earth as seen from Apollo 17.', 'imagePath': 'textbook_images/earth_science_and_its_branches_20010.png'}], 'mediaLinks': [], 'text': 'Earth Science is the study of all aspects of our planet Earth. Earth Science is not just about the molten lava, icy mountain peaks, steep canyons and towering waterfalls of the continents. Earth Science includes the atmosphere and oceans. The field also looks out into the solar system, galaxy, and universe. Earth scientists seek to understand the beautiful planet on which we depend (Figure 1.10). Different branches of Earth Science study one particular part of Earth. Since all of the branches are connected, specialists work toget

In [3]:
def extract_textbook_text(entry):
    """
    Extract textbook text from a TQA dataset entry.
    Retrieves text from the 'lessonName', 'adjunctTopics', and 'topics' sections.
    """
    texts = []
    
    # Include lesson title
    lesson_name = entry.get("lessonName", "")
    if lesson_name:
        texts.append(f"Lesson: {lesson_name.strip()}")
    
    # Extract text from adjunctTopics
    adjunct_topics = entry.get("adjunctTopics", {})
    for topic in adjunct_topics.values():
        content = topic.get("content", {})
        text = content.get("text", "")
        if text:
            texts.append(text.strip())
            
    # Extract text from topics
    topics = entry.get("topics", {})
    for topic in topics.values():
        content = topic.get("content", {})
        text = content.get("text", "")
        if text:
            texts.append(text.strip())
    
    return texts

In [16]:
import os

def extract_textbook_text_with_image_explainations(entry):
    """
    Extract textbook text from a TQA dataset entry.
    Retrieves text from:
      - 'lessonName'
      - 'adjunctTopics' and their content text and figures (with LLava explanations)
      - 'topics' and their content text and figures (with LLava explanations)
    
    For figures, this function assumes that:
      - Figures with image paths containing "textbook_images" have their LLava explanations 
        stored in the "textbook_images_llava_captions" folder.
      - Figures with image paths containing "teaching_images" have their LLava explanations 
        stored in the "teaching_images_llava_captions" folder.
    
    Explanation files are assumed to be named as the image file name (without its extension)
    plus a ".txt" extension, and are located under the base directory:
    "../data/tqa_train_val_test/train/"
    """
    texts = []
    
    # Include lesson title
    lesson_name = entry.get("lessonName", "")
    if lesson_name:
        texts.append(f"Lesson: {lesson_name.strip()}")
    
    # Base directory for explanation files
    base_dir = "../data/tqa_train_val_test/train/"
    
    def extract_content_and_explanations(section):
        section_texts = []
        content = section.get("content", {})
        
        # Append the text content if available
        text = content.get("text", "")
        if text:
            section_texts.append(text.strip())
        
        # Process figures for captions and explanations
        figures = content.get("figures", [])
        for fig in figures:
            caption = fig.get("caption", "").strip()
            image_path = fig.get("imagePath", "")
            if image_path:
                # Determine the appropriate explanation folder based on the image path
                if "textbook_images" in image_path:
                    explanation_folder = "textbook_images_llava_captions"
                elif "teaching_images" in image_path:
                    explanation_folder = "teaching_images_llava_captions"
                else:
                    explanation_folder = None
                
                if explanation_folder:
                    file_name = os.path.basename(image_path)
                    # Remove file extension and add .txt
                    base_name = os.path.splitext(file_name)[0]
                    explanation_file = os.path.join(base_dir, explanation_folder, base_name + ".txt")
                    print("Looking for explanation file:", explanation_file)
                    if os.path.exists(explanation_file):
                        with open(explanation_file, 'r') as f:
                            explanation = f.read().strip()
                        if explanation:
                            # Combine caption and explanation in a single string
                            combined_text = ""
                            if caption:
                                combined_text += f"Image Caption: {caption}. "
                            combined_text += f"Image Explanation: {explanation}"
                            section_texts.append(combined_text)
        return section_texts

    # Extract text and explanations from adjunctTopics
    adjunct_topics = entry.get("adjunctTopics", {})
    for topic in adjunct_topics.values():
        texts.extend(extract_content_and_explanations(topic))
    
    # Extract text and explanations from topics
    topics = entry.get("topics", {})
    for topic in topics.values():
        texts.extend(extract_content_and_explanations(topic))
    
    return texts


In [17]:
# Combine the extracted text from the first record
extracted_texts = extract_textbook_text_with_image_explainations(data[0])
combined_text = "\n\n".join(extracted_texts)
print("Combined Text:\n")
print(combined_text)

Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20010.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20017.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20011.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20012.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20013.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20014.txt
Looking for explanation file: ../data/tqa_train_val_test/train/textbook_images_llava_captions/earth_science_and_its_branches_20015.txt
Looking for explanation file: ../data/tqa_train_val_tes

In [18]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

In [19]:
# Define the instruct-tuned model checkpoint
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
scratch_dir = "../models"

# Load the tokenizer (force using the slow tokenizer)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=scratch_dir)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=scratch_dir,
    device_map="auto",
    torch_dtype=torch.float16
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [20]:
########################
# 1. Generate Mindmap  #
########################
def generate_mindmap(text, model, tokenizer, max_tokens=300):
    """
    Generate a mind map in Mermaid syntax using Mistral.
    The mind map should capture the main topics and subtopics hierarchically.
    The model is instructed to output concise and correctly formatted Mermaid syntax.
    """

    prompt = (
        "Generate a mind map in Mermaid syntax for the following textbook text. "
        "The mind map should capture the main topics and their subtopics clearly in a hierarchical structure. "
        "Do NOT include a 'root' node; use a meaningful central topic. "
        "Include minimal Mermaid syntax so it renders correctly, and keep it concise.\n\n"
        f"{text}\n\n"
        "Mermaid Mind Map:"
    )

    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    # Generate text with Mistral
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,  # Sampling for variation
            temperature=0.7,  # Adjust randomness
            top_p=0.9  # Nucleus sampling for diversity
        )

    # Decode output
    mindmap = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return mindmap.strip()

In [21]:
import re

def clean_mermaid_code(code):
    """
    Extract the text between ```mermaid and ``` fences only.
    Returns the extracted Mermaid syntax, or the original code if not found.
    """
    match = re.search(r"```mermaid(.*?)```", code, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return code.strip()


In [22]:
##############################
# 3. Render Mermaid Diagram  #
##############################
def render_mermaid(mindmap_code, output_filename='mermaid_output.svg', theme='forest'):
    """
    Render Mermaid diagram as an SVG file using the Mermaid CLI (mmdc).
    You can specify a theme for better visualization (e.g., 'forest', 'dark', 'neutral', etc.).
    """
    # Write Mermaid code to a temporary file
    temp_file = "temp_mermaid.mmd"
    with open(temp_file, "w", encoding="utf-8") as f:
        f.write(mindmap_code)
    
    # Call the Mermaid CLI with a chosen theme
    subprocess.run(
        ["mmdc", "-i", temp_file, "-o", output_filename, "-t", theme],
        check=True
    )
    
    # Display the generated SVG inline (Jupyter/IPython)
    display(SVG(output_filename))
    
    # Clean up the temporary file
    os.remove(temp_file)


In [23]:
# Generate the Mermaid mind map for the combined textbook text
mindmap_code = generate_mindmap(combined_text, model, tokenizer)
print("\nGenerated Mind Map (Mermaid format):\n")
print(mindmap_code)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Generated Mind Map (Mermaid format):

Generate a mind map in Mermaid syntax for the following textbook text. The mind map should capture the main topics and their subtopics clearly in a hierarchical structure. Do NOT include a 'root' node; use a meaningful central topic. Include minimal Mermaid syntax so it renders correctly, and keep it concise.

Lesson: earth science and its branches

5. A glacier is melting. What are all of the scientists you can think of who might be involved in studying this glacier? What would each of them do?

Earth Science is the study of all aspects of our planet Earth. Earth Science is not just about the molten lava, icy mountain peaks, steep canyons and towering waterfalls of the continents. Earth Science includes the atmosphere and oceans. The field also looks out into the solar system, galaxy, and universe. Earth scientists seek to understand the beautiful planet on which we depend (Figure 1.10). Different branches of Earth Science study one particular pa

In [28]:
import subprocess

# 2) Clean code to remove markdown fences
cleaned_code = clean_mermaid_code(mindmap_code)
print("Cleaned Mermaid code:")
print(cleaned_code)

# 3) Render with the desired theme (e.g., 'forest', 'dark', 'neutral')
render_mermaid(cleaned_code, 'mermaid_output.svg', theme='forest')

Cleaned Mermaid code:
graph LR
    Earth Science
    Geology --> Earth Material & Structures
    Oceanography --> Ocean Basins & Water Movement
    Meteorology --> Atmosphere & Climate
    Astronomy --> Solar System & Universe
    Environmental Science --> Humans & Planet
    Paleontology --> Fossils & History
    Volcanology --> Molten Lava & Volcanoes
    Seismology --> Earthquakes
    Mineralogy --> Minerals
    Planetary Geology --> Other Planets
    Petroleum Geology --> Oil
    Soil Geology --> Soil
    Geography --> Earth Surface & Maps
    Marine Geology --> Ocean Basins
    Marine Biology --> Marine Life
    Chemical Oceanography --> Ocean Water Chemistry
    Climatology --> Climate
    Global Warming --> Atmosphere Heating
    Ecology --> Ecosystems
    Spacecraft --> Space Travel & Information
    Optical Telescopes --> Stars & Galaxies
    Radio Telescopes --> Radio Waves
    Black Holes --> Space Objects
    Universe Creation --> Cosmology
    Life on Other Planets --> Ast

FileNotFoundError: [Errno 2] No such file or directory: 'mmdc'

## Prediction pipeline (LORA)

In [14]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

In [16]:
# Define the instruct-tuned model checkpoint
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
scratch_dir = "../models"

# Load the tokenizer (force using the slow tokenizer)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=scratch_dir)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=scratch_dir,
    device_map="auto",
    torch_dtype=torch.float16
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
# Load the LoRA adapter weights into the base model
model = PeftModel.from_pretrained(model, "./lora_mistral_adapter")
model.eval()  # set model to evaluation mode

# Example prompt for inference
prompt = "Once upon a time"

# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

# Generate output (adjust parameters like max_new_tokens as desired)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,   # maximum number of new tokens to generate
        do_sample=True,      # use sampling; set to False for greedy decoding
        temperature=0.7      # adjust temperature for randomness
    )

# Decode and print the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Once upon a time, a friend gave me a book with a simple title, The Bible. I had heard of it before, but I had never read it. I had no idea what I was in for. I never imagined it would become an essential part of my
