In [None]:
!pip install -U datasets
!pip install rouge-score


In [None]:

# Imports
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
import pandas as pd
import json
import numpy as np
from sklearn.model_selection import train_test_split
import wandb
import os
import re
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Optional, Any
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm.auto import tqdm
import random
tqdm.pandas()

from google.colab import userdata


In [None]:

def display_dataset_statistics(data: List[Dict[str, Any]]):
    print(f"\n {len(data)} examples loaded")

    question_lengths = [len(item['question']) for item in data]
    plt.figure(figsize=(10, 6))
    plt.hist(question_lengths, bins=30, edgecolor='black')
    plt.title('Distribution of Question Lengths')
    plt.xlabel('Question Length (characters)')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.show()

    print("\n## Example entries:")
    for i in range(min(3, len(data))):
        print(f"Example {i+1}:")
        print(f"Question: {data[i]['question']}")
        print(f"Answer: {data[i]['answer']}")
        print(f"Distractors: {data[i]['distractors']}")
        print()


def format_options(options_field: Any) -> List[str]:
    if isinstance(options_field, str):
        options_list = re.findall(r'[a-e]\s*\)\s*([^,]*?)(?=\s*$|\s*,\s*[a-e]\s*\))', options_field)
        return [opt.strip() for opt in options_list] if options_list else []

    if isinstance(options_field, list):
        return [re.sub(r'^\s*[a-e][\):.-]?\s*', '', opt.strip()) for opt in options_field]

    raise ValueError(f"Unexpected format in options: {options_field}")


def parse_options(options_list: List[str], correct_option: str) -> tuple:
    try:
        correct_index = ord(correct_option.lower()) - ord('a')
        correct_answer = options_list[correct_index]
        distractors = [opt for i, opt in enumerate(options_list) if i != correct_index]
        return correct_answer, distractors
    except (IndexError, ValueError):
        raise ValueError(f"Bad option '{correct_option}'")

def preprocess_data(
    data: List[Dict[str, Any]],
    max_distractors: int = 3,
    print_examples: bool = False,
) -> List[Dict[str, Any]]:

    formatted_data = []
    skipped_count = 0

    if print_examples:
        print("\n 5 BEFORE processing: ")
        for i in range(min(5, len(data))):
          print(f"Raw Example {i+1}:")
          print(json.dumps(data[i], indent=2))
          print()

    for item in tqdm(data, desc=f"Processing data"):
        try:
            if 'Problem' in item:  # MathQA
                problem = item['Problem']
                options_field = item.get('options', '')
                correct = item.get('correct', '')

                if not (problem and options_field and correct):
                    skipped_count += 1
                    continue

                options = format_options(options_field)
                correct_answer, distractors = parse_options(options, correct)

            elif 'question' in item:  # MMLU-Pro
                problem = item['question']
                choices = item.get('options', [])
                answer_idx = item.get('answer_index')

                if not (problem and choices and answer_idx is not None):
                    skipped_count += 1
                    continue

                correct_answer = choices[answer_idx]
                distractors = [c for i, c in enumerate(choices) if i != answer_idx]

            else:
                skipped_count += 1
                continue

            distractors = distractors[:max_distractors]
            while len(distractors) < max_distractors:
                distractors.append("No distractor available")

            formatted_entry = {
                "question": problem,
                "answer": correct_answer,
                "distractors": distractors,
            }
            formatted_data.append(formatted_entry)

        except Exception:
            skipped_count += 1
            continue

    print(f"Processed {len(formatted_data)}, skipped {skipped_count}")

    if print_examples:
        print("\n## 5 AFTER processing:")
        for i in range(min(5, len(formatted_data))):
            print(f"Processed Example {i+1}:")
            print(json.dumps(formatted_data[i], indent=2))
            print()

    return formatted_data

def load_dataset_from_huggingface(dataset_name: str, split: str = "train"):
    try:
        dataset = load_dataset(dataset_name, split=split)
        return [dict(item) for item in dataset]
    except Exception as e:
        print(f"Failed to load {dataset_name}: {e}")
        return []

