In [None]:
# ====================================================================
# COMPLETE GREENTEXT ML DATASET SCRAPER PIPELINE
# One-cell solution for Google Colab
# Purpose: Scrape r/wholesomegreentext for ML training dataset
# ====================================================================

# ============== INSTALLATION & SETUP ==============
import subprocess
import sys
import os

# Install required packages
def install_packages():
    packages = ['praw', 'pillow', 'tqdm', 'scikit-learn', 'requests', 'pandas']
    for package in packages:
        try:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '-q'])
        except:
            print(f"‚ö†Ô∏è Could not install {package}")
    print("‚úÖ Package installation complete!")

install_packages()

# Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_MOUNTED = True
    print("‚úÖ Google Drive mounted!")
except:
    DRIVE_MOUNTED = False
    print("‚ÑπÔ∏è Google Drive not available")

# ============== IMPORTS ==============
import praw
import pandas as pd
import requests
from PIL import Image
import io
import time
from datetime import datetime
import json
from pathlib import Path
import hashlib
from typing import List, Dict, Optional
from tqdm.notebook import tqdm
import zipfile
from sklearn.model_selection import train_test_split

# ============== CONFIGURATION ==============
# üîß CONFIGURE YOUR CREDENTIALS HERE üîß
REDDIT_CONFIG = {
    'client_id': 'iJnKy41u_V_xeQ9kz7tkWQ',          # Your Reddit client ID
    'client_secret': 'fb0VRa1Nua7blD_AO22oor2aFYPUuA',  # Your Reddit client secret
    'user_agent': 'GreentextML/1.0 by BRArjun_890'     # Your user agent
}

DATASET_CONFIG = {
    'subreddit': 'wholesomegreentext',
    'hot_posts': 500,           # Number of hot posts to scrape
    'top_posts': 500,          # Number of top posts to scrape
    'min_score': 50,            # Minimum upvotes for quality
    'save_to_drive': DRIVE_MOUNTED,
    'create_archive': True
}

print("üîß Configuration loaded!")

