In [None]:
# Install required packages
!pip uninstall -y fsspec gcsfs
!pip install -q fsspec[http]==2024.6.1 gcsfs==2024.6.1
!pip install -q datasets transformers torch accelerate bitsandbytes

import os
import json
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset
from typing import List, Dict, Tuple
from google.colab import drive

class MIRFlanT5DualTrainer:
    def __init__(self,
                 base_model: str = "google/flan-t5-base",
                 qa_output_dir: str = "mir_flan_t5_qa_v2",
                 fulltext_output_dir: str = "mir_flan_t5_fulltext_v2"):
        """Initialize the FLAN-T5 trainer for both QA and Full Text formats"""
        print("Initializing MIR FLAN-T5 Dual Trainer...")
        self.base_model = base_model
        self.qa_output_dir = qa_output_dir
        self.fulltext_output_dir = fulltext_output_dir

        os.makedirs(qa_output_dir, exist_ok=True)
        os.makedirs(fulltext_output_dir, exist_ok=True)

        print(f"Loading base model: {base_model}")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)

        # Initialize separate models for QA and Full Text
        self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
        self.fulltext_model = AutoModelForSeq2SeqLM.from_pretrained(base_model)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.qa_model = self.qa_model.to(self.device)
        self.fulltext_model = self.fulltext_model.to(self.device)

    def format_qa_question(self, item: Dict) -> Dict:
        """Format question for QA training with improved prompting"""
        try:
            # Handle the different data formats
            if 'options' in item:
                # Format for structured questions
                question_text = item['question']
                options = item['options']
                answer_letter = item['correct_answer']
                options_dict = {
                    'A': options.get('A', ''),
                    'B': options.get('B', ''),
                    'C': options.get('C', ''),
                    'D': options.get('D', '')
                }
            else:
                # Format for training questions
                question_text = item['input'].replace("Medical Question: ", "").strip()
                if 'context' not in item or 'all_options' not in item['context']:
                    return None
                options = item['context']['all_options']
                answer_num = item['context']['numeric_answer']
                answer_letter = chr(ord('A') + answer_num - 1)
                options_dict = {
                    'A': options[0],
                    'B': options[1],
                    'C': options[2],
                    'D': options[3]
                }

            input_text = f"""Task: Medical Multiple Choice Question
Instructions: Select the correct answer (A, B, C, or D).
Respond only with the letter of the correct answer.

Question: {question_text}

Options:
A) {options_dict['A']}
B) {options_dict['B']}
C) {options_dict['C']}
D) {options_dict['D']}

Your answer (A/B/C/D):"""

            return {
                'input': input_text,
                'output': answer_letter
            }
        except Exception as e:
            print(f"Error formatting QA question: {str(e)}")
            print(f"Item structure: {json.dumps(item, indent=2)}")
            return None

    def format_fulltext_question(self, item: Dict) -> Dict:
        """Format question for full text training with improved prompting"""
        try:
            question_text = item['input'].replace("Medical Question: ", "").strip()
            correct_answer = item['output']

            input_text = f"""Task: Medical Question and Answer
Instructions: Provide the correct answer to the following medical question.

Question: {question_text}

Complete answer:"""

            return {
                'input': input_text,
                'output': correct_answer
            }
        except Exception as e:
            print(f"Error formatting fulltext question: {str(e)}")
            print(f"Item structure: {json.dumps(item, indent=2)}")
            return None

    def prepare_training_data(self, questions: List[Dict], format_type: str) -> Tuple[Dataset, Dataset]:
        """Prepare and split training data for specified format"""
        print(f"Preparing {format_type} training data...")
        training_data = []

        format_func = self.format_qa_question if format_type == 'qa' else self.format_fulltext_question

        for i, question in enumerate(questions):
            try:
                formatted = format_func(question)
                if formatted:
                    training_data.append(formatted)

                if (i + 1) % 100 == 0:
                    print(f"Processed {i + 1} questions...")

            except Exception as e:
                print(f"Error processing question {i}: {str(e)}")
                continue

        print(f"Created {len(training_data)} training examples")

        if len(training_data) == 0:
            raise ValueError("No valid training examples were created!")

        # Convert to DataFrame
        df = pd.DataFrame(training_data)

        # Tokenize inputs and outputs
        inputs = list(df["input"])
        outputs = list(df["output"])

        # Tokenize with padding and truncation
        max_length = 16 if format_type == 'qa' else 128
        tokenized_inputs = self.tokenizer(
            inputs, padding=True, truncation=True, max_length=512, return_tensors="pt"
        )
        tokenized_outputs = self.tokenizer(
            outputs, padding=True, truncation=True, max_length=max_length, return_tensors="pt"
        )

        # Create tokenized dataset
        tokenized_data = {
            "input_ids": tokenized_inputs["input_ids"],
            "attention_mask": tokenized_inputs["attention_mask"],
            "labels": tokenized_outputs["input_ids"],
        }

        # Split data
        total_examples = len(tokenized_data["input_ids"])
        train_size = int(0.9 * total_examples)

        train_data = {k: v[:train_size] for k, v in tokenized_data.items()}
        eval_data = {k: v[train_size:] for k, v in tokenized_data.items()}

        train_dataset = Dataset.from_dict(train_data)
        eval_dataset = Dataset.from_dict(eval_data)

        return train_dataset, eval_dataset

    def train(self, train_dataset: Dataset, eval_dataset: Dataset, model_type: str):
        """Train the specified model type"""
        print(f"Starting {model_type} training...")

        # Disable W&B integration
        os.environ["WANDB_DISABLED"] = "true"

        model = self.qa_model if model_type == 'qa' else self.fulltext_model
        output_dir = self.qa_output_dir if model_type == 'qa' else self.fulltext_output_dir

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=5,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            warmup_steps=200,
            weight_decay=0.01,
            logging_dir=f'{output_dir}/logs',
            logging_steps=10,
            evaluation_strategy="steps",
            eval_steps=100,
            save_strategy="steps",
            save_steps=100,
            load_best_model_at_end=True,
            save_total_limit=2,
            fp16=torch.cuda.is_available(),
            learning_rate=2e-5,
            gradient_accumulation_steps=4,
            max_grad_norm=0.5,
        )

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=model,
            padding=True,
            max_length=512
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            tokenizer=self.tokenizer
        )

        print(f"Training {model_type} model...")
        trainer.train()

        print(f"Saving {model_type} model...")
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")

