In [None]:
# Install required packages
!pip install -q google-cloud-storage google-cloud-aiplatform vertexai
!pip install --upgrade google-cloud-aiplatform
!pip install --upgrade vertexai

# Import required packages
import os
import csv
import json
import pandas as pd
import numpy as np
from google.cloud import storage
from google.cloud import aiplatform
from google.colab import auth
from datetime import datetime
import sys
import vertexai
from vertexai.generative_models import GenerativeModel
from io import StringIO
from typing import Dict, List, Any, Optional

# Global Configuration
CONFIG = {
    # Google Cloud Settings
    'PROJECT_ID': 'label-studio-424123',
    'LOCATION': 'us-central1',

    # Storage Buckets
    'INPUT_BUCKET': 'accurate-aligned-datasets',
    'OUTPUT_BUCKET': 'unranked-training-data',
    'OUTPUT_PREFIX': 'unranked-text-chunks/',
    'PROGRESS_FOLDER': 'progress-tracking/',

    # Model Settings
    'MODEL_NAME': 'gemini-1.5-pro',
    'TEMPERATURE': 0.3,
    'MAX_TOKENS': 4000,
    'TOP_P': 0.9,

    # File Processing
    'VALID_EXTENSIONS': ('.csv', '.json', '.jsonl'),

    # Batch Processing
    'BATCH_SIZE': 1000,  # Number of rows to process before logging

    # Error Handling
    'MAX_RETRIES': 3,
    'RETRY_DELAY': 1,  # seconds
}

# Initialize CSV field size limit
maxInt = sys.maxsize
while True:
    try:
        csv.field_size_limit(maxInt)
        break
    except OverflowError:
        maxInt = int(maxInt/10)

# Authenticate
auth.authenticate_user()

# Initialize Vertex AI
aiplatform.init(project=CONFIG['PROJECT_ID'], location=CONFIG['LOCATION'])

class ProgressTracker:
    def __init__(self, bucket_name: str, tracker_prefix: str = CONFIG['PROGRESS_FOLDER']):
        self.storage_client = storage.Client()
        self.bucket = self.storage_client.get_bucket(bucket_name)
        self.tracker_prefix = tracker_prefix
        # Ensure the tracker prefix ends with a forward slash
        if not self.tracker_prefix.endswith('/'):
            self.tracker_prefix += '/'
        self.tracker_file = f"{self.tracker_prefix}text_chunk_processing_progress.json"

    def load_progress(self) -> Dict:
        """Load progress from storage"""
        try:
            blob = self.bucket.blob(self.tracker_file)
            if blob.exists():
                content = blob.download_as_string().decode('utf-8')
                return json.loads(content)
            else:
                print(f"No existing progress file found at {self.tracker_file}")
        except Exception as e:
            print(f"Error loading progress: {str(e)}")
        return {'last_processed_file': None, 'last_processed_index': -1}

    def save_progress(self, current_file: str, current_index: int):
        """Save progress to storage"""
        try:
            progress = {
                'last_processed_file': current_file,
                'last_processed_index': current_index,
                'timestamp': datetime.now().isoformat()
            }
            blob = self.bucket.blob(self.tracker_file)
            blob.upload_from_string(
                json.dumps(progress, indent=2),
                content_type='application/json'
            )
            print(f"Progress saved to {self.tracker_file}")
        except Exception as e:
            print(f"Error saving progress: {str(e)}")