# ============== MAIN SCRAPER CLASS ==============
class CompleteGreentextPipeline:
    def __init__(self, reddit_config: dict, dataset_config: dict):
        """Initialize the complete pipeline"""
        self.reddit_config = reddit_config
        self.dataset_config = dataset_config

        # Initialize Reddit connection
        self.reddit = praw.Reddit(**reddit_config)

        # Test connection
        try:
            test_sub = self.reddit.subreddit('test')
            print(f"‚úÖ Reddit API connected successfully!")
        except Exception as e:
            raise Exception(f"‚ùå Reddit API connection failed: {e}")

        # Setup paths
        self.setup_directories()

    def setup_directories(self):
        """Setup directory structure"""
        base_name = "greentext_ml_dataset"

        if self.dataset_config['save_to_drive'] and DRIVE_MOUNTED:
            self.base_path = Path(f"/content/drive/MyDrive/{base_name}")
        else:
            self.base_path = Path(f"/content/{base_name}")

        self.images_path = self.base_path / "images"
        self.metadata_path = self.base_path / "metadata"

        # Create directories
        self.base_path.mkdir(exist_ok=True, parents=True)
        self.images_path.mkdir(exist_ok=True, parents=True)
        self.metadata_path.mkdir(exist_ok=True, parents=True)

        print(f"üìÅ Dataset directory: {self.base_path}")

    def is_valid_greentext_image(self, url: str) -> bool:
        """Check if URL is a valid greentext image"""
        url_lower = url.lower()

        valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp')
        valid_domains = ['i.redd.it', 'i.imgur.com', 'imgur.com', 'preview.redd.it']

        return (url_lower.endswith(valid_extensions) or
                any(domain in url_lower for domain in valid_domains))

    def download_image(self, url: str, filename: str, max_retries: int = 3) -> Optional[Dict]:
        """Download and save image with robust error handling"""
        headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}

        for attempt in range(max_retries):
            try:
                response = requests.get(url, headers=headers, timeout=30)

                # Handle different HTTP status codes
                if response.status_code == 404:
                    print(f"üîç Image not found (404): {url}")
                    return {'download_success': False, 'error': '404_not_found'}

                if response.status_code == 403:
                    print(f"üîí Access forbidden (403): {url}")
                    return {'download_success': False, 'error': '403_forbidden'}

                if response.status_code == 429:
                    print(f"‚è≥ Rate limited (429): {url} - Waiting...")
                    time.sleep(60)  # Wait 1 minute for rate limit
                    continue

                response.raise_for_status()

                # Verify content is actually an image
                content_type = response.headers.get('content-type', '').lower()
                if not any(img_type in content_type for img_type in ['image/', 'jpeg', 'png', 'gif', 'webp']):
                    print(f"‚ùå Not an image file: {url} (Content-Type: {content_type})")
                    return {'download_success': False, 'error': 'not_image'}

                # Process image
                try:
                    img = Image.open(io.BytesIO(response.content))
                    img.verify()  # Verify image integrity

                    # Re-open for processing (verify closes the image)
                    img = Image.open(io.BytesIO(response.content))

                    # Convert problematic modes
                    if img.mode in ('RGBA', 'P', 'LA'):
                        img = img.convert('RGB')

                    # Check image size (avoid tiny or huge images)
                    width, height = img.size
                    if width < 50 or height < 50:
                        print(f"üìè Image too small: {url} ({width}x{height})")
                        return {'download_success': False, 'error': 'too_small'}

                    if width > 5000 or height > 5000:
                        print(f"üìè Image too large: {url} ({width}x{height}) - Resizing...")
                        img.thumbnail((2000, 2000), Image.Resampling.LANCZOS)

                    # Save image
                    image_path = self.images_path / filename
                    img.save(image_path, 'JPEG', quality=95, optimize=True)

                    return {
                        'filename': filename,
                        'size_bytes': len(response.content),
                        'dimensions': img.size,
                        'download_success': True,
                        'error': None
                    }

                except (OSError, IOError) as img_error:
                    print(f"üñºÔ∏è Image processing failed: {url} - {img_error}")
                    return {'download_success': False, 'error': 'image_processing_failed'}

            except requests.exceptions.Timeout:
                print(f"‚è∞ Timeout (attempt {attempt + 1}/{max_retries}): {url}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff
                    continue
                return {'download_success': False, 'error': 'timeout'}

            except requests.exceptions.ConnectionError:
                print(f"üåê Connection error (attempt {attempt + 1}/{max_retries}): {url}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                return {'download_success': False, 'error': 'connection_error'}

            except requests.exceptions.RequestException as req_error:
                print(f"üì° Request failed: {url} - {req_error}")
                return {'download_success': False, 'error': f'request_failed: {req_error}'}

            except Exception as e:
                print(f"‚ùå Unexpected error downloading {url}: {e}")
                return {'download_success': False, 'error': f'unexpected: {e}'}

        return {'download_success': False, 'error': 'max_retries_exceeded'}

    def extract_post_features(self, submission) -> Dict:
        """Extract ML features from Reddit post"""
        filename = f"{submission.id}.jpg"
        created_time = datetime.fromtimestamp(submission.created_utc)

        return {
            'post_id': submission.id,
            'filename': filename,
            'title': submission.title,
            'url': submission.url,
            'permalink': f"https://www.reddit.com{submission.permalink}",
            'score': submission.score,
            'upvote_ratio': submission.upvote_ratio,
            'num_comments': submission.num_comments,
            'title_word_count': len(submission.title.split()),
            'title_char_count': len(submission.title),
            'domain': submission.domain,
            'author': str(submission.author) if submission.author else '[deleted]',
            'created_utc': submission.created_utc,
            'created_date': created_time.strftime('%Y-%m-%d %H:%M:%S'),
            'is_nsfw': submission.over_18,
            'flair': submission.link_flair_text if submission.link_flair_text else '',
            'quality_score': submission.score * submission.upvote_ratio,
            'engagement_ratio': submission.num_comments / max(submission.score, 1)
        }

    def get_subreddit_stats(self, subreddit_name: str) -> Dict:
        """Get comprehensive subreddit statistics before scraping"""
        try:
            print(f"üìä Analyzing r/{subreddit_name}...")
            subreddit = self.reddit.subreddit(subreddit_name)

            # Basic subreddit info
            stats = {
                'name': subreddit.display_name,
                'title': subreddit.title,
                'subscribers': subreddit.subscribers,
                'description': subreddit.public_description[:200] + "..." if len(subreddit.public_description) > 200 else subreddit.public_description,
                'created_utc': subreddit.created_utc,
                'over18': subreddit.over18
            }

            # Count posts in different categories
            print("üîç Counting posts in different categories...")

            # Count hot posts
            hot_count = 0
            try:
                for _ in subreddit.hot(limit=1000):
                    hot_count += 1
                    if hot_count >= 1000:
                        break
            except:
                hot_count = "Unable to count"

            # Count top posts (this month)
            top_month_count = 0
            try:
                for _ in subreddit.top(time_filter='month', limit=1000):
                    top_month_count += 1
                    if top_month_count >= 1000:
                        break
            except:
                top_month_count = "Unable to count"

            # Count top posts (all time) - sample first 1000
            top_all_count = 0
            try:
                for _ in subreddit.top(time_filter='all', limit=1000):
                    top_all_count += 1
                    if top_all_count >= 1000:
                        break
            except:
                top_all_count = "Unable to count"

            # Count new posts
            new_count = 0
            try:
                for _ in subreddit.new(limit=1000):
                    new_count += 1
                    if new_count >= 1000:
                        break
            except:
                new_count = "Unable to count"

            stats.update({
                'hot_posts_sample': hot_count,
                'top_month_posts_sample': top_month_count,
                'top_all_posts_sample': top_all_count,
                'new_posts_sample': new_count
            })

            return stats

        except Exception as e:
            print(f"‚ùå Error getting subreddit stats: {e}")
            return {'error': str(e)}

    def display_subreddit_info(self, stats: Dict):
        """Display subreddit information in a user-friendly format"""
        if 'error' in stats:
            print(f"‚ùå Could not get subreddit information: {stats['error']}")
            return False

        print("\n" + "="*60)
        print("üìä SUBREDDIT ANALYSIS")
        print("="*60)
        print(f"üè∑Ô∏è  Name: r/{stats['name']}")
        print(f"üìù Title: {stats['title']}")
        print(f"üë• Subscribers: {stats['subscribers']:,}")
        print(f"üìÖ Created: {datetime.fromtimestamp(stats['created_utc']).strftime('%Y-%m-%d')}")
        print(f"üîû NSFW: {'Yes' if stats['over18'] else 'No'}")
        print(f"üìÑ Description: {stats['description']}")

        print("\nüìà POST COUNTS (Sample of up to 1000 each):")
        print(f"üî• Hot posts: {stats['hot_posts_sample']}")
        print(f"‚≠ê Top (this month): {stats['top_month_posts_sample']}")
        print(f"üèÜ Top (all time): {stats['top_all_posts_sample']}")
        print(f"üÜï New posts: {stats['new_posts_sample']}")

        print("\nüí° SCRAPING PLAN:")
        print(f"üì• Hot posts to scrape: {self.dataset_config['hot_posts']}")
        print(f"üì• Top posts to scrape: {self.dataset_config['top_posts']}")
        print(f"üìä Minimum score filter: {self.dataset_config['min_score']}")
        print(f"üéØ Expected total: ~{self.dataset_config['hot_posts'] + self.dataset_config['top_posts']} posts")

        return True

    def get_user_confirmation(self) -> bool:
        """Get user confirmation before starting scraping"""
        print("\n" + "="*60)
        print("‚ö†Ô∏è  BEFORE WE START:")
        print("‚Ä¢ This will download images and may take 15-45 minutes")
        print("‚Ä¢ Large datasets will use significant storage space")
        print("‚Ä¢ Reddit API has rate limits - scraping may be slow")
        print("‚Ä¢ Some images may fail to download (404, etc.)")
        print("="*60)

        while True:
            try:
                choice = input("\nü§î Do you want to continue with scraping? (y/n): ").lower().strip()
                if choice in ['y', 'yes']:
                    print("‚úÖ Starting scraping process...")
                    return True
                elif choice in ['n', 'no']:
                    print("‚ùå Scraping cancelled by user.")
                    return False
                else:
                    print("Please enter 'y' for yes or 'n' for no.")
            except (EOFError, KeyboardInterrupt):
                print("\n‚ùå Operation cancelled.")
                return False

    def scrape_posts(self, sort_method: str, limit: int) -> List[Dict]:
        """Scrape posts from subreddit"""
        subreddit_name = self.dataset_config['subreddit']
        min_score = self.dataset_config['min_score']

        print(f"üîÑ Scraping {limit} {sort_method} posts from r/{subreddit_name}")

        subreddit = self.reddit.subreddit(subreddit_name)

        # Get submissions
        if sort_method == 'hot':
            submissions = subreddit.hot(limit=limit * 2)  # Get extra to account for filtering
        elif sort_method == 'top':
            submissions = subreddit.top(limit=limit * 2, time_filter='all')
        else:
            raise ValueError(f"Unsupported sort method: {sort_method}")

        dataset = []
        downloaded = 0

        pbar = tqdm(desc=f"Processing {sort_method}", unit="posts")

        for submission in submissions:
            pbar.update(1)

            # Quality filters
            if submission.score < min_score:
                continue

            if not self.is_valid_greentext_image(submission.url):
                continue

            # Stop if we have enough
            if len(dataset) >= limit:
                break

            # Extract features
            post_data = self.extract_post_features(submission)

            # Download image
            image_metadata = self.download_image(submission.url, post_data['filename'])

            if image_metadata:
                post_data.update(image_metadata)
                dataset.append(post_data)
                downloaded += 1
                pbar.set_postfix({'downloaded': downloaded, 'score': submission.score})

            time.sleep(0.1)  # Be respectful

        pbar.close()
        print(f"‚úÖ Downloaded {downloaded} {sort_method} posts")
        return dataset

    def save_dataset(self, df: pd.DataFrame, name: str):
        """Save dataset with metadata"""
        # Save CSV
        csv_path = self.metadata_path / f"{name}.csv"
        df.to_csv(csv_path, index=False)

        # Save info
        info = {
            'created_at': datetime.now().isoformat(),
            'total_posts': len(df),
            'images_downloaded': len(df[df['download_success'] == True]),
            'score_range': [df['score'].min(), df['score'].max()],
            'date_range': [df['created_date'].min(), df['created_date'].max()],
            'average_quality': df['quality_score'].mean(),
            'subreddit': self.dataset_config['subreddit']
        }

        info_path = self.metadata_path / f"{name}_info.json"
        with open(info_path, 'w') as f:
            json.dump(info, f, indent=2)

        print(f"üíæ Saved: {csv_path}")
        return csv_path

    def create_train_test_split(self, df: pd.DataFrame):
        """Create ML train/test split"""
        # Stratify by score quintiles for balanced split
        score_bins = pd.cut(df['score'], bins=5, labels=False)

        train_df, test_df = train_test_split(
            df, test_size=0.2, random_state=42, stratify=score_bins
        )

        # Save splits
        train_path = self.metadata_path / "train_set.csv"
        test_path = self.metadata_path / "test_set.csv"

        train_df.to_csv(train_path, index=False)
        test_df.to_csv(test_path, index=False)

        print(f"üîÑ Train/Test split: {len(train_df)} / {len(test_df)}")
        return train_df, test_df

    def create_archive(self):
        """Create downloadable zip archive"""
        if not self.dataset_config['create_archive']:
            return None

        archive_path = self.base_path.parent / "greentext_ml_dataset.zip"

        print("üì¶ Creating archive...")
        with zipfile.ZipFile(archive_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for file_path in self.base_path.rglob('*'):
                if file_path.is_file():
                    arcname = file_path.relative_to(self.base_path)
                    zipf.write(file_path, arcname)

        size_mb = os.path.getsize(archive_path) / (1024 * 1024)
        print(f"üì¶ Archive created: {archive_path} ({size_mb:.1f} MB)")
        return archive_path

    def run_complete_pipeline(self):
        """Run the complete ML dataset creation pipeline with user confirmation"""
        print("üöÄ GREENTEXT ML DATASET SCRAPER")
        print("=" * 60)

        # Step 1: Get subreddit statistics
        stats = self.get_subreddit_stats(self.dataset_config['subreddit'])

        # Step 2: Display information and get user confirmation
        if not self.display_subreddit_info(stats):
            print("‚ùå Unable to get subreddit information. Exiting.")
            return None, None, None, None

        if not self.get_user_confirmation():
            print("üëã Goodbye! Run the script again when you're ready.")
            return None, None, None, None

        # Step 3: Start the actual scraping process
        print("\nüî• STARTING DATASET CREATION PIPELINE")
        print("=" * 60)

        start_time = time.time()

        # Phase 1: Scrape hot posts
        hot_posts = self.scrape_posts('hot', self.dataset_config['hot_posts'])

        # Phase 2: Scrape top posts
        top_posts = self.scrape_posts('top', self.dataset_config['top_posts'])

        # Check if we got any data
        if not hot_posts and not top_posts:
            print("‚ùå No posts were successfully scraped. Please check your configuration and try again.")
            return None, None, None, None

        # Phase 3: Combine and deduplicate
        print("üîÑ Combining and deduplicating datasets...")
        all_posts = hot_posts + top_posts

        if not all_posts:
            print("‚ùå No posts to process. Exiting.")
            return None, None, None, None

        df = pd.DataFrame(all_posts)

        # Remove duplicates and posts with failed downloads
        original_count = len(df)
        df = df[df['download_success'] == True]  # Only keep successful downloads
        df = df.drop_duplicates(subset=['post_id']).reset_index(drop=True)

        print(f"üìä Dataset processing results:")
        print(f"   üì• Total posts attempted: {original_count}")
        print(f"   ‚úÖ Successfully downloaded: {len(df)}")
        print(f"   üîÑ After deduplication: {len(df)}")

        if len(df) == 0:
            print("‚ùå No valid posts remaining after filtering. Exiting.")
            return None, None, None, None

        # Phase 4: Save complete dataset
        try:
            self.save_dataset(df, "greentext_complete")
        except Exception as e:
            print(f"‚ùå Error saving dataset: {e}")
            return df, None, None, None

        # Phase 5: Create train/test split
        try:
            if len(df) < 10:
                print("‚ö†Ô∏è Dataset too small for train/test split. Skipping split creation.")
                train_df, test_df = df, pd.DataFrame()
            else:
                train_df, test_df = self.create_train_test_split(df)
        except Exception as e:
            print(f"‚ùå Error creating train/test split: {e}")
            train_df, test_df = df, pd.DataFrame()

        # Phase 6: Create archive
        try:
            archive_path = self.create_archive()
        except Exception as e:
            print(f"‚ùå Error creating archive: {e}")
            archive_path = None

        # Phase 7: Generate final summary
        elapsed = time.time() - start_time

        print("\nüéâ PIPELINE COMPLETED!")
        print("=" * 60)
        print(f"‚è±Ô∏è  Total time: {elapsed/60:.1f} minutes")
        print(f"üìä Final dataset size: {len(df)} posts")
        print(f"üñºÔ∏è  Images downloaded: {len(df[df['download_success']])}")

        if len(train_df) > 0:
            print(f"üèãÔ∏è  Training samples: {len(train_df)}")
        if len(test_df) > 0:
            print(f"üß™ Test samples: {len(test_df)}")

        print(f"üíæ Dataset location: {self.base_path}")

        if archive_path:
            print(f"üì¶ Download archive: {archive_path}")

        # Quality metrics
        if len(df) > 0:
            avg_score = df['score'].mean()
            avg_quality = df['quality_score'].mean()
            print(f"\nüìà Quality metrics:")
            print(f"   ‚≠ê Average post score: {avg_score:.1f}")
            print(f"   üéØ Average quality score: {avg_quality:.1f}")

        print("\n‚ú® Your ML dataset is ready for training!")
        print("üöÄ Next steps: Use the CSV files for ML model training")

        return df, train_df, test_df, archive_path

# ============== RUN PIPELINE ==============
def main():
    """Main execution function with comprehensive error handling"""
    try:
        print("üî• GREENTEXT ML DATASET SCRAPER")
        print("Initializing pipeline...")
        print("-" * 40)

        # Create pipeline with error handling
        try:
            pipeline = CompleteGreentextPipeline(REDDIT_CONFIG, DATASET_CONFIG)
        except Exception as e:
            print(f"‚ùå Failed to initialize pipeline: {e}")
            print("üí° Check your Reddit API credentials and try again.")
            return None, None, None, None, None

        # Run complete pipeline
        result = pipeline.run_complete_pipeline()

        if result == (None, None, None, None):
            print("‚ùå Pipeline execution cancelled or failed.")
            return None, None, None, None, None

        dataset, train_set, test_set, archive = result

        # Display sample data if available
        if dataset is not None and len(dataset) > 0:
            print("\nüìã DATASET PREVIEW:")
            print(dataset[['post_id', 'title', 'score', 'quality_score', 'download_success']].head())

        return pipeline, dataset, train_set, test_set, archive

    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Pipeline interrupted by user (Ctrl+C)")
        return None, None, None, None, None
    except Exception as e:
        print(f"‚ùå Pipeline failed with unexpected error: {e}")
        print("üí° Please check your configuration and try again.")
        return None, None, None, None, None

# ============== EXECUTE ==============
if __name__ == "__main__":
    print("üî• GREENTEXT ML DATASET SCRAPER")
    print("Starting automated pipeline...")
    print("-" * 40)

    # Run the complete pipeline
    pipeline, dataset, train_set, test_set, archive = main()

    print("\nüéØ NEXT STEPS:")
    print("1. Check your dataset files in the created directory")
    print("2. Download the archive if needed")
    print("3. Use the train/test CSV files for ML training")
    print("4. Images are saved as JPEGs in the images/ folder")
    print("\nüöÄ Happy machine learning!")