def main():
    print("Starting MIR FLAN-T5 Dual Training Pipeline...")

    try:
        # Mount Google Drive
        drive.mount('/content/drive')

        # Define paths
        base_dir = '/content/drive/MyDrive/TFM2'

        # Define model names with timestamp for versioning
        timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        qa_model_name = f"mir_flan_t5_qa_v2_{timestamp}"
        fulltext_model_name = f"mir_flan_t5_fulltext_v2_{timestamp}"

        qa_output_dir = f"{base_dir}/models/{qa_model_name}"
        fulltext_output_dir = f"{base_dir}/models/{fulltext_model_name}"

        # Initialize trainer
        trainer = MIRFlanT5DualTrainer(
            base_model="google/flan-t5-base",
            qa_output_dir=qa_output_dir,
            fulltext_output_dir=fulltext_output_dir
        )

        # Train QA Model
        print("\nStarting QA Model Training...")
        qa_data_path = f"{base_dir}/meli-training-content/full_context/flan_t5_training.json"  # Changed path
        with open(qa_data_path, 'r', encoding='utf-8') as f:
            qa_questions = json.load(f)
        print(f"Loaded {len(qa_questions)} QA questions")

        qa_train_dataset, qa_eval_dataset = trainer.prepare_training_data(qa_questions, 'qa')
        trainer.train(qa_train_dataset, qa_eval_dataset, 'qa')

        # Train Full Text Model
        print("\nStarting Full Text Model Training...")
        fulltext_data_path = f"{base_dir}/meli-training-content/full_context/flan_t5_training.json"
        with open(fulltext_data_path, 'r', encoding='utf-8') as f:
            fulltext_questions = json.load(f)
        print(f"Loaded {len(fulltext_questions)} full text questions")

        fulltext_train_dataset, fulltext_eval_dataset = trainer.prepare_training_data(fulltext_questions, 'fulltext')
        trainer.train(fulltext_train_dataset, fulltext_eval_dataset, 'fulltext')

        print("\nTraining complete!")
        print(f"QA Model saved to: {qa_output_dir}")
        print(f"Full Text Model saved to: {fulltext_output_dir}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise

Found existing installation: fsspec 2024.6.1
Uninstalling fsspec-2024.6.1:
  Successfully uninstalled fsspec-2024.6.1
Found existing installation: gcsfs 2024.6.1
Uninstalling gcsfs-2024.6.1:
  Successfully uninstalled gcsfs-2024.6.1
Starting MIR FLAN-T5 Dual Training Pipeline...
Mounted at /content/drive
Initializing MIR FLAN-T5 Dual Trainer...
Loading base model: google/flan-t5-base


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Using device: cuda

Starting QA Model Training...
Loaded 619 QA questions
Preparing qa training data...
Processed 100 questions...
Processed 200 questions...
Processed 300 questions...
Processed 400 questions...
Processed 500 questions...
Processed 600 questions...
Created 619 training examples


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting qa training...


  trainer = Trainer(


Training qa model...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss


Saving qa model...
Model saved to /content/drive/MyDrive/TFM2/models/mir_flan_t5_qa_v2_20250119_1623

Starting Full Text Model Training...
Loaded 619 full text questions
Preparing fulltext training data...
Processed 100 questions...
Processed 200 questions...
Processed 300 questions...
Processed 400 questions...
Processed 500 questions...
Processed 600 questions...
Created 619 training examples


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting fulltext training...
Training fulltext model...


  trainer = Trainer(


Step,Training Loss,Validation Loss


Saving fulltext model...
Model saved to /content/drive/MyDrive/TFM2/models/mir_flan_t5_fulltext_v2_20250119_1623

Training complete!
QA Model saved to: /content/drive/MyDrive/TFM2/models/mir_flan_t5_qa_v2_20250119_1623
Full Text Model saved to: /content/drive/MyDrive/TFM2/models/mir_flan_t5_fulltext_v2_20250119_1623
