In [1]:
!pip install -q transformers accelerate peft bitsandbytes datasets gradio

In [1]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset
import gradio as gr

In [2]:
# -------------------
# 1. Load dataset (SQuAD for demo)
# -------------------
flashcards = load_dataset("squad")

def format_example(example):
    context = example["context"]
    question = example["question"]
    answer = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else "N/A"
    return {
        "input": f"Context: {context}",
        "output": f"Q: {question}\nA: {answer}"
    }

flashcards = flashcards.map(format_example)

In [3]:
# -------------------
# 2. Tokenizer + preprocessing
# -------------------
MODEL_NAME = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(example):
    text = f"{example['input']}\n{example['output']}"
    tokenized = tokenizer(
        text,
        max_length=320,
        truncation=True,
        padding="max_length"
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized_train = flashcards["train"].map(tokenize_function, remove_columns=flashcards["train"].column_names)
tokenized_val = flashcards["validation"].map(tokenize_function, remove_columns=flashcards["validation"].column_names)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [4]:
# -------------------
# 3. Model + QLoRA setup
# -------------------
print("Loading model with QLoRA...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)

Loading model with QLoRA...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# -------------------
# 4. Training setup
# -------------------
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="./flashcard-lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=1.0,         # small demo run (increase for better results)
    logging_steps=50,
    save_total_limit=1,
    eval_strategy="no",
    save_strategy="no",
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train.shuffle().select(range(2000)),   # subset for Colab
    eval_dataset=tokenized_val.shuffle().select(range(500)),
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("Training...")
trainer.train()
model.save_pretrained("./flashcard-lora-adapter")

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Training...


[34m[1mwandb[0m: Currently logged in as: [33mvaaruni-desai[0m ([33mvaaruni-desai-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, mcp] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss
50,1.8106
100,1.7323
150,1.728
200,1.7084
250,1.7285
300,1.7353
350,1.7423
400,1.6699
450,1.7348
500,1.6939


In [6]:
# -------------------
# 5. Reload model with adapters for inference
# -------------------
print("Reloading trained adapters...")
inference_base = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)
inference_model = PeftModel.from_pretrained(inference_base, "./flashcard-lora-adapter")
inference_model.eval()



Reloading trained adapters...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj)

In [16]:
# -------------------
# 6. Gradio app for flashcards
# -------------------
def generate_flashcards(text, num_cards=5):
    prompt = f"""
You are a flashcard generator.
Create exactly {num_cards} separate flashcards from the context.
Each flashcard must be in this format:

Q: <question>
A: <answer>

Context:
{text}
"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(inference_model.device)
    with torch.no_grad():
        outputs = inference_model.generate(
            **inputs,
            max_new_tokens=512,        # allow longer outputs
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            num_return_sequences=1     # keep one long output
        )
    input_length = inputs["input_ids"].shape[1]

    # Slice off the prompt part
    result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    # ✅ Post-process: split into flashcards
    flashcards = []
    lines = [line.strip() for line in result.splitlines() if line.strip()]
    current_q, current_a = None, None
    for line in lines:
        if line.startswith("Q:"):
            current_q = line
        elif line.startswith("A:"):
            current_a = line
            if current_q and current_a:
                flashcards.append(f"{current_q}\n{current_a}")
                current_q, current_a = None, None

    # return only the requested number of cards
    return "\n\n".join(flashcards[:num_cards])

In [17]:
text = "The invention of the printing press by Johannes Gutenberg in the mid-15th century was a revolutionary development in human history. Before its creation, books had to be copied by hand, which was time-consuming and expensive. Gutenberg’s press used movable type, allowing pages to be mass-produced quickly and accurately. This innovation dramatically reduced the cost of books, making them accessible to a much wider audience. As a result, literacy rates began to rise across Europe, and ideas spread more rapidly than ever before. The printing press is often credited with fueling the Renaissance, the Reformation, and the Scientific Revolution, as it enabled the exchange of knowledge on an unprecedented scale."
generate_flashcards(text, num_cards=5)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'Q: Who invented the printing press?\nA: Johannes Gutenberg\n\nQ: What was the main advantage of the printing press?\nA: It allowed pages to be mass-produced quickly and accurately.\n\nQ: What did the printing press do to literacy rates?\nA: It fueled the Renaissance, the Reformation, and the Scientific Revolution.\n\nQ: What is the printing press often credited with?\nA: It enabled the exchange of knowledge on an unprecedented scale.\n\nQ: What did the printing press do to books?\nA: It made them accessible to a much wider audience.'

In [18]:
with gr.Blocks() as demo:
    gr.Markdown("## 📚 Flashcard Generator (Mistral-7B + QLoRA Fine-Tuned on SQuAD)")
    text_input = gr.Textbox(placeholder="Paste study text...", label="Input Text", lines=6)
    btn = gr.Button("Generate Flashcard")
    output = gr.Textbox(label="Flashcard Q&A")
    btn.click(generate_flashcards, inputs=text_input, outputs=output)

demo.launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://17c48e16aadd0ec58f.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