def process_and_save_dataset(
    dataset_name: str,
    output_path: Optional[str] = None,
    max_distractors: int = 3,
    num_samples: Optional[int] = None
) -> Optional[List[Dict[str, Any]]]:

    dataset_configs = {
        "allenai/math_qa": {"split": "train", "source": "mathqa"},
        "TIGER-Lab/MMLU-Pro": {
            "split": "test",
            "source": "mmlu_pro",
            "filter_categories": ["mathematics", "math", "arithmetic", "geometry", "algebra"]
        }
    }

    config = dataset_configs.get(dataset_name, {})

    data = load_dataset_from_huggingface(dataset_name, config.get('split', 'train'))

    if dataset_name == "TIGER-Lab/MMLU-Pro" and config.get('filter_categories'):
        data = [
            item for item in data
            if any(cat in str(item.get('category', '')).lower() for cat in config['filter_categories'])
        ]

    if num_samples:
        data = data[:num_samples]

    formatted_data = preprocess_data(
        data,
        max_distractors=max_distractors,
        print_examples=True,
    )

    if output_path:
        os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(formatted_data, f, ensure_ascii=False, indent=2)
        print(f"Dataset saved to {output_path}")

    display_dataset_statistics(formatted_data)

    return formatted_data

def combine_datasets(
    datasets: List[str],
    output_path: Optional[str] = None,
    max_distractors: int = 3
) -> List[Dict[str, Any]]:

    combined_data = []
    for dataset_name in datasets:
        data = process_and_save_dataset(dataset_name, max_distractors=max_distractors)
        combined_data.extend(data)

    random.shuffle(combined_data)

    if output_path:
        os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(combined_data, f, ensure_ascii=False, indent=2)
        print(f"Combined dataset saved to {output_path}")

    display_dataset_statistics(combined_data)

    return combined_data

if __name__ == "__main__":
    mathqa_data = process_and_save_dataset(
        "allenai/math_qa",
        output_path="./processed_data/mathqa_processed.json",
        max_distractors=3
    )

    mmlu_data = process_and_save_dataset(
        "TIGER-Lab/MMLU-Pro",
        output_path="./processed_data/mmlu_processed.json",
        max_distractors=3,
        num_samples=1351
    )

    combined_data = combine_datasets(
        ["allenai/math_qa", "TIGER-Lab/MMLU-Pro"],
        output_path="./processed_data/combined_data.json"
    )


In [None]:

def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


In [None]:

