In [6]:
#!/usr/bin/env python3

import json
import os
import shutil
import random
import argparse
from typing import Dict, List, Any, Tuple
from collections import defaultdict
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [7]:
class DatasetProcessor:
    """Processes RADQA and ORCA datasets with upload functionality"""

    def __init__(self, radqa_ratio: float = 0.75, eval_split: float = 0.1, seed: int = 42):
        self.radqa_ratio = radqa_ratio
        self.orca_ratio = 1.0 - radqa_ratio
        self.eval_split = eval_split
        self.seed = seed
        random.seed(seed)

        # Schema for unified output format
        self.unified_schema = {
            "instruction": "",
            "input": "",
            "output": ""
        }

    def upload_files_to_drive(self):
        """Handle file uploads from local computer to Google Drive"""
        print("📤 UPLOAD FILES TO GOOGLE DRIVE")
        print("=" * 50)

        try:
            from google.colab import files
            uploaded_files = {}

            print("Please upload your RADQA dataset files:")
            print("- Expected files: train.json, dev.json, or test.json")
            print("- Upload one or more RADQA JSON files")

            # Upload RADQA files
            print("\n🔄 Click 'Choose Files' and select your RADQA JSON file(s)...")
            radqa_files = files.upload()

            for filename, content in radqa_files.items():
                if filename.endswith('.json'):
                    uploaded_files[f'radqa_{filename}'] = filename
                    print(f"✅ Uploaded RADQA file: {filename}")
                else:
                    print(f"⚠️ Skipping non-JSON file: {filename}")

            # Ask about ORCA files
            upload_orca = input("\n📁 Do you want to upload ORCA dataset? (y/n): ").strip().lower()

            if upload_orca in ['y', 'yes']:
                print("\nPlease upload your ORCA dataset file:")
                print("- Expected: JSON file with instruction-following data")
                print("\n🔄 Click 'Choose Files' and select your ORCA JSON file...")

                orca_files = files.upload()

                for filename, content in orca_files.items():
                    if filename.endswith('.json'):
                        uploaded_files[f'orca_{filename}'] = filename
                        print(f"✅ Uploaded ORCA file: {filename}")
                    else:
                        print(f"⚠️ Skipping non-JSON file: {filename}")

            # Move files to organized location in Drive with better error handling
            drive_folder = '/content/drive/MyDrive/datasets'

            try:
                os.makedirs(drive_folder, exist_ok=True)
                print(f"📁 Created/verified directory: {drive_folder}")
            except Exception as dir_error:
                print(f"⚠️ Could not create directory {drive_folder}: {dir_error}")
                print("📁 Will use current directory instead")
                drive_folder = '/content'

            organized_files = {}
            for key, filename in uploaded_files.items():
                try:
                    new_path = os.path.join(drive_folder, filename)

                    if os.path.exists(filename):  # File exists in current directory
                        if drive_folder != '/content':  # Only copy if we're moving to Drive
                            # Use shutil.copy2 instead of os.rename to handle cross-device links
                            shutil.copy2(filename, new_path)
                            print(f"📁 Copied {filename} to {new_path}")

                            # Verify the copy was successful
                            if os.path.exists(new_path):
                                organized_files[key] = new_path
                                # Clean up original file after successful copy
                                try:
                                    os.remove(filename)
                                    print(f"🗑️ Cleaned up original file: {filename}")
                                except Exception as cleanup_error:
                                    print(f"⚠️ Could not remove original file {filename}: {cleanup_error}")
                            else:
                                print(f"❌ Copy verification failed for {filename}")
                                organized_files[key] = filename
                        else:
                            # Using current directory
                            organized_files[key] = filename
                            print(f"📁 Using file in current location: {filename}")
                    else:
                        print(f"⚠️ File not found: {filename}")

                except Exception as file_error:
                    print(f"❌ Error processing file {filename}: {file_error}")
                    # Fallback: try to use the file in current location
                    if os.path.exists(filename):
                        organized_files[key] = filename
                        print(f"📁 Fallback: using file in current location: {filename}")

            if organized_files:
                print(f"\n✅ Successfully organized {len(organized_files)} files:")
                for key, path in organized_files.items():
                    print(f"   - {key}: {path}")
            else:
                print(f"\n❌ No files were successfully organized")

            return organized_files

        except ImportError:
            print("❌ File upload is only available in Google Colab")
            return {}
        except Exception as e:
            print(f"❌ Error during file upload: {e}")
            return {}

    def find_uploaded_files(self, uploaded_files: Dict[str, str]) -> Tuple[str, str]:
        """Identify RADQA and ORCA files from uploaded files"""
        radqa_path = None
        orca_path = None

        # Find RADQA file
        for key, path in uploaded_files.items():
            if key.startswith('radqa_') and os.path.exists(path):
                radqa_path = path
                print(f"🎯 Selected RADQA file: {path}")
                break

        # Find ORCA file
        for key, path in uploaded_files.items():
            if key.startswith('orca_') and os.path.exists(path):
                orca_path = path
                print(f"🎯 Selected ORCA file: {path}")
                break

        return radqa_path, orca_path

    def load_radqa_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        """Load RADQA dataset from JSON file"""
        logger.info(f"Loading RADQA dataset from {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            samples = []

            if 'data' in data:
                for article in data['data']:
                    title = article.get('title', '')

                    for paragraph in article.get('paragraphs', []):
                        context = paragraph.get('context', '')
                        document_id = paragraph.get('document_id', '')

                        for qa in paragraph.get('qas', []):
                            question = qa.get('question', '')
                            qa_id = qa.get('id', '')
                            is_impossible = qa.get('is_impossible', False)
                            answers = qa.get('answers', [])

                            sample = {
                                'title': title,
                                'document_id': document_id,
                                'context': context,
                                'question': question,
                                'qa_id': qa_id,
                                'is_impossible': is_impossible,
                                'answers': answers
                            }
                            samples.append(sample)

            logger.info(f"Loaded {len(samples)} RADQA samples.")
            return samples

        except Exception as e:
            logger.error(f"Error loading RADQA dataset: {e}")
            raise

    def load_orca_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        """Load ORCA dataset from JSON file"""
        logger.info(f"Loading Orca dataset from {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            if isinstance(data, dict):
                for key in ['data', 'samples', 'examples', 'train']:
                    if key in data:
                        data = data[key]
                        break
                else:
                    data = [data]

            logger.info(f"Loaded {len(data)} Orca samples.")
            return data

        except Exception as e:
            logger.error(f"Error loading Orca dataset: {e}")
            raise

    def clean_text(self, text: str) -> str:
        """Clean and normalize text"""
        if not isinstance(text, str):
            text = str(text)
        text = ' '.join(text.split())
        text = text.replace('\r', ' ').replace('\n', ' ')
        return text.strip()

    def standardize_radqa_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        """Convert RADQA sample to unified format"""
        question = sample.get('question', '')
        context = sample.get('context', '')
        is_impossible = sample.get('is_impossible', False)
        answers = sample.get('answers', [])

        # Clean text
        question = self.clean_text(question)
        context = self.clean_text(context)

        # Create instruction for reading comprehension
        instruction = f"Context: {context}\n\nQuestion: {question}\n\nPlease provide a comprehensive answer based on the given radiology report."

        # Create output based on answers
        if is_impossible or not answers:
            output = "Based on the provided radiology report, this question cannot be answered as the required information is not available in the given context."
        else:
            if len(answers) == 1:
                answer_text = answers[0].get('text', '')
                answer_text = self.clean_text(answer_text)
                output = f"Based on the radiology report: {answer_text}"
            else:
                answer_texts = []
                for ans in answers:
                    ans_text = self.clean_text(ans.get('text', ''))
                    if ans_text and ans_text not in answer_texts:
                        answer_texts.append(ans_text)

                if answer_texts:
                    output = f"Based on the radiology report: {' / '.join(answer_texts)}"
                else:
                    output = "Based on the provided radiology report, this question cannot be answered as the required information is not available in the given context."

        return {
            "instruction": instruction,
            "input": "",
            "output": output
        }

    def standardize_orca_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        """Convert ORCA sample to unified format"""
        if 'instruction' in sample:
            instruction = sample.get('instruction', '')
            input_text = sample.get('input', '')
            output_text = sample.get('output', sample.get('response', ''))
        elif 'question' in sample:
            instruction = sample.get('question', '')
            input_text = sample.get('input', '')
            output_text = sample.get('response', sample.get('answer', ''))
        else:
            instruction = sample.get('prompt', sample.get('text', ''))
            input_text = ''
            output_text = sample.get('completion', sample.get('target', ''))

        # Clean text
        instruction = self.clean_text(instruction)
        input_text = self.clean_text(input_text)
        output_text = self.clean_text(output_text)

        return {
            "instruction": instruction,
            "input": input_text,
            "output": output_text
        }

    def filter_valid_samples(self, samples: List[Dict[str, str]], dataset_name: str) -> List[Dict[str, str]]:
        """Filter out invalid samples"""
        valid_samples = []

        for sample in samples:
            if (sample.get('instruction', '').strip() and
                sample.get('output', '').strip()):
                valid_samples.append(sample)

        filtered_count = len(samples) - len(valid_samples)
        if filtered_count > 0:
            logger.warning(f"Filtered {filtered_count} invalid samples from {dataset_name}")

        logger.info(f"Valid {dataset_name} samples: {len(valid_samples)}")
        return valid_samples

    def stratified_split(self, samples: List[Dict[str, str]], eval_ratio: float) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
        """Split samples into train and eval sets"""
        shuffled_samples = samples.copy()
        random.shuffle(shuffled_samples)

        eval_size = int(len(shuffled_samples) * eval_ratio)
        eval_samples = shuffled_samples[:eval_size]
        train_samples = shuffled_samples[eval_size:]

        return train_samples, eval_samples

    def process_datasets(self, radqa_path: str, orca_path: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
        """Process both datasets and create train/eval splits"""
        logger.info("Starting dataset processing.")

        # Load RADQA dataset
        radqa_raw = self.load_radqa_dataset(radqa_path)

        # Load ORCA dataset
        orca_raw = []
        if orca_path and os.path.exists(orca_path):
            try:
                orca_raw = self.load_orca_dataset(orca_path)
            except Exception as e:
                logger.warning(f"Could not load Orca dataset: {e}")
                orca_raw = []

        if not orca_raw:
            logger.warning("The Orca dataset is empty, using 100% RADQA.")
            self.radqa_ratio = 1.0
            self.orca_ratio = 0.0

        # Standardize RADQA samples
        logger.info("Standardizing RADQA samples.")
        if not radqa_raw:
            logger.error("No RADQA samples loaded.")
            raise ValueError("RADQA dataset is empty.")

        radqa_standardized = []
        for i, sample in enumerate(radqa_raw):
            try:
                standardized = self.standardize_radqa_sample(sample)
                radqa_standardized.append(standardized)
            except Exception as e:
                logger.warning(f"Error standardizing RADQA sample {i}: {e}")

        if not radqa_standardized:
            logger.error("No RADQA samples could be standardized.")
            raise ValueError("All RADQA samples failed standardization.")

        radqa_valid = self.filter_valid_samples(radqa_standardized, "RADQA")

        # Standardize ORCA samples
        if orca_raw:
            logger.info("Standardizing Orca samples.")
            orca_standardized = [self.standardize_orca_sample(sample) for sample in orca_raw]
            orca_valid = self.filter_valid_samples(orca_standardized, "Orca")
        else:
            logger.info("Skipping Orca processing.")
            orca_valid = []

        # Split datasets
        radqa_train, radqa_eval = self.stratified_split(radqa_valid, self.eval_split)

        if orca_valid:
            orca_train, orca_eval = self.stratified_split(orca_valid, self.eval_split)
        else:
            orca_train, orca_eval = [], []

        logger.info(f"RADQA splits - Train: {len(radqa_train)}, Eval: {len(radqa_eval)}")
        logger.info(f"Orca splits - Train: {len(orca_train)}, Eval: {len(orca_eval)}")

        # Apply ratios
        if orca_train:
            total_train_target = len(radqa_train) + len(orca_train)
            radqa_train_target = int(total_train_target * self.radqa_ratio)
            orca_train_target = int(total_train_target * self.orca_ratio)

            if len(radqa_train) > radqa_train_target:
                radqa_train = random.sample(radqa_train, radqa_train_target)
            if len(orca_train) > orca_train_target:
                orca_train = random.sample(orca_train, orca_train_target)

        # Mix datasets
        mixed_train = radqa_train + orca_train
        random.shuffle(mixed_train)

        mixed_eval = radqa_eval + orca_eval
        random.shuffle(mixed_eval)

        logger.info(f"Final mixed dataset - Train: {len(mixed_train)}, Eval: {len(mixed_eval)}")
        if orca_train:
            logger.info(f"Training composition - RADQA: {len(radqa_train)} ({len(radqa_train)/len(mixed_train)*100:.1f}%), "
                       f"Orca: {len(orca_train)} ({len(orca_train)/len(mixed_train)*100:.1f}%)")
        else:
            logger.info("Training composition - RADQA: 100%")

        return mixed_train, mixed_eval

    def save_dataset(self, samples: List[Dict[str, str]], output_path: str):
        """Save dataset to JSON file"""
        logger.info(f"Saving {len(samples)} samples to {output_path}")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(samples, f, indent=2, ensure_ascii=False)

        logger.info(f"Dataset saved successfully to {output_path}")

    def generate_dataset_stats(self, train_samples: List[Dict[str, str]], eval_samples: List[Dict[str, str]]) -> Dict[str, Any]:
        """Generate dataset statistics"""
        all_samples = train_samples + eval_samples
        total_samples = len(all_samples)

        if total_samples == 0:
            logger.warning("No samples found in datasets.")
            return {
                "total_samples": 0,
                "train_samples": 0,
                "eval_samples": 0,
                "error": "No valid samples found"
            }

        avg_instruction_length = sum(len(s['instruction'].split()) for s in all_samples) / total_samples
        avg_output_length = sum(len(s['output'].split()) for s in all_samples) / total_samples
        samples_with_input = sum(1 for s in all_samples if s['input'].strip())

        stats = {
            "total_samples": total_samples,
            "train_samples": len(train_samples),
            "eval_samples": len(eval_samples),
            "train_eval_ratio": f"{len(train_samples)}/{len(eval_samples)}",
            "avg_instruction_length": avg_instruction_length,
            "avg_output_length": avg_output_length,
            "samples_with_input": samples_with_input,
            "processing_config": {
                "radqa_ratio": self.radqa_ratio,
                "orca_ratio": self.orca_ratio,
                "eval_split": self.eval_split,
                "seed": self.seed
            }
        }

        return stats


def setup_colab_environment():
    """Setup Google Colab environment"""
    try:
        import google.colab
        IN_COLAB = True
        print("🚀 Running in Google Colab environment")
    except ImportError:
        IN_COLAB = False
        print("❌ Not running in Google Colab")

    if IN_COLAB:
        print("📦 Installing required packages...")
        os.system("pip install -q pandas scikit-learn")

        try:
            from google.colab import drive
            print("📂 Mounting Google Drive...")
            drive.mount('/content/drive')
            print("✅ Google Drive mounted successfully!")
        except Exception as e:
            print(f"❌ Could not mount Google Drive: {e}")
            return False

    return IN_COLAB


def run_upload_and_process():
    """Main function to upload files and process datasets"""
    print("🏥 RADQA + ORCA Dataset Processor with Upload")
    print("=" * 70)

    # Setup environment
    IN_COLAB = setup_colab_environment()

    if not IN_COLAB:
        print("❌ This script requires Google Colab for file upload functionality.")
        return

    # Initialize processor
    temp_processor = DatasetProcessor()

    # Upload files
    uploaded_files = temp_processor.upload_files_to_drive()

    if not uploaded_files:
        print("❌ No files were uploaded. Exiting.")
        return

    # Find uploaded files
    radqa_path, orca_path = temp_processor.find_uploaded_files(uploaded_files)

    if not radqa_path:
        print("❌ No RADQA file found. Please upload a valid RADQA JSON file.")
        return

    # Preview files
    print("\n" + "="*50)
    print("📋 FILE PREVIEW")
    print("="*50)

    try:
        with open(radqa_path, 'r') as f:
            radqa_preview = json.load(f)
        print(f"✅ RADQA file: {os.path.basename(radqa_path)}")
        if isinstance(radqa_preview, dict) and 'data' in radqa_preview:
            total_samples = sum(len(p.get('qas', [])) for article in radqa_preview['data']
                              for p in article.get('paragraphs', []))
            print(f"   📊 Total QA pairs: {total_samples}")
    except Exception as e:
        print(f"❌ Could not preview RADQA file: {e}")

    if orca_path:
        try:
            with open(orca_path, 'r') as f:
                orca_preview = json.load(f)
            print(f"✅ ORCA file: {os.path.basename(orca_path)}")
            if isinstance(orca_preview, list):
                print(f"   📊 Total samples: {len(orca_preview)}")
        except Exception as e:
            print(f"❌ Could not preview ORCA file: {e}")

    # Configuration
    print("\n" + "="*50)
    print("⚙️ CONFIGURATION")
    print("="*50)

    radqa_ratio = float(input("RADQA ratio (0.75 for 75% RADQA, 25% Orca): ").strip() or "0.75")
    eval_split = float(input("Evaluation split (0.1 for 10%): ").strip() or "0.1")
    seed = int(input("Random seed (42): ").strip() or "42")
    output_dir = input("Output directory (./data): ").strip() or "./data"

    # Validate inputs
    if not 0 < radqa_ratio <= 1:
        print(f"❌ RADQA ratio must be between 0 and 1, got {radqa_ratio}")
        return

    if not 0 < eval_split < 1:
        print(f"❌ Eval split must be between 0 and 1, got {eval_split}")
        return

    # Process datasets
    print(f"\n🚀 Initializing dataset processor...")
    processor = DatasetProcessor(
        radqa_ratio=radqa_ratio,
        eval_split=eval_split,
        seed=seed
    )

    try:
        print("📊 Processing datasets...")
        train_samples, eval_samples = processor.process_datasets(radqa_path, orca_path)

        print("💾 Saving datasets...")
        train_path = os.path.join(output_dir, 'train.json')
        eval_path = os.path.join(output_dir, 'eval.json')

        processor.save_dataset(train_samples, train_path)
        processor.save_dataset(eval_samples, eval_path)

        # Generate statistics
        stats = processor.generate_dataset_stats(train_samples, eval_samples)
        stats_path = os.path.join(output_dir, 'dataset_stats.json')

        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        # Display results
        print("\n🎉" + "="*58 + "🎉")
        print("   DATASET PREPARATION COMPLETED SUCCESSFULLY!")
        print("🎉" + "="*58 + "🎉")
        print(f"📈 Total samples: {stats['total_samples']:,}")
        print(f"🏋️ Training samples: {stats['train_samples']:,}")
        print(f"🧪 Evaluation samples: {stats['eval_samples']:,}")
        print(f"📝 Average instruction length: {stats['avg_instruction_length']:.1f} words")
        print(f"💬 Average output length: {stats['avg_output_length']:.1f} words")
        print(f"📁 Files saved to: {output_dir}")
        print(f"   - 📄 train.json: {stats['train_samples']:,} samples")
        print(f"   - 📄 eval.json: {stats['eval_samples']:,} samples")
        print(f"   - 📊 dataset_stats.json: Processing statistics")

        # Download files
        print("\n📥 Downloading processed files...")
        try:
            from google.colab import files
            files.download(train_path)
            files.download(eval_path)
            files.download(stats_path)
            print("✅ Files downloaded successfully!")
        except Exception as e:
            print(f"⚠️ Could not download files automatically: {e}")
            print(f"Files are available at: {output_dir}")

        print("\n🚀 Ready for model fine-tuning!")
        return train_path, eval_path, stats_path

    except Exception as e:
        print(f"❌ Error during dataset preparation: {e}")
        import traceback
        traceback.print_exc()
        raise


if __name__ == "__main__":
    run_upload_and_process()

🏥 RADQA + ORCA Dataset Processor with Upload
🚀 Running in Google Colab environment
📦 Installing required packages...
📂 Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted successfully!
📤 UPLOAD FILES TO GOOGLE DRIVE
Please upload your RADQA dataset files:
- Expected files: train.json, dev.json, or test.json
- Upload one or more RADQA JSON files

🔄 Click 'Choose Files' and select your RADQA JSON file(s)...


Saving train_radqa.json to train_radqa (1).json
✅ Uploaded RADQA file: train_radqa (1).json

📁 Do you want to upload ORCA dataset? (y/n): y

Please upload your ORCA dataset file:
- Expected: JSON file with instruction-following data

🔄 Click 'Choose Files' and select your ORCA JSON file...


Saving orca_subset.json to orca_subset (2).json
✅ Uploaded ORCA file: orca_subset (2).json
📁 Created/verified directory: /content/drive/MyDrive/datasets
📁 Copied train_radqa (1).json to /content/drive/MyDrive/datasets/train_radqa (1).json
🗑️ Cleaned up original file: train_radqa (1).json
📁 Copied orca_subset (2).json to /content/drive/MyDrive/datasets/orca_subset (2).json
🗑️ Cleaned up original file: orca_subset (2).json

✅ Successfully organized 2 files:
   - radqa_train_radqa (1).json: /content/drive/MyDrive/datasets/train_radqa (1).json
   - orca_orca_subset (2).json: /content/drive/MyDrive/datasets/orca_subset (2).json
🎯 Selected RADQA file: /content/drive/MyDrive/datasets/train_radqa (1).json
🎯 Selected ORCA file: /content/drive/MyDrive/datasets/orca_subset (2).json

📋 FILE PREVIEW
✅ RADQA file: train_radqa (1).json
   📊 Total QA pairs: 4878
✅ ORCA file: orca_subset (2).json
   📊 Total samples: 1000

⚙️ CONFIGURATION
RADQA ratio (0.75 for 75% RADQA, 25% Orca): 0.75
Evaluation spli

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Files downloaded successfully!

🚀 Ready for model fine-tuning!