class DataReformatter:
    def __init__(self):
        self.storage_client = storage.Client()
        vertexai.init(project=CONFIG['PROJECT_ID'], location=CONFIG['LOCATION'])
        self.model = GenerativeModel(CONFIG['MODEL_NAME'])
        self.progress_tracker = ProgressTracker(
            CONFIG['OUTPUT_BUCKET'],
            CONFIG['PROGRESS_FOLDER']
        )

        self.prompt_template = """
        === CONTEXT ===
        {data}

        === TASK ===
        Analyze and rewrite the above information with these specific goals:

        1. ACCURACY:
        - Preserve all key facts and details exactly as presented
        - Do not add interpretations or assumptions
        - Maintain technical precision where present

        2. CLARITY:
        - Use clear, straightforward language
        - Present information in a logical order
        - Break complex ideas into digestible parts
        - Remove redundancy without losing meaning

        3. CONCISENESS:
        - Be brief but complete
        - Use precise vocabulary
        - Eliminate unnecessary words
        - Keep sentences focused and direct

        4. ENTITY EXTRACTION:
        - Groups: Any collective entity including:
          * Companies
          * Organizations
          * Institutions
          * Industries
          * Categories of people
          * Any other collective entity
        - Individuals: Specific named people
        - Locations: Geographical references
        - Concepts: Abstract ideas, terms, metrics
        - Species: Any animals or species mentioned
        - Events: Specific occurrences or dates

        For example, in "PETA and The Humane Society filed a joint lawsuit against Tyson Foods after whistleblower Jane Smith documented chickens and pigs being abused at their North Carolina facility during the 2023 holiday season", we would have:
        - mentioned_group: ["PETA", "The Humane Society", "Tyson Foods"]
        - mentioned_individual: ["Jane Smith"]
        - mentioned_location: ["North Carolina"]
        - mentioned_species: ["chickens", "pigs"]
        - mentioned_concept: ["animal abuse", "whistleblowing"]
        - mentioned_events: ["2023 holiday season lawsuit"]

        Or in "Local activists from Direct Action Everywhere disrupted a McDonald's board meeting in Chicago to protest their continued use of caged hens in their egg supply chain, while Mercy For Animals released undercover footage of dairy cows at California factory farms", we would have:
        - mentioned_group: ["Direct Action Everywhere", "McDonald's", "Mercy For Animals"]
        - mentioned_location: ["Chicago", "California"]
        - mentioned_species: ["hens", "dairy cows"]
        - mentioned_concept: ["animal confinement", "undercover investigation", "supply chain", "factory farming"]
        - mentioned_events: ["board meeting disruption", "footage release"]

        Your task is to rewrite the information maintaining perfect accuracy while maximizing clarity and conciseness. Extract all relevant entities.

        === OUTPUT FORMAT ===
        Return a JSON object following this exact schema:
        {{
            "summary": "Concise 1-2 sentence summary",
            "main_text": "Complete rewritten text preserving all key details",
            "content_type": "article|report|news|blog|social_media|research|other",
            "source_url": [],  # List of URLs found in the content. If none found, leave empty. Look for:
                # - Web addresses (http://, https://, www.)
                # - Social media links
                # - Document links
                # - Any other referenced URLs or sources
                For example:
                Input: "According to a new study published at https://example.com/study123, while Facebook posts at facebook.com/vegangroup show..."
                url: ["https://example.com/study123", "https://facebook.com/vegangroup"]
            "individual_authors": [],  # List of individual author names
            "group_authors": [],  # List of group/organization authors
            "language": ["en"],  # Language codes
            "date": "1985-04-12T23:20:50.52Z",  # Use RFC 3339 timestamp in the date-time format.
            "related_entities": {{
                "individuals": [],  # List of mentioned individuals
                "groups": [],      # List of mentioned groups/organizations
                "species": [],     # List of mentioned species
                "locations": [],   # List of mentioned locations
                "events": []       # List of mentioned events
            }},
            "tags": []  # List of relevant topic tags
        }}

        Return only the JSON object, no additional text or explanation."""

    def extract_json_from_response(self, response_text: str) -> str:
        """Extract JSON from the response text, handling potential ```json prefix"""
        cleaned_text = response_text.strip()
        if cleaned_text.startswith('```json'):
            cleaned_text = cleaned_text[7:]
        if cleaned_text.startswith('```'):
            cleaned_text = cleaned_text[3:]
        if cleaned_text.endswith('```'):
            cleaned_text = cleaned_text[:-3]
        return cleaned_text.strip()

    def verify_accuracy(self, original_data: str, reformatted_text: str) -> float:
        """
        Verify the accuracy of reformatted text compared to original data.
        Returns a score between 0 and 1.
        """
        verification_prompt = """
        === ORIGINAL TEXT ===
        {original}

        === REFORMATTED TEXT ===
        {reformatted}

        === TASK ===
        Compare the original text with the reformatted version and score the accuracy of information preservation on a scale of 0 to 1.

        Focus on:
        1. Factual accuracy - are all key facts preserved exactly?
        2. Completeness - is any important information missing?
        3. Precision - are technical details maintained accurately?
        4. No distortion - is anything misrepresented or changed in meaning?
        5. No additions - is there any information added that wasn't in the original?

        Return only a single number between 0 and 1, where:

        1.0 = Perfect preservation
        - Every single fact, detail, and nuance is perfectly preserved
        - No information loss whatsoever
        - Maintains exact technical precision
        - Perfect preservation of context and relationships between facts

        0.9 = Near-perfect preservation
        - All key facts and important details preserved
        - Technical precision maintained

        0.8 = Very good preservation
        - All core facts and most details preserved
        - Technical accuracy maintained for important concepts
        - Some very minor details might be simplified, but still accurate

        0.7 = Good preservation with minor issues
        - All main points preserved
        - Some secondary details slightly modified
        - Technical precision slightly reduced in a way that isn't entirely accurate
        - Context mostly maintained, but simplified in a way that loses some meaningful details

        0.6 = Adequate preservation with noticeable gaps
        - Core message preserved but some details lost
        - Several secondary points missing or modified
        - Some technical precision lost
        - Context partially simplified

        0.5 = Partial preservation with significant issues
        - Main points present but some misrepresented
        - Important secondary details missing
        - Technical precision significantly reduced
        - Context partially lost or modified

        0.4 = Problematic preservation
        - Some main points misrepresented
        - Many important details missing
        - Technical accuracy compromised
        - Context often unclear or modified

        0.3 = Poor preservation
        - Multiple main points missing or wrong
        - Most important details lost
        - Technical aspects mostly incorrect
        - Context largely lost

        0.2 = Very poor preservation
        - Most main points missing or incorrect
        - Almost all details lost
        - Technical aspects wrong
        - Context missing or wrong

        0.1 = Severe distortion
        - Almost all information wrong or missing
        - Complete loss of important details
        - Technical aspects completely wrong
        - Context entirely lost

        0.0 = Complete failure
        - No accurate information preserved
        - Content completely different from original
        - No technical accuracy
        - Wrong context entirely

        Score must be exactly one of these values: 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0
        Return only the number, no explanation or other text.
        """

        try:
            # Check if reformatted_text is a dictionary
            if isinstance(reformatted_text, dict):
                # Combine summary and main_text from the dictionary
                reformatted_content = f"{reformatted_text.get('summary', '')} {reformatted_text.get('main_text', '')}"
            else:
                reformatted_content = reformatted_text

            response = self.model.generate_content(
                verification_prompt.format(
                    original=original_data,
                    reformatted=reformatted_content
                ),
                generation_config={
                    "temperature": 0.1,
                    "max_output_tokens": 128,
                    "top_p": 0.9,
                }
            )

            score_text = response.text.strip()
            try:
                score = float(score_text)
                return max(0.0, min(1.0, score))
            except ValueError:
                print(f"Error parsing accuracy score: {score_text}")
                return 0.0

        except Exception as e:
            print(f"Error in accuracy verification: {str(e)}")
            return 0.0

    def reformat_data(self, data: str, source_file: str, item_id: int):
        max_attempts = CONFIG['MAX_RETRIES']
        attempt = 0

        while attempt < max_attempts:
            try:
                prompt = self.prompt_template.format(
                    data=data,
                    source=source_file,
                    item_id=item_id
                )

                response = self.model.generate_content(
                    prompt,
                    generation_config={
                        "temperature": CONFIG['TEMPERATURE'],
                        "max_output_tokens": CONFIG['MAX_TOKENS'],
                        "top_p": CONFIG['TOP_P'],
                    }
                )

                try:
                    cleaned_response = self.extract_json_from_response(response.text)
                    json_response = json.loads(cleaned_response)

                    # Verify accuracy but don't include it in the output
                    accuracy_score = self.verify_accuracy(data, json_response)
                    print(f"Accuracy score: {accuracy_score}")

                    if accuracy_score >= 0.8:  # Only accept responses with high accuracy
                        return json_response  # Return without adding accuracy_score
                    else:
                        print(f"Attempt {attempt + 1}: Accuracy score too low ({accuracy_score}), retrying...")
                        attempt += 1
                        continue

                except json.JSONDecodeError as e:
                    print(f"JSON parsing error: {str(e)}")
                    print(f"Problematic text: {cleaned_response}")
                    attempt += 1
                    continue

            except Exception as e:
                print(f"Error reformatting data: {str(e)}")
                attempt += 1
                continue

        print("Max attempts reached")
        return {}

    def create_meaningful_filename(self, data: Dict, fact_num: int) -> str:
        """Create a filename that reflects the content"""
        # Get meaningful elements
        groups = data.get('related_entities', {}).get('groups', [])
        species = data.get('related_entities', {}).get('species', [])
        tags = data.get('tags', [])

        elements = []

        # Add primary group if exists
        if groups:
            elements.append(groups[0].replace(' ', '_').lower()[:30])

        # Add primary species if exists
        if species:
            elements.append(species[0].replace(' ', '_').lower()[:30])

        # Add primary tag if exists
        if tags:
            elements.append(tags[0].replace(' ', '_').lower()[:30])

        # If no elements found, use first few words of summary
        if not elements and data.get('summary'):
            first_words = '_'.join(data['summary'].split()[:3]).lower()
            elements.append(first_words[:30])

        timestamp = datetime.now().strftime('%Y%m%d')
        meaningful_part = '_'.join(elements)
        filename = f"{meaningful_part}_{timestamp}_{fact_num}.json"
        filename = ''.join(c for c in filename if c.isalnum() or c in ['_', '-', '.'])

        return filename

    def process_bucket(self, input_bucket: str = CONFIG['INPUT_BUCKET'],
                      output_bucket: str = CONFIG['OUTPUT_BUCKET'],
                      output_prefix: str = CONFIG['OUTPUT_PREFIX']):
        """
        Process all files in the input bucket, transform them, and save to output bucket.
        Maintains progress and handles errors gracefully.
        """
        try:
            print("\nAccessing storage buckets...")
            output_bucket_obj = self.storage_client.get_bucket(output_bucket)
            all_facts = self.load_all_data(input_bucket)

            if not all_facts:
                print("No facts to process")
                return

            print(f"\nStarting processing of {len(all_facts)} facts...")

            for fact_num, fact in enumerate(all_facts):
                try:
                    print(f"\nProcessing fact {fact_num + 1}/{len(all_facts)}")
                    print(f"Source file: {fact['source_file']}")
                    print(f"Index: {fact.get('index', 'N/A')}")

                    # Attempt to reformat the data
                    reformatted = self.reformat_data(
                        fact['fact_text'],
                        fact['source_file'],
                        fact_num
                    )

                    if reformatted:
                        # Prepare output data
                        output_data = {
                            **reformatted,
                        }

                        try:
                            # Create meaningful filename
                            filename = self.create_meaningful_filename(reformatted, fact_num)

                            # Ensure proper path construction
                            output_path = os.path.join(output_prefix, filename).replace('\\', '/')

                            # Upload to storage
                            output_blob = output_bucket_obj.blob(output_path)
                            output_blob.upload_from_string(
                                json.dumps(reformatted, indent=2),
                                content_type='application/json'
                            )

                            # Update progress
                            try:
                                self.progress_tracker.save_progress(
                                    fact['source_file'],
                                    fact.get('index', fact_num)
                                )
                                print("Progress saved successfully")
                            except Exception as e:
                                print(f"Warning: Failed to save progress: {str(e)}")

                        except Exception as e:
                            print(f"Error saving output file: {str(e)}")
                            continue

                    else:
                        print(f"Warning: Failed to reformat fact {fact_num + 1}")
                        continue

                except Exception as e:
                    print(f"Error processing fact {fact_num + 1}: {str(e)}")
                    print("Continuing with next fact...")
                    continue

            print("\nProcessing complete!")
            print(f"Processed {len(all_facts)} facts")

        except Exception as e:
            print(f"Critical error in process_bucket: {str(e)}")
            raise

    def list_all_blobs(self, bucket: storage.bucket.Bucket) -> List[storage.blob.Blob]:
        """Recursively list all blobs in bucket including those in subfolders"""
        all_blobs = []

        # Get all prefixes (folders) first
        prefixes = set()
        for blob in bucket.list_blobs():
            if '/' in blob.name:
                prefix = blob.name.split('/')[0]
                prefixes.add(prefix + '/')

        print(f"Found folders: {prefixes}")

        # Process root files
        for blob in bucket.list_blobs():
            if any(blob.name.lower().endswith(ext) for ext in CONFIG['VALID_EXTENSIONS']):
                print(f"Adding root blob: {blob.name}")
                all_blobs.append(blob)

        # Process each prefix (folder)
        for prefix in prefixes:
            print(f"\nProcessing folder: {prefix}")
            folder_blobs = bucket.list_blobs(prefix=prefix)
            for blob in folder_blobs:
                if any(blob.name.lower().endswith(ext) for ext in CONFIG['VALID_EXTENSIONS']):
                    print(f"Adding blob: {blob.name}")
                    all_blobs.append(blob)

        # Sort blobs by name for consistent processing
        all_blobs.sort(key=lambda x: x.name)

        print(f"\nTotal valid files found: {len(all_blobs)}")
        for blob in all_blobs:
            print(f"- {blob.name}")

        return all_blobs

    def load_all_data(self, input_bucket: str) -> List[Dict[str, Any]]:
        print("\nLoading data from files...")
        all_facts = []
        input_bucket_obj = self.storage_client.get_bucket(input_bucket)

        # Get progress
        progress = self.progress_tracker.load_progress()
        last_file = progress['last_processed_file']
        last_index = progress['last_processed_index']

        # Get all blobs including those in subfolders
        blobs = self.list_all_blobs(input_bucket_obj)

        # Sort blobs by name for consistent ordering
        blobs.sort(key=lambda x: x.name)

        # Find where to resume from
        start_processing = False if last_file else True

        for blob in blobs:
            # Skip until we reach the last processed file
            if not start_processing and last_file:
                if blob.name == last_file:
                    start_processing = True
                continue

            print(f"\nLoading {blob.name}")
            try:
                content = blob.download_as_string().decode('utf-8')

                if blob.name.lower().endswith('.jsonl'):
                    for i, line in enumerate(content.splitlines()):
                        # Skip until we reach the last processed index for the last file
                        if blob.name == last_file and i <= last_index:
                            continue

                        if line.strip():
                            all_facts.append({
                                "source_file": blob.name,
                                "fact_text": line.strip(),
                                "index": i
                            })

                elif blob.name.lower().endswith('.json'):
                    df = pd.DataFrame([json.loads(content)])
                    for i, row in df.iterrows():
                        if blob.name == last_file and i <= last_index:
                            continue

                        fact_text = " ".join(f"{k}: {v}" for k, v in row.items() if pd.notna(v))
                        if fact_text.strip():
                            all_facts.append({
                                "source_file": blob.name,
                                "fact_text": fact_text,
                                "index": i
                            })

                else:  # CSV files
                    df = pd.read_csv(StringIO(content))
                    for i, row in df.iterrows():
                        if blob.name == last_file and i <= last_index:
                            continue

                        fact_text = " ".join(f"{k}: {v}" for k, v in row.items() if pd.notna(v))
                        if fact_text.strip():
                            all_facts.append({
                                "source_file": blob.name,
                                "fact_text": fact_text,
                                "index": i
                            })

            except Exception as e:
                print(f"Error processing {blob.name}: {str(e)}")
                continue

        print(f"Total facts loaded: {len(all_facts)}")
        return all_facts

def main():
    try:
        reformatter = DataReformatter()
        print("Starting data reformatting...")

        # Use explicit path with trailing slash
        reformatter.process_bucket(
            input_bucket=CONFIG['INPUT_BUCKET'],
            output_bucket=CONFIG['OUTPUT_BUCKET'],
            output_prefix=CONFIG['OUTPUT_PREFIX']
        )

        print("Data reformatting complete!")

    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        raise

if __name__ == "__main__":
    main()
