# Hybird Core-based + Stratified Sampling

In [1]:
import json
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
import random
import glob
import os
from typing import Dict, List, Tuple, Set
import gc
import psutil
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# SCALE CONFIGURATION
# =============================================================================

def get_scale_parameters(scale_name: str, target_playlists: int) -> Dict:
    """Get optimized parameters for different scales"""

    if target_playlists <= 300:  # tiny
        return {
            'min_track_frequency': 3,
            'min_user_playlists': 3,
            'expected_total_nodes': 5000,
            'expected_tracks': 2000,
            'expected_artists': 500,
            'expected_albums': 400,
            'expected_users': 100
        }
    elif target_playlists <= 600:  # small
        return {
            'min_track_frequency': 5,
            'min_user_playlists': 5,
            'expected_total_nodes': 10000,
            'expected_tracks': 4000,
            'expected_artists': 1000,
            'expected_albums': 800,
            'expected_users': 200
        }
    elif target_playlists <= 1000:  # medium
        return {
            'min_track_frequency': 6,
            'min_user_playlists': 7,
            'expected_total_nodes': 20000,
            'expected_tracks': 8000,
            'expected_artists': 2000,
            'expected_albums': 1500,
            'expected_users': 300
        }
    else:  # large (1500+)
        return {
            'min_track_frequency': 8,
            'min_user_playlists': 10,
            'expected_total_nodes': 30000,
            'expected_tracks': 12000,
            'expected_artists': 3000,
            'expected_albums': 2500,
            'expected_users': 500
        }

# =============================================================================
# MULTI-SCALE HYBRID STREAMING SAMPLER
# =============================================================================