class DistractorDatasetSeq2Seq(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []

        special_tokens = {"additional_special_tokens": ["<distractor1>", "<distractor2>", "<distractor3>"]}
        tokenizer.add_special_tokens(special_tokens)

        for item in data:
            question = item["question"]
            answer = item["answer"]
            distractors = item["distractors"]

            input_text = f"Question: {question} Answer: {answer}"
            target_text = f"<distractor1> {distractors[0]} <distractor2> {distractors[1]} <distractor3> {distractors[2]}"

            input_encodings = self.tokenizer(input_text,
                                     max_length=self.max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")

            target_encodings = self.tokenizer(target_text,
                                      max_length=self.max_length,
                                      padding="max_length",
                                      truncation=True,
                                      return_tensors="pt")

            self.examples.append({
                "input_ids": input_encodings["input_ids"][0],
                "attention_mask": input_encodings["attention_mask"][0],
                "labels": target_encodings["input_ids"][0],
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


In [None]:

class DistractorDatasetCausalLM(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []


        for item in data:
            question = item["question"]
            answer = item["answer"]
            distractors = item["distractors"]

            formatted_input = f"Question: {question} Answer: {answer} Distractors:"
            formatted_output = f" {', '.join(distractors)}"

            full_text = formatted_input + formatted_output

            encoded = self.tokenizer(full_text,
                                     max_length=self.max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")

            input_ids_len = len(self.tokenizer(formatted_input,
                                              add_special_tokens=False)['input_ids'])

            labels = encoded["input_ids"].clone()
            labels[0, :input_ids_len] = -100
            attention_mask = encoded["attention_mask"].clone()

            self.examples.append({
                "input_ids": encoded["input_ids"][0],
                "attention_mask": attention_mask[0],
                "labels": labels[0]
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


In [None]:

def train_distractor_model(data_path, test_data_path, model_name, output_dir, epochs=3, batch_size=4, lr=5e-5, gradient_accumulation_steps=1,
                           weight_decay=0.01, warmup_ratio=0.1, lr_scheduler_type="linear", early_stopping_patience=3):
    data = load_data(data_path)

    test_data = load_data(test_data_path)
    train_data, eval_data = train_test_split(data, test_size=0.2, random_state=42)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_name_lower = model_name.lower()
    is_seq2seq = "bart" in model_name.lower() or "t5" in model_name.lower()

    if is_seq2seq:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        # Add special tokens for seq2seq
        special_tokens = {"additional_special_tokens": ["<distractor1>", "<distractor2>", "<distractor3>"]}
        num_added_tokens = tokenizer.add_special_tokens(special_tokens)
        model.resize_token_embeddings(len(tokenizer))
        train_dataset = DistractorDatasetSeq2Seq(train_data, tokenizer)
        eval_dataset = DistractorDatasetSeq2Seq(eval_data, tokenizer)
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(model, 'resize_token_embeddings'):
                model.resize_token_embeddings(len(tokenizer))
        train_dataset = DistractorDatasetCausalLM(train_data, tokenizer)
        eval_dataset = DistractorDatasetCausalLM(eval_data, tokenizer)
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=output_dir,
        run_name=f"{model_name.split('/')[-1]}-lr{lr}-bs{batch_size}-ep{epochs}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir=f"{output_dir}/logs",
        logging_steps=100,
        learning_rate=lr,
        weight_decay=weight_decay,
        warmup_ratio=warmup_ratio,
        lr_scheduler_type=lr_scheduler_type,
        save_total_limit=2,
        fp16=torch.cuda.is_available(),
        gradient_accumulation_steps=gradient_accumulation_steps,
        report_to="wandb",
        disable_tqdm=False,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    callbacks = []
    if early_stopping_patience > 0:
        early_stopping_callback = EarlyStoppingCallback(
            early_stopping_patience=early_stopping_patience
        )
        callbacks.append(early_stopping_callback)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        callbacks=callbacks
    )

    trainer.train()

    final_model_dir = f"{output_dir}/final"
    model.save_pretrained(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)

    eval_results = trainer.evaluate()
    wandb.log({"final_eval_loss": eval_results["eval_loss"]})

    artifact = wandb.Artifact(name=f"model-{wandb.run.id}",
                              type="model",
                              description=f"trained model: {model_name}")
    artifact.add_dir(final_model_dir)
    wandb.log_artifact(artifact)

    print(f"\nEvaluating model on {len(test_data)} test examples...")
    evaluation_results = evaluate_model(model, tokenizer, is_seq2seq, test_data, f"{model_name}-{wandb.run.id}")

    wandb.log({
        "avg_bert_score": evaluation_results["avg_bert_score"],
        "avg_rouge1_score": evaluation_results["avg_rouge1_score"],
        "avg_rouge2_score": evaluation_results["avg_rouge2_score"],
        "avg_rougeL_score": evaluation_results["avg_rougeL_score"]
    })

    examples_table = wandb.Table(columns=["Question", "Answer", "Gold Distractors", "Generated Distractors", "BERT Score"])
    for i in range(min(10, len(evaluation_results["results"]))):
        result = evaluation_results["results"][i]
        examples_table.add_data(
            result["question"],
            result["answer"],
            str(result["gold_distractors"]),
            str(result["generated_distractors"]),
            result["bert_score"]
        )
    wandb.log({"evaluation_examples": examples_table})

    return evaluation_results


In [None]:

def compute_bert_score(generated_distractors, gold_distractors):
    if not generated_distractors or not gold_distractors:
        print("Warning: Empty distractor list received")
        return 0.0

    try:
        model = SentenceTransformer('all-MiniLM-L6-v2')

        gen_distractors = [str(d) for d in generated_distractors if d]
        gold_distractors = [str(d) for d in gold_distractors if d]

        if not gen_distractors or not gold_distractors:
            return 0.0

        gen_embeddings = model.encode(gen_distractors)
        gold_embeddings = model.encode(gold_distractors)

        similarity_matrix = cosine_similarity(gen_embeddings, gold_embeddings)

        max_similarities = np.max(similarity_matrix, axis=1)

        return float(np.mean(max_similarities))
    except Exception as e:
        print(f"BERT scoring fail: {e}")
        return 0.0

def compute_rouge_scores(generated_distractors, gold_distractors):

    try:
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        gen_distractors = [str(d) for d in generated_distractors if d]
        gold_distractors = [str(d) for d in gold_distractors if d]

        if not gen_distractors or not gold_distractors:
            return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}

        all_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

        for gen in gen_distractors:
            best_scores = {'rouge1': 0, 'rouge2': 0, 'rougeL': 0}

            for gold in gold_distractors:
                scores = scorer.score(gen, gold)

                for metric in best_scores:
                    if scores[metric].fmeasure > best_scores[metric]:
                        best_scores[metric] = scores[metric].fmeasure

            for metric in all_scores:
                all_scores[metric].append(best_scores[metric])

        return {
            'rouge1': float(np.mean(all_scores['rouge1'])) if all_scores['rouge1'] else 0.0,
            'rouge2': float(np.mean(all_scores['rouge2'])) if all_scores['rouge2'] else 0.0,
            'rougeL': float(np.mean(all_scores['rougeL'])) if all_scores['rougeL'] else 0.0
        }
    except Exception as e:
        print(f"ROUGE scoring failed: {e}")
        return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}


In [None]:
def generate_distractors(question, answer, model, tokenizer, is_seq2seq=False, num_distractors=3, max_length=100):

    prompt = f"Question: {question} Answer: {answer}" if is_seq2seq else f"Question: {question} Answer: {answer} Distractors:"

    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=min(len(input_ids[0]) + max_length, 512),
            num_return_sequences=num_distractors*2,
            num_beams=num_distractors*2,
            temperature=0.8,
            top_p=0.92,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2,
            no_repeat_ngram_size=2,
            early_stopping=True
        )

    all_distractors = []
    for seq in output:
        generated_text = tokenizer.decode(seq, skip_special_tokens=True)

        if is_seq2seq:
            if "<distractor1>" in generated_text:
                distractor_parts = generated_text.split("<distractor")
                for i in range(1, len(distractor_parts)):
                    part = distractor_parts[i]
                    if ">" in part:
                        distractor = part.split(">", 1)[1].strip()
                        if "<distractor" in distractor:
                            distractor = distractor.split("<distractor")[0].strip()
                        all_distractors.append(distractor)
            else:
                distractors_text = generated_text.strip()
                if ',' in distractors_text:
                    all_distractors.extend([d.strip() for d in distractors_text.split(',')])
                else:
                    potential_distractors = distractors_text.split()
                    if all(d.replace('.', '').isdigit() for d in potential_distractors):
                        all_distractors.extend(potential_distractors)
                    else:
                        all_distractors.append(distractors_text)
        else:
            if "Distractors:" in generated_text:
                distractors_text = generated_text.split("Distractors:")[1].strip()
                if ',' in distractors_text:
                    all_distractors.extend([d.strip() for d in distractors_text.split(',')])
                else:
                    potential_distractors = distractors_text.split()
                    if all(d.replace('.', '').isdigit() for d in potential_distractors):
                        all_distractors.extend(potential_distractors)
                    else:
                        sentences = re.split(r'[.!?]\s+', distractors_text)
                        if len(sentences) > 1:
                            all_distractors.extend([s.strip() + '.' for s in sentences if s.strip()])
                        else:
                            all_distractors.append(distractors_text)

    unique_distractors = []
    seen = set()
    answer_lower = answer.lower()

    for distractor in all_distractors:
        distractor_clean = distractor.strip()
        if (distractor_clean and
            distractor_clean not in seen and
            distractor_clean.lower() != answer_lower):
            seen.add(distractor_clean)
            unique_distractors.append(distractor_clean)

    if not unique_distractors and answer.replace('.', '').isdigit():
        base = float(answer)
        unique_distractors = [
            str(round(base * 0.9, 2)),
            str(round(base * 1.1, 2)),
            str(round(base * 0.95, 2))
        ]

    return unique_distractors[:num_distractors]

In [None]:

def evaluate_model(model, tokenizer, is_seq2seq, test_data, model_name):
    results = []
    bert_scores = []
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []

    print(f"\nEvaluating {model_name} on {len(test_data)} test examples...")

    for idx, item in enumerate(tqdm(test_data, desc=f"Evaluating {model_name}")):
        question = item["question"]
        answer = item["answer"]
        gold_distractors = item["distractors"]

        #print(f"\nExample {idx+1}:")
        #print(f"Question: {question}")
        #print(f"Answer: {answer}")
        #print(f"Gold distractors: {gold_distractors}")

        try:
            generated_distractors = generate_distractors(
                question, answer, model, tokenizer, is_seq2seq,
                num_distractors=len(gold_distractors)
            )

            bert_score = compute_bert_score(generated_distractors, gold_distractors)
            bert_scores.append(bert_score)

            rouge_scores = compute_rouge_scores(generated_distractors, gold_distractors)
            rouge1_scores.append(rouge_scores['rouge1'])
            rouge2_scores.append(rouge_scores['rouge2'])
            rougeL_scores.append(rouge_scores['rougeL'])

            #print(f"Generated distractors: {generated_distractors}")
            #print(f"BERT score: {bert_score:.4f}")
            #print(f"ROUGE-1: {rouge_scores['rouge1']:.4f}")
            #print(f"ROUGE-2: {rouge_scores['rouge2']:.4f}")
            #print(f"ROUGE-L: {rouge_scores['rougeL']:.4f}")

            results.append({
                "question": question,
                "answer": answer,
                "gold_distractors": gold_distractors,
                "generated_distractors": generated_distractors,
                "bert_score": bert_score,
                "rouge1_score": rouge_scores['rouge1'],
                "rouge2_score": rouge_scores['rouge2'],
                "rougeL_score": rouge_scores['rougeL']
            })

            if (idx + 1) % 10 == 0:
                interim_bert = np.mean(bert_scores[-10:])
                interim_rouge1 = np.mean(rouge1_scores[-10:])
                print(f"Last 10 examples - BERT: {interim_bert:.4f}, ROUGE-1: {interim_rouge1:.4f}")

        except Exception as e:
            print(f"skipped example {idx}: {e}")
            continue

    avg_bert = np.mean(bert_scores) if bert_scores else 0
    avg_rouge1 = np.mean(rouge1_scores) if rouge1_scores else 0
    avg_rouge2 = np.mean(rouge2_scores) if rouge2_scores else 0
    avg_rougeL = np.mean(rougeL_scores) if rougeL_scores else 0

    print("\n==== EVALUATION SUMMARY ====")
    print(f"Total examples evaluated: {len(results)}")
    print(f"Average BERT score: {avg_bert:.4f}")
    print(f"Average ROUGE-1 score: {avg_rouge1:.4f}")
    print(f"Average ROUGE-2 score: {avg_rouge2:.4f}")
    print(f"Average ROUGE-L score: {avg_rougeL:.4f}")

    summary = {
        "model_name": model_name,
        "avg_bert_score": avg_bert,
        "avg_rouge1_score": avg_rouge1,
        "avg_rouge2_score": avg_rouge2,
        "avg_rougeL_score": avg_rougeL,
        "results": results
    }

    return summary


In [None]:

def create_sweep_config():
    sweep_config = {
        'method': 'bayes',
        'metric': {'name': 'avg_bert_score','goal': 'maximize'},
        'parameters': {
            'learning_rate': {'values': [1e-6, 1e-5, 5e-5, 1e-4]},
            'epochs': {'values': [3,4,5,6,7,8,9]},
            'batch_size': {'values': [8, 16, 32]},
            'model_name': {'values': ['facebook/bart-base','t5-small','gpt2']},
            'data_path': {'values': ['processed_data/combined_data.json']},
            'test_data_path': {'values': ['test_data_dir/arithmetik_questions.json']},
            'model_output_dir': {'values': ['distractor_model']},
            'gradient_accumulation_steps': {'values': [2, 4]},
            'weight_decay': {'values': [0.01, 0.05, 0.1]},
            'warmup_ratio': {'values': [0.05, 0.1, 0.15]},
            'lr_scheduler_type': {'values': ['linear', 'cosine', 'cosine_with_restarts']},
            'early_stopping_patience': {'values': [3, 5]}
        }
    }
    return sweep_config



In [None]:

def train_evaluate_sweep():
    wandb.init()

    config = wandb.config

    train_data_path = config.data_path
    test_data_path = config.test_data_path
    model_name = config.model_name
    output_dir = config.model_output_dir
    epochs = config.epochs
    batch_size = config.batch_size
    lr = config.learning_rate
    gradient_accumulation_steps = config.gradient_accumulation_steps
    weight_decay = config.weight_decay
    warmup_ratio = config.warmup_ratio
    lr_scheduler_type = config.lr_scheduler_type
    early_stopping_patience = config.early_stopping_patience

    evaluation_results = train_distractor_model(
        train_data_path,
        test_data_path,
        model_name,
        output_dir,
        epochs,
        batch_size,
        lr,
        gradient_accumulation_steps,
        weight_decay,
        warmup_ratio,
        lr_scheduler_type,
        early_stopping_patience
    )

    print(f"Training and evaluation complete for {model_name}")
    print(f"BERT Score: {evaluation_results['avg_bert_score']:.4f}")
    print(f"ROUGE-1: {evaluation_results['avg_rouge1_score']:.4f}")
    print(f"ROUGE-2: {evaluation_results['avg_rouge2_score']:.4f}")
    print(f"ROUGE-L: {evaluation_results['avg_rougeL_score']:.4f}")



In [None]:

def main():
    wandb_key = userdata.get('WANDB_KEY')
    wandb.login(key=wandb_key)

    sweep_config = create_sweep_config()

    sweep_id = wandb.sweep(sweep_config, project="bachelorprojekt")

    wandb.agent(sweep_id, train_evaluate_sweep, count=10)


In [None]:
if __name__ == "__main__":
    main()