<a href="https://colab.research.google.com/github/DishaKushwah/custom-quiz-generator/blob/main/mcq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## MCQS
import torch
from transformers import (
    T5ForConditionalGeneration, T5Tokenizer,
    pipeline, AutoTokenizer, AutoModel
)
from sentence_transformers import SentenceTransformer
import spacy
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import re
import random
from typing import List, Dict, Tuple
import nltk
from nltk.corpus import wordnet
import string

class MultipleChoiceQuestionGenerator:
    def __init__(self):
        """Initialize the MCQ generator with advanced models."""
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Load T5 model for question generation
        self.qg_model_name = "valhalla/t5-base-qg-hl"
        self.qg_tokenizer = T5Tokenizer.from_pretrained(self.qg_model_name)
        self.qg_model = T5ForConditionalGeneration.from_pretrained(self.qg_model_name).to(self.device)

        # Load question-answering pipeline for answer validation
        self.qa_pipeline = pipeline(
            "question-answering",
            model="deepset/roberta-large-squad2",
            tokenizer="deepset/roberta-large-squad2",
            device=0 if torch.cuda.is_available() else -1
        )

        # Load sentence transformer for semantic similarity (distractor generation)
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

        # Load spaCy for NLP processing
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Please install spaCy English model: python -m spacy download en_core_web_sm")
            self.nlp = None

        # Load fill-mask pipeline for generating distractors
        self.fill_mask = pipeline(
            "fill-mask",
            model="roberta-large",
            tokenizer="roberta-large",
            device=0 if torch.cuda.is_available() else -1
        )

        # Download NLTK data
        try:
            nltk.download('wordnet', quiet=True)
            nltk.download('omw-1.4', quiet=True)
        except:
            pass

    def extract_key_information(self, text: str) -> Dict:
        """Extract key information from text for question generation."""
        if not self.nlp:
            return {"entities": [], "noun_chunks": [], "sentences": []}

        doc = self.nlp(text)

        # Extract named entities
        entities = []
        for ent in doc.ents:
            if ent.label_ in ['PERSON', 'ORG', 'GPE', 'DATE', 'EVENT', 'WORK_OF_ART', 'CARDINAL', 'ORDINAL']:
                entities.append({
                    'text': ent.text,
                    'label': ent.label_,
                    'start': ent.start_char,
                    'end': ent.end_char
                })

        # Extract noun chunks
        noun_chunks = [chunk.text for chunk in doc.noun_chunks if len(chunk.text.split()) <= 4]

        # Extract sentences
        sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.split()) > 5]

        return {
            "entities": entities,
            "noun_chunks": noun_chunks,
            "sentences": sentences
        }

    def generate_question_from_context(self, context: str, answer_text: str) -> str:
        """Generate a question given context and answer."""
        # Highlight the answer in the context for T5
        highlighted_context = context.replace(answer_text, f"<hl>{answer_text}<hl>")
        input_text = f"generate question: {highlighted_context}"

        inputs = self.qg_tokenizer.encode_plus(
            input_text,
            max_length=512,
            truncation=True,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.qg_model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=64,
                num_beams=4,
                temperature=0.8,
                do_sample=True,
                early_stopping=True
            )

        question = self.qg_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return question

    def generate_distractors_semantic(self, correct_answer: str, context: str, num_distractors: int = 3) -> List[str]:
        """Generate distractors using semantic similarity and context understanding."""
        distractors = []

        # Method 1: Use fill-mask to generate contextually similar options
        try:
            # Replace answer with mask in context
            masked_context = context.replace(correct_answer, "<mask>")
            if "<mask>" in masked_context:
                predictions = self.fill_mask(masked_context, top_k=20)
                for pred in predictions:
                    candidate = pred['token_str'].strip()
                    if (candidate != correct_answer and
                        candidate.lower() != correct_answer.lower() and
                        len(candidate) > 1 and
                        candidate not in distractors):
                        distractors.append(candidate)
                        if len(distractors) >= num_distractors:
                            break
        except:
            pass

        # Method 2: Extract similar entities from context
        if self.nlp and len(distractors) < num_distractors:
            doc = self.nlp(context)
            answer_doc = self.nlp(correct_answer)

            # Get answer entity type
            answer_label = None
            for ent in answer_doc.ents:
                answer_label = ent.label_
                break

            # Find similar entities
            for ent in doc.ents:
                if (ent.label_ == answer_label and
                    ent.text != correct_answer and
                    ent.text not in distractors):
                    distractors.append(ent.text)
                    if len(distractors) >= num_distractors:
                        break

        # Method 3: Generate using WordNet synonyms and related words
        if len(distractors) < num_distractors:
            try:
                words = correct_answer.split()
                for word in words:
                    synsets = wordnet.synsets(word)
                    for synset in synsets[:3]:
                        for lemma in synset.lemmas()[:2]:
                            candidate = lemma.name().replace('_', ' ')
                            if (candidate != correct_answer and
                                candidate.lower() != correct_answer.lower() and
                                candidate not in distractors):
                                distractors.append(candidate)
                                if len(distractors) >= num_distractors:
                                    break
                        if len(distractors) >= num_distractors:
                            break
                    if len(distractors) >= num_distractors:
                        break
            except:
                pass

        # Method 4: Generate plausible distractors based on answer type
        if len(distractors) < num_distractors:
            distractors.extend(self.generate_type_based_distractors(correct_answer, context))

        # Remove duplicates and return
        unique_distractors = []
        seen = set()
        for d in distractors:
            if d.lower() not in seen and d.lower() != correct_answer.lower():
                seen.add(d.lower())
                unique_distractors.append(d)

        return unique_distractors[:num_distractors]

    def generate_type_based_distractors(self, correct_answer: str, context: str) -> List[str]:
        """Generate distractors based on answer type patterns."""
        distractors = []

        # Check if answer is a number
        if re.match(r'^\d+$', correct_answer):
            base_num = int(correct_answer)
            variations = [
                str(base_num + random.randint(1, 10)),
                str(base_num - random.randint(1, 10)),
                str(base_num * 2),
                str(base_num // 2) if base_num > 1 else str(base_num + 1)
            ]
            distractors.extend([v for v in variations if v != correct_answer])

        # Check if answer is a year
        elif re.match(r'^\d{4}$', correct_answer):
            year = int(correct_answer)
            year_variations = [
                str(year + random.randint(1, 20)),
                str(year - random.randint(1, 20)),
                str(year + random.randint(50, 100)),
                str(year - random.randint(50, 100))
            ]
            distractors.extend([y for y in year_variations if y != correct_answer])

        # Check if answer is a percentage
        elif '%' in correct_answer:
            try:
                num = float(correct_answer.replace('%', ''))
                percent_variations = [
                    f"{num + random.randint(5, 25)}%",
                    f"{num - random.randint(5, 25)}%",
                    f"{num * 2}%" if num <= 50 else f"{num / 2}%"
                ]
                distractors.extend([p for p in percent_variations if p != correct_answer])
            except:
                pass

        return distractors[:3]

    def validate_mcq_quality(self, question: str, correct_answer: str, distractors: List[str], context: str) -> Dict:
        """Validate the quality of generated MCQ."""
        # Check if the question can be answered correctly
        try:
            qa_result = self.qa_pipeline(question=question, context=context)
            predicted_answer = qa_result['answer']
            confidence = qa_result['score']

            # Check if predicted answer matches or is similar to correct answer
            similarity_threshold = 0.7
            correct_embedding = self.sentence_model.encode([correct_answer])
            predicted_embedding = self.sentence_model.encode([predicted_answer])
            similarity = cosine_similarity(correct_embedding, predicted_embedding)[0][0]

            is_answerable = similarity > similarity_threshold or correct_answer.lower() in predicted_answer.lower()

        except:
            is_answerable = False
            confidence = 0.0
            similarity = 0.0

        # Check distractor quality
        if len(distractors) > 0:
            distractor_embeddings = self.sentence_model.encode(distractors)
            correct_embedding = self.sentence_model.encode([correct_answer])

            # Calculate similarity between distractors and correct answer
            similarities = cosine_similarity(correct_embedding, distractor_embeddings)[0]
            avg_distractor_similarity = np.mean(similarities)

            # Good distractors should be somewhat similar but not too similar
            distractor_quality = "good" if 0.3 < avg_distractor_similarity < 0.8 else "poor"
        else:
            distractor_quality = "poor"
            avg_distractor_similarity = 0.0

        return {
            "is_answerable": is_answerable,
            "confidence": confidence,
            "answer_similarity": similarity,
            "distractor_quality": distractor_quality,
            "avg_distractor_similarity": avg_distractor_similarity
        }

    def generate_mcq(self, context: str, num_questions: int = 5) -> List[Dict]:
        """Generate multiple choice questions from context."""
        mcqs = []

        # Extract key information
        key_info = self.extract_key_information(context)

        # Generate questions from entities
        for entity in key_info["entities"][:num_questions]:
            correct_answer = entity["text"]

            # Generate question
            question = self.generate_question_from_context(context, correct_answer)

            # Generate distractors
            distractors = self.generate_distractors_semantic(correct_answer, context, 3)

            # Skip if not enough distractors
            if len(distractors) < 2:
                continue

            # Validate quality
            quality = self.validate_mcq_quality(question, correct_answer, distractors, context)

            # Create options and shuffle
            options = [correct_answer] + distractors[:3]
            random.shuffle(options)
            correct_option = chr(65 + options.index(correct_answer))  # A, B, C, D

            mcq = {
                "question": question,
                "options": {
                    "A": options[0],
                    "B": options[1],
                    "C": options[2] if len(options) > 2 else "None of the above",
                    "D": options[3] if len(options) > 3 else "All of the above"
                },
                "correct_answer": correct_option,
                "correct_text": correct_answer,
                "entity_type": entity["label"],
                "quality_score": quality["confidence"],
                "is_answerable": quality["is_answerable"]
            }

            # Only include high-quality MCQs
            if quality["is_answerable"] and quality["confidence"] > 0.3:
                mcqs.append(mcq)

        # Generate additional questions from noun chunks if needed
        if len(mcqs) < num_questions:
            for chunk in key_info["noun_chunks"][:num_questions - len(mcqs)]:
                question = self.generate_question_from_context(context, chunk)
                distractors = self.generate_distractors_semantic(chunk, context, 3)

                if len(distractors) >= 2:
                    quality = self.validate_mcq_quality(question, chunk, distractors, context)

                    if quality["is_answerable"] and quality["confidence"] > 0.2:
                        options = [chunk] + distractors[:3]
                        random.shuffle(options)
                        correct_option = chr(65 + options.index(chunk))

                        mcq = {
                            "question": question,
                            "options": {
                                "A": options[0],
                                "B": options[1],
                                "C": options[2] if len(options) > 2 else "None of the above",
                                "D": options[3] if len(options) > 3 else "All of the above"
                            },
                            "correct_answer": correct_option,
                            "correct_text": chunk,
                            "entity_type": "NOUN_CHUNK",
                            "quality_score": quality["confidence"],
                            "is_answerable": quality["is_answerable"]
                        }
                        mcqs.append(mcq)

        # Sort by quality score and return
        mcqs.sort(key=lambda x: x["quality_score"], reverse=True)
        return mcqs[:num_questions]

def main():
    """Main function to demonstrate the MCQ generator."""
    generator = MultipleChoiceQuestionGenerator()

    # Sample context
    sample_context = """
    The Renaissance was a period of cultural, artistic, political and economic rebirth following the Middle Ages.
    It began in Italy in the 14th century and spread throughout Europe. Leonardo da Vinci, born in 1452, was one
    of the most famous Renaissance artists and inventors. He created masterpieces like the Mona Lisa and The Last Supper.
    Michelangelo, another renowned artist, painted the ceiling of the Sistine Chapel between 1508 and 1512.
    The Renaissance emphasized humanism, scientific inquiry, and artistic innovation. The printing press,
    invented by Johannes Gutenberg around 1440, helped spread Renaissance ideas across Europe.
    This period lasted approximately 300 years, from the 14th to the 17th century.
    """

    print("Multiple Choice Question Generator")
    print("=" * 50)

    # Get user input
    user_context = input("Enter your context (or press Enter to use sample): ").strip()
    if not user_context:
        user_context = sample_context
        print("Using sample context about the Renaissance...")

    try:
        num_questions = int(input("Number of MCQs to generate (default 5): ") or "5")
    except ValueError:
        num_questions = 5

    print(f"\nGenerating {num_questions} multiple choice questions...")
    print("=" * 50)

    # Generate MCQs
    mcqs = generator.generate_mcq(user_context, num_questions)

    # Display results
    if mcqs:
        for i, mcq in enumerate(mcqs, 1):
            print(f"\nQuestion {i}: ")
            print(f"Q: {mcq['question']}")
            print()
            for option, text in mcq['options'].items():
                print(f"{option}) {text}")
            print(f"\nCorrect Answer: {mcq['correct_answer']}) {mcq['correct_text']}")
    else:
        print("No high-quality MCQs could be generated from the provided context.")
        print("Try providing a longer, more detailed context with specific facts and entities.")

    print("\nGeneration complete!")

if __name__ == "__main__":
    main()