class MultiScaleHybridStreamingSampler:
    """
    Multi-Scale Hybrid Core-Based + Stratified Streaming Sampler

    Maintains original methodology with scale-optimized parameters:
    - Pass 1: Core-based filtering (length, user activity, track frequency)
    - Pass 2: Stratified sampling with priority scoring
    """

    def __init__(self,
                 scale_name: str,
                 target_playlists: int,
                 batch_size: int = 20,
                 min_playlist_length: int = 10,
                 max_playlist_length: int = 100):

        # Get scale-specific parameters
        scale_params = get_scale_parameters(scale_name, target_playlists)

        # Core parameters
        self.scale_name = scale_name
        self.target_playlists = target_playlists
        self.batch_size = batch_size
        self.min_playlist_length = min_playlist_length
        self.max_playlist_length = max_playlist_length
        self.min_track_frequency = scale_params['min_track_frequency']
        self.min_user_playlists = scale_params['min_user_playlists']

        # Expected targets for verification
        self.expected_total_nodes = scale_params['expected_total_nodes']
        self.expected_tracks = scale_params['expected_tracks']
        self.expected_artists = scale_params['expected_artists']
        self.expected_albums = scale_params['expected_albums']
        self.expected_users = scale_params['expected_users']

        # Statistics collectors
        self.track_counts = Counter()
        self.user_counts = Counter()
        self.playlist_stats = []

        print(f"🎯 MULTI-SCALE SAMPLER ({scale_name.upper()} SCALE)")
        print(f"=" * 70)
        print(f"📊 SCALE-OPTIMIZED PARAMETERS:")
        print(f"   • Target playlists: {target_playlists:,}")
        print(f"   • Min track frequency: {scale_params['min_track_frequency']}")
        print(f"   • Min user playlists: {scale_params['min_user_playlists']}")
        print(f"   • Expected total nodes: ~{self.expected_total_nodes:,}")
        print()

    def get_memory_usage(self):
        """Get current memory usage in MB"""
        try:
            process = psutil.Process(os.getpid())
            return process.memory_info().rss / 1024 / 1024
        except:
            return 0.0

    def print_memory_status(self, stage: str):
        """Print current memory usage"""
        memory_mb = self.get_memory_usage()
        print(f"💾 Memory usage after {stage}: {memory_mb:.1f} MB")

    def get_file_batches(self, file_pattern: str) -> List[List[str]]:
        """Split files into manageable batches"""
        file_paths = glob.glob(file_pattern)
        file_paths.sort()

        if not file_paths:
            raise FileNotFoundError(f"No files found: {file_pattern}")

        print(f"📁 Found {len(file_paths)} files")

        # Split into batches
        batches = []
        for i in range(0, len(file_paths), self.batch_size):
            batch = file_paths[i:i + self.batch_size]
            batches.append(batch)

        print(f"📦 Created {len(batches)} batches of ~{self.batch_size} files each")
        return batches

    def _extract_user_id(self, playlist: Dict) -> str:
        """Extract user identifier with scale-appropriate consolidation"""
        # Name-based grouping with consolidation based on scale
        name = playlist.get('name', '').lower().strip()
        if name:
            words = name.split()
            if words:
                # Adjust consolidation level based on scale
                char_count = 2 if self.target_playlists <= 600 else 3
                user_base = words[0][:char_count]
                user_id = ''.join(c for c in user_base if c.isalnum())
                if user_id:
                    return user_id

        # PID-based consolidation
        pid = playlist.get('pid', 0)
        # Scale user bins based on target size
        user_bins = max(100, self.target_playlists // 3)
        return f"u{pid % user_bins}"

    def pass1_core_filtering(self, file_pattern: str) -> Dict:
        """PASS 1: Core-based filtering + Statistics collection"""
        print("🔍 PASS 1: HYBRID CORE-BASED FILTERING")
        print("=" * 60)

        batches = self.get_file_batches(file_pattern)

        stage_counts = {
            'total_seen': 0,
            'passed_length_filter': 0,
            'passed_user_filter': 0,
            'passed_track_frequency_filter': 0,
            'final_valid': 0
        }

        print("🚀 Applying Core-Based Filters:")
        print(f"   ✅ Step 1: Playlist length ({self.min_playlist_length}-{self.max_playlist_length} tracks)")
        print(f"   ✅ Step 2: User activity (≥{self.min_user_playlists} playlists per user)")
        print(f"   ✅ Step 3: Track frequency (≥{self.min_track_frequency} appearances)")
        print()

        # Sub-pass 1a: Collect user activity statistics
        print("📊 Sub-pass 1a: Collecting user activity statistics...")
        user_playlist_count = Counter()

        for batch_idx, file_batch in enumerate(batches):
            if batch_idx % 20 == 0:
                print(f"   User stats progress: {batch_idx + 1}/{len(batches)} batches")

            for file_path in file_batch:
                try:
                    with open(file_path, 'r') as f:
                        data = json.load(f)

                    file_playlists = data.get('playlists', [])
                    for playlist in file_playlists:
                        user_id = self._extract_user_id(playlist)
                        user_playlist_count[user_id] += 1

                except Exception:
                    continue

        # Identify active users
        active_users = {
            user for user, count in user_playlist_count.items()
            if count >= self.min_user_playlists
        }

        print(f"   ✅ Identified {len(active_users):,} active users (target: ~{self.expected_users})")
        print()

        # Sub-pass 1b: Apply all core filters
        print("🔍 Sub-pass 1b: Applying all core filters...")

        for batch_idx, file_batch in enumerate(batches):
            print(f"📦 Processing batch {batch_idx + 1}/{len(batches)}")

            batch_playlists = []

            # Load batch
            for file_path in file_batch:
                try:
                    with open(file_path, 'r') as f:
                        data = json.load(f)

                    file_playlists = data.get('playlists', [])
                    for playlist in file_playlists:
                        playlist['_source_file'] = file_path

                    batch_playlists.extend(file_playlists)

                except Exception as e:
                    print(f"   ⚠️  Error loading {os.path.basename(file_path)}: {e}")
                    continue

            # Process batch with core filtering
            for playlist in batch_playlists:
                stage_counts['total_seen'] += 1

                # CORE FILTER 1: Playlist length
                tracks = playlist.get('tracks', [])
                playlist_length = len(tracks)

                if not (self.min_playlist_length <= playlist_length <= self.max_playlist_length):
                    continue
                stage_counts['passed_length_filter'] += 1

                # CORE FILTER 2: User activity
                user_id = self._extract_user_id(playlist)
                if user_id not in active_users:
                    continue
                stage_counts['passed_user_filter'] += 1

                # Count tracks for frequency analysis
                playlist_tracks = set()
                for track in tracks:
                    track_uri = track.get('track_uri', '')
                    if track_uri:
                        self.track_counts[track_uri] += 1
                        playlist_tracks.add(track_uri)

                self.user_counts[user_id] += 1

                # Store playlist metadata
                playlist_metadata = {
                    'file_path': playlist['_source_file'],
                    'pid': playlist.get('pid'),
                    'length': playlist_length,
                    'modified_at': playlist.get('modified_at', 0),
                    'user_id': user_id,
                    'track_uris': list(playlist_tracks),
                    'name': playlist.get('name', ''),
                    'collaborative': playlist.get('collaborative', False),
                    'num_followers': playlist.get('num_followers', 0)
                }

                self.playlist_stats.append(playlist_metadata)

            # Clear batch from memory
            del batch_playlists
            gc.collect()

            if batch_idx % 10 == 0:
                print(f"   Progress: {stage_counts['total_seen']:,} seen, {len(self.playlist_stats):,} valid so far")
                self.print_memory_status(f"batch {batch_idx + 1}")

        # CORE FILTER 3: Track frequency
        print(f"\n🔍 Applying final core filter: Track frequency (≥{self.min_track_frequency})")

        core_tracks = {
            track for track, count in self.track_counts.items()
            if count >= self.min_track_frequency
        }

        print(f"   ✅ Identified {len(core_tracks):,} core tracks (target: ~{self.expected_tracks})")

        # Filter playlists that have core tracks
        filtered_playlist_stats = []
        for playlist_meta in self.playlist_stats:
            playlist_tracks = set(playlist_meta['track_uris'])
            if playlist_tracks.intersection(core_tracks):
                filtered_playlist_stats.append(playlist_meta)
                stage_counts['passed_track_frequency_filter'] += 1

        self.playlist_stats = filtered_playlist_stats
        stage_counts['final_valid'] = len(filtered_playlist_stats)

        print(f"\n✅ CORE-BASED FILTERING COMPLETE:")
        print(f"   📊 Filtering Funnel:")
        print(f"      • Total playlists: {stage_counts['total_seen']:,}")
        print(f"      • Length filter: {stage_counts['passed_length_filter']:,} ({stage_counts['passed_length_filter']/stage_counts['total_seen']*100:.1f}%)")
        print(f"      • User filter: {stage_counts['passed_user_filter']:,} ({stage_counts['passed_user_filter']/stage_counts['total_seen']*100:.1f}%)")
        print(f"      • Track frequency: {stage_counts['passed_track_frequency_filter']:,} ({stage_counts['passed_track_frequency_filter']/stage_counts['total_seen']*100:.1f}%)")
        print(f"      • 🎯 FINAL VALID: {stage_counts['final_valid']:,} ({stage_counts['final_valid']/stage_counts['total_seen']*100:.1f}%)")
        print()

        return {
            'total_playlists': stage_counts['total_seen'],
            'valid_playlists': stage_counts['final_valid'],
            'core_tracks': core_tracks,
            'active_users': active_users,
            'unique_tracks': len(self.track_counts),
            'stage_counts': stage_counts
        }

    def create_strata(self) -> Dict[str, List[int]]:
        """Create comprehensive strata for stratified sampling"""
        print("📊 CREATING STRATIFIED SAMPLING STRATA")
        print("=" * 50)

        # Get temporal split
        timestamps = [p['modified_at'] for p in self.playlist_stats if p['modified_at'] > 0]
        median_time = np.median(timestamps) if timestamps else 1500000000

        # Get user activity split
        user_playlist_counts = {}
        for playlist_meta in self.playlist_stats:
            user_id = playlist_meta['user_id']
            user_playlist_counts[user_id] = user_playlist_counts.get(user_id, 0) + 1

        user_activity_median = np.median(list(user_playlist_counts.values())) if user_playlist_counts else 5

        print(f"   📅 Temporal split at timestamp: {median_time}")
        print(f"   👥 User activity split at: {user_activity_median} playlists per user")

        # Create 12 comprehensive strata
        strata = {
            'short_old_casual': [], 'short_old_active': [],
            'short_recent_casual': [], 'short_recent_active': [],
            'medium_old_casual': [], 'medium_old_active': [],
            'medium_recent_casual': [], 'medium_recent_active': [],
            'long_old_casual': [], 'long_old_active': [],
            'long_recent_casual': [], 'long_recent_active': []
        }

        for i, playlist_meta in enumerate(self.playlist_stats):
            length = playlist_meta['length']
            timestamp = playlist_meta['modified_at']
            user_id = playlist_meta['user_id']
            user_activity = user_playlist_counts.get(user_id, 1)

            # Categorize
            length_cat = 'short' if length <= 30 else 'medium' if length <= 60 else 'long'
            time_cat = 'recent' if timestamp >= median_time else 'old'
            activity_cat = 'active' if user_activity >= user_activity_median else 'casual'

            stratum_key = f"{length_cat}_{time_cat}_{activity_cat}"
            strata[stratum_key].append(i)

        # Print strata distribution
        print("   📋 Strata Distribution:")
        total_playlists = len(self.playlist_stats)

        for stratum, indices in strata.items():
            if indices:
                percentage = len(indices) / total_playlists * 100
                print(f"      • {stratum:20s}: {len(indices):6,} ({percentage:4.1f}%)")

        print()
        return strata

    def _calculate_priority_score(self, playlist_meta: Dict) -> float:
        """Calculate priority score for playlist selection"""
        score = 0.0

        # Factor 1: Track diversity (30% weight)
        unique_tracks = len(playlist_meta['track_uris'])
        playlist_length = playlist_meta['length']
        if playlist_length > 0:
            track_diversity_ratio = unique_tracks / playlist_length
            score += track_diversity_ratio * 3.0

        # Factor 2: User engagement (25% weight)
        num_followers = playlist_meta.get('num_followers', 0)
        if num_followers > 0:
            follower_score = min(np.log10(num_followers + 1), 3.0)
            score += follower_score * 2.5

        # Factor 3: Playlist completeness (20% weight)
        name = playlist_meta.get('name', '')
        has_good_name = len(name.strip()) > 3 and not name.lower().startswith('my playlist')
        if has_good_name:
            score += 2.0

        # Factor 4: Collaborative playlists bonus (10% weight)
        if playlist_meta.get('collaborative', False):
            score += 1.0

        # Factor 5: Length balance bonus (15% weight)
        length = playlist_meta['length']
        if 20 <= length <= 80:
            score += 1.5

        return score

    def pass2_stratified_sampling(self, strata: Dict[str, List[int]]) -> List[Dict]:
        """PASS 2: Stratified sampling with priority scoring"""
        print("🎲 PASS 2: STRATIFIED SAMPLING WITH PRIORITY SCORING")
        print("=" * 60)

        total_available = len(self.playlist_stats)

        if total_available <= self.target_playlists:
            print(f"   📝 Available ({total_available:,}) ≤ target ({self.target_playlists:,})")
            selected_indices = list(range(total_available))
        else:
            sampling_ratio = self.target_playlists / total_available
            selected_indices = set()

            print(f"   📊 Global sampling ratio: {sampling_ratio:.3f}")
            print(f"   🏆 Using priority scoring within strata")
            print()

            # Stratified sampling with priority scoring
            for stratum, indices in strata.items():
                if not indices:
                    continue

                stratum_target = max(1, int(len(indices) * sampling_ratio))
                stratum_target = min(stratum_target, len(indices))

                # Score playlists in this stratum
                scored_playlists = []
                for idx in indices:
                    playlist_meta = self.playlist_stats[idx]
                    score = self._calculate_priority_score(playlist_meta)
                    scored_playlists.append((idx, score))

                # Sort by score and sample
                scored_playlists.sort(key=lambda x: x[1], reverse=True)

                # Hybrid: 70% top-scored + 30% random
                top_count = int(stratum_target * 0.7)
                random_count = stratum_target - top_count

                selected = [idx for idx, _ in scored_playlists[:top_count]]

                if random_count > 0 and len(scored_playlists) > top_count:
                    remaining = [idx for idx, _ in scored_playlists[top_count:]]
                    if len(remaining) >= random_count:
                        selected.extend(random.sample(remaining, random_count))
                    else:
                        selected.extend(remaining)

                selected_indices.update(selected)

                avg_score = np.mean([score for _, score in scored_playlists[:len(selected)]])
                print(f"      • {stratum:20s}: {len(selected):4,} / {len(indices):5,} (avg score: {avg_score:.2f})")

            selected_indices = list(selected_indices)

        print(f"\n   🎯 Selected {len(selected_indices):,} playlists for final loading")
        return self._load_selected_playlists(selected_indices)

    def _load_selected_playlists(self, selected_indices: List[int]) -> List[Dict]:
        """Load only the selected playlists from files"""
        print("   📁 Loading selected playlists...")

        # Group by file
        file_to_playlists = defaultdict(list)
        for idx in selected_indices:
            playlist_meta = self.playlist_stats[idx]
            file_path = playlist_meta['file_path']
            file_to_playlists[file_path].append(playlist_meta)

        print(f"   📂 Loading from {len(file_to_playlists)} files")

        final_playlists = []

        for file_idx, (file_path, playlist_metas) in enumerate(file_to_playlists.items()):
            if file_idx % 100 == 0:
                print(f"      📖 File {file_idx + 1}/{len(file_to_playlists)}")

            try:
                with open(file_path, 'r') as f:
                    data = json.load(f)

                file_playlists = data.get('playlists', [])
                pid_to_playlist = {p.get('pid'): p for p in file_playlists}

                for meta in playlist_metas:
                    pid = meta['pid']
                    if pid in pid_to_playlist:
                        playlist = pid_to_playlist[pid]
                        playlist['_sampling_score'] = self._calculate_priority_score(meta)
                        final_playlists.append(playlist)

            except Exception:
                continue

        print(f"   ✅ Loaded {len(final_playlists):,} final playlists")
        return final_playlists

    def _verify_final_scale(self, final_playlists: List[Dict]) -> Dict:
        """Verify that final scale meets experimental targets"""
        print("\n🔍 VERIFYING FINAL SCALE FOR EXPERIMENTAL CONTROL")
        print("=" * 50)

        # Count actual entities
        actual_tracks = set()
        actual_artists = set()
        actual_albums = set()
        actual_users = set()

        for playlist in final_playlists:
            user_id = self._extract_user_id(playlist)
            actual_users.add(user_id)

            for track in playlist.get('tracks', []):
                track_uri = track.get('track_uri', '')
                artist_uri = track.get('artist_uri', '')
                album_uri = track.get('album_uri', '')

                if track_uri:
                    actual_tracks.add(track_uri)
                if artist_uri:
                    actual_artists.add(artist_uri)
                if album_uri:
                    actual_albums.add(album_uri)

        actual_scale = {
            'playlists': len(final_playlists),
            'tracks': len(actual_tracks),
            'artists': len(actual_artists),
            'albums': len(actual_albums),
            'users': len(actual_users),
            'total_nodes': len(final_playlists) + len(actual_tracks) + len(actual_artists) + len(actual_albums) + len(actual_users)
        }

        print(f"   📊 Final Entity Counts:")
        print(f"      • Playlists: {actual_scale['playlists']:,} (target: {self.target_playlists:,})")
        print(f"      • Tracks: {actual_scale['tracks']:,} (target: ~{self.expected_tracks:,})")
        print(f"      • Artists: {actual_scale['artists']:,} (target: ~{self.expected_artists:,})")
        print(f"      • Albums: {actual_scale['albums']:,} (target: ~{self.expected_albums:,})")
        print(f"      • Users: {actual_scale['users']:,} (target: ~{self.expected_users:,})")
        print(f"      🎯 TOTAL: {actual_scale['total_nodes']:,} (target: ~{self.expected_total_nodes:,})")

        return actual_scale

    def run_hybrid_sampling(self, file_pattern: str) -> Tuple[List[Dict], Dict]:
        """Main method: Complete hybrid sampling workflow"""
        print("🚀 STARTING MULTI-SCALE HYBRID SAMPLING")
        print("=" * 70)

        self.print_memory_status("start")

        # Pass 1: Core-based filtering
        stats = self.pass1_core_filtering(file_pattern)
        self.print_memory_status("pass 1 complete")

        # Create strata
        strata = self.create_strata()
        self.print_memory_status("strata created")

        # Pass 2: Stratified sampling
        final_playlists = self.pass2_stratified_sampling(strata)
        self.print_memory_status("pass 2 complete")

        # Scale verification
        actual_scale = self._verify_final_scale(final_playlists)

        # Final statistics
        final_stats = {
            'methodology': 'multi_scale_hybrid_core_based_stratified_streaming',
            'scale': self.scale_name,
            'original_total': stats['total_playlists'],
            'final_sampled': len(final_playlists),
            'retention_rate': len(final_playlists) / stats['total_playlists'],
            'core_filtering_retention': stats['valid_playlists'] / stats['total_playlists'],
            'unique_tracks': stats['unique_tracks'],
            'core_tracks_count': len(stats['core_tracks']),
            'active_users_count': len(stats['active_users']),
            'stage_counts': stats['stage_counts'],
            'actual_scale': actual_scale,
            'scale_targets': {
                'total_nodes': self.expected_total_nodes,
                'playlists': self.target_playlists,
                'tracks': self.expected_tracks,
                'artists': self.expected_artists,
                'albums': self.expected_albums,
                'users': self.expected_users
            }
        }

        print("\n🎉 MULTI-SCALE HYBRID SAMPLING COMPLETE!")
        print("=" * 70)
        print(f"📊 Results: {stats['total_playlists']:,} → {len(final_playlists):,} playlists")
        print(f"📈 Overall retention: {len(final_playlists) / stats['total_playlists']:.1%}")

        # Scale verification summary
        total_actual = actual_scale['total_nodes']
        scale_ratio = total_actual / self.expected_total_nodes
        print(f"\n🎯 EXPERIMENTAL SCALE VERIFICATION:")
        print(f"   • Actual total nodes: {total_actual:,}")
        print(f"   • Target total nodes: {self.expected_total_nodes:,}")
        print(f"   • Scale ratio: {scale_ratio:.3f} ({'✅ GOOD' if 0.8 <= scale_ratio <= 1.2 else '⚠️ ADJUST'})")

        return final_playlists, final_stats

# =============================================================================
# MULTI-SCALE CREATION FUNCTIONS
# =============================================================================

def run_multi_scale_sampling(file_pattern: str, scales: Dict[str, int] = None):
    """Create multiple scale datasets in one execution"""

    if scales is None:
        scales = {
            'tiny': 300,
            'small': 600,
            'medium': 1000,
            'large': 1500
        }

    print("🎯 MULTI-SCALE DATASET CREATION")
    print("=" * 60)
    print("Creating multiple scale datasets for flexible experimentation")
    print()

    # Show scale overview
    print("📊 SCALES TO CREATE:")
    print("Scale    Playlists  Expected Nodes  Expected Time")
    print("-" * 50)

    time_estimates = {
        'tiny': '15-30 min',
        'small': '45-75 min',
        'medium': '2-3 hours',
        'large': '4-6 hours'
    }

    for scale_name, target_playlists in scales.items():
        scale_params = get_scale_parameters(scale_name, target_playlists)
        expected_nodes = scale_params['expected_total_nodes']
        time_est = time_estimates.get(scale_name, 'Unknown')
        print(f"{scale_name:<8} {target_playlists:<10} {expected_nodes:<15,} {time_est}")

    print()

    # Set seeds for reproducibility
    random.seed(42)
    np.random.seed(42)

    results = {}

    for scale_name, target_playlists in scales.items():
        print(f"\n{'='*70}")
        print(f"🔧 CREATING {scale_name.upper()} SCALE DATASET")
        print(f"Target: {target_playlists:,} playlists")
        print(f"{'='*70}")

        try:
            # Create scale-specific sampler
            sampler = MultiScaleHybridStreamingSampler(
                scale_name=scale_name,
                target_playlists=target_playlists,
                batch_size=20,
                min_playlist_length=10,
                max_playlist_length=100
            )

            # Run sampling
            sampled_playlists, stats = sampler.run_hybrid_sampling(file_pattern)

            # Create output data
            output_data = {
                'info': {
                    'generated_on': datetime.now().isoformat(),
                    'sampling_method': 'multi_scale_hybrid_core_based_stratified_streaming',
                    'scale': scale_name,
                    'target_playlists': target_playlists,
                    'parameters': get_scale_parameters(scale_name, target_playlists)
                },
                'sampling_stats': stats,
                'playlists': sampled_playlists
            }

            # Save scale-specific file
            os.makedirs('../data/processed', exist_ok=True)
            output_file = f'../data/processed/spotify_scaled_hybrid_{scale_name}.json'

            with open(output_file, 'w') as f:
                json.dump(output_data, f, indent=2)

            file_size_mb = os.path.getsize(output_file) / (1024 * 1024)

            print(f"\n✅ {scale_name.upper()} SCALE COMPLETED:")
            print(f"   📁 File: {output_file}")
            print(f"   📦 Size: {file_size_mb:.1f} MB")
            print(f"   📊 Playlists: {len(sampled_playlists):,}")
            print(f"   📊 Total nodes: {stats['actual_scale']['total_nodes']:,}")

            # Store results
            results[scale_name] = {
                'file_path': output_file,
                'playlists': len(sampled_playlists),
                'total_nodes': stats['actual_scale']['total_nodes'],
                'file_size_mb': file_size_mb,
                'stats': stats
            }

        except Exception as e:
            print(f"❌ Failed to create {scale_name} scale: {e}")
            import traceback
            traceback.print_exc()
            continue

    # Print final summary
    print(f"\n{'='*70}")
    print("🎉 MULTI-SCALE DATASET CREATION COMPLETED!")
    print(f"{'='*70}")

    if results:
        print("📊 CREATED DATASETS:")
        print("Scale    File                                 Playlists  Nodes     Size")
        print("-" * 80)

        for scale_name, result in results.items():
            filename = os.path.basename(result['file_path'])
            print(f"{scale_name:<8} {filename:<35} {result['playlists']:<10,} {result['total_nodes']:<9,} {result['file_size_mb']:.1f}MB")

    return results

def run_single_scale_sampling(file_pattern: str,
                             scale_name: str = 'small',
                             target_playlists: int = None):
    """Create a single scale dataset"""

    # Default targets for each scale
    default_targets = {
        'tiny': 300,
        'small': 600,
        'medium': 1000,
        'large': 1500
    }

    if target_playlists is None:
        target_playlists = default_targets.get(scale_name, 600)

    print(f"🎯 CREATING SINGLE {scale_name.upper()} SCALE DATASET")
    print(f"Target: {target_playlists:,} playlists")
    print("=" * 60)

    # Set seeds for reproducibility
    random.seed(42)
    np.random.seed(42)

    # Create sampler
    sampler = MultiScaleHybridStreamingSampler(
        scale_name=scale_name,
        target_playlists=target_playlists,
        batch_size=20,
        min_playlist_length=10,
        max_playlist_length=100
    )

    # Run sampling
    sampled_playlists, stats = sampler.run_hybrid_sampling(file_pattern)

    # Create output data
    output_data = {
        'info': {
            'generated_on': datetime.now().isoformat(),
            'sampling_method': 'multi_scale_hybrid_core_based_stratified_streaming',
            'scale': scale_name,
            'target_playlists': target_playlists,
            'parameters': get_scale_parameters(scale_name, target_playlists)
        },
        'sampling_stats': stats,
        'playlists': sampled_playlists
    }

    # Save file
    os.makedirs('../data/processed', exist_ok=True)
    output_file = f'../data/processed/spotify_scaled_hybrid_{scale_name}.json'

    with open(output_file, 'w') as f:
        json.dump(output_data, f, indent=2)

    file_size_mb = os.path.getsize(output_file) / (1024 * 1024)

    print(f"\n✅ {scale_name.upper()} SCALE COMPLETED:")
    print(f"   📁 File: {output_file}")
    print(f"   📦 Size: {file_size_mb:.1f} MB")
    print(f"   📊 Playlists: {len(sampled_playlists):,}")
    print(f"   📊 Total nodes: {stats['actual_scale']['total_nodes']:,}")

    expected_time = {
        'tiny': '15-30 minutes',
        'small': '45-75 minutes',
        'medium': '2-3 hours',
        'large': '4-6 hours'
    }.get(scale_name, 'Unknown')

    print(f"   ⏱️ Expected experiment time: {expected_time}")

    return sampled_playlists, stats, output_file

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def compare_scales_detailed():
    """Detailed comparison of all scales"""

    scales_info = {
        'tiny': {
            'playlists': 300,
            'nodes': '~5K',
            'experiment_time': '15-30 min',
            'min_track_freq': 3,
            'min_user_playlists': 3
        },
        'small': {
            'playlists': 600,
            'nodes': '~10K',
            'experiment_time': '45-75 min',
            'min_track_freq': 5,
            'min_user_playlists': 5
        },
        'medium': {
            'playlists': 1000,
            'nodes': '~20K',
            'experiment_time': '2-3 hours',
            'min_track_freq': 6,
            'min_user_playlists': 7
        },
        'large': {
            'playlists': 1500,
            'nodes': '~30K',
            'experiment_time': '4-6 hours',
            'min_track_freq': 8,
            'min_user_playlists': 10
        }
    }

    print("📊 DETAILED SCALE COMPARISON")
    print("=" * 100)
    print("Scale   Playlists  Nodes   Exp.Time   Track_Freq  User_Freq")
    print("-" * 100)

    for scale, info in scales_info.items():
        print(f"{scale:<7} {info['playlists']:<10} {info['nodes']:<7} {info['experiment_time']:<10} "
              f"{info['min_track_freq']:<11} {info['min_user_playlists']:<10}")

# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    print("🎯 MULTI-SCALE SPOTIFY SAMPLER")
    print("=" * 60)

    # Show scale comparison
    compare_scales_detailed()

    print(f"\n🔧 USAGE OPTIONS:")
    print(f"1. Create all scales at once:")
    print(f"   run_multi_scale_sampling(file_pattern)")
    print(f"")
    print(f"2. Create single scale:")
    print(f"   run_single_scale_sampling(file_pattern, 'small')")
    print(f"")
    print(f"3. Create custom scales:")
    print(f"   custom_scales = {{'quick': 200, 'normal': 800}}")
    print(f"   run_multi_scale_sampling(file_pattern, custom_scales)")

    # Example usage
    file_pattern = "../data/raw/data/mpd.slice.*.json"
    run_multi_scale_sampling(file_pattern)

🎯 MULTI-SCALE SPOTIFY SAMPLER
📊 DETAILED SCALE COMPARISON
Scale   Playlists  Nodes   Exp.Time   Track_Freq  User_Freq
----------------------------------------------------------------------------------------------------
tiny    300        ~5K     15-30 min  3           3         
small   600        ~10K    45-75 min  5           5         
medium  1000       ~20K    2-3 hours  6           7         
large   1500       ~30K    4-6 hours  8           10        

🔧 USAGE OPTIONS:
1. Create all scales at once:
   run_multi_scale_sampling(file_pattern)

2. Create single scale:
   run_single_scale_sampling(file_pattern, 'small')

3. Create custom scales:
   custom_scales = {'quick': 200, 'normal': 800}
   run_multi_scale_sampling(file_pattern, custom_scales)
🎯 MULTI-SCALE DATASET CREATION
Creating multiple scale datasets for flexible experimentation

📊 SCALES TO CREATE:
Scale    Playlists  Expected Nodes  Expected Time
--------------------------------------------------
tiny     300        5,0