<a href="https://colab.research.google.com/github/RyuichiSaito1/inflation-reddit-usa/blob/main/src/split_main_data_to_training_and_validation_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter

# Step 1: Read the CSV file
print("Step 1: Reading CSV file...")
df = pd.read_csv('/content/drive/MyDrive/world-inflation/data/reddit/production/main-prod-130.csv')

print(f"Original dataset shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

# Step 2 & 3: Calculate the ratio of each class (0, 1, 2)
print("\nStep 3: Calculating class ratios...")
class_counts = df['inflation'].value_counts().sort_index()
total_records = len(df)

print(f"Class distribution:")
for class_label in [0, 1, 2]:
    count = class_counts.get(class_label, 0)
    ratio = count / total_records
    print(f"  Class {class_label}: {count} records ({ratio:.4f})")

# Step 4: Retrieve exactly half records while approximating class ratios
print("\nStep 4: Sampling exactly half records with approximate stratification...")
target_size = -(-total_records // 2)  # Ceiling division to round up half

# Calculate proportional target counts for each class
target_counts = {}
remaining_target = target_size

# First, calculate ideal proportional counts
for class_label in [0, 1, 2]:
    original_count = class_counts.get(class_label, 0)
    original_ratio = original_count / total_records
    ideal_count = int(target_size * original_ratio)
    target_counts[class_label] = ideal_count

# Adjust to ensure we get exactly target_size records
current_total = sum(target_counts.values())
difference = target_size - current_total

# Distribute the difference to maintain ratios as closely as possible
if difference != 0:
    # Sort classes by their original size (largest first) to distribute difference
    sorted_classes = sorted([0, 1, 2], key=lambda x: class_counts.get(x, 0), reverse=True)

    # Add/subtract records starting from the largest class
    for i, class_label in enumerate(sorted_classes):
        if difference > 0:
            if i < difference:
                target_counts[class_label] += 1
        else:  # difference < 0
            if i < abs(difference):
                target_counts[class_label] = max(0, target_counts[class_label] - 1)

print(f"Target sample size: {target_size}")
print(f"Target counts per class:")
for class_label in [0, 1, 2]:
    print(f"  Class {class_label}: {target_counts[class_label]} records")

# Verify total
actual_target_total = sum(target_counts.values())
print(f"Total target records: {actual_target_total}")

# Perform sampling
sampled_dfs = []
for class_label in [0, 1, 2]:
    class_df = df[df['inflation'] == class_label]
    target_count = target_counts[class_label]
    if len(class_df) > 0 and target_count > 0:
        sampled_class_df = class_df.sample(n=min(target_count, len(class_df)),
                                         random_state=42)
        sampled_dfs.append(sampled_class_df)

# Combine sampled data
sampled_df = pd.concat(sampled_dfs, ignore_index=True)

# Shuffle the combined data
sampled_df = sampled_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"\nSampled dataset shape: {sampled_df.shape}")
sampled_class_counts = sampled_df['inflation'].value_counts().sort_index()
print(f"Sampled class distribution:")
for class_label in [0, 1, 2]:
    count = sampled_class_counts.get(class_label, 0)
    ratio = count / len(sampled_df) if len(sampled_df) > 0 else 0
    print(f"  Class {class_label}: {count} records ({ratio:.4f})")

# Step 5: Save the retrieved records
print("\nStep 5: Saving sampled dataset...")
sampled_df.to_csv('/content/drive/MyDrive/world-inflation/data/reddit/production/main-prod-65.csv',
                  index=False)
print("Saved sampled dataset to main-prod-622.csv")

# Step 6: Split into training and validation data with stratification
print("\nStep 6: Splitting into training and validation data...")
X = sampled_df[['body']]
y = sampled_df['inflation']

# Use stratified split to maintain class ratios
X_train, X_val, y_train, y_val = train_test_split(
    X, y,
    test_size=0.25,
    random_state=42,
    stratify=y
)

# Combine features and labels back into DataFrames
train_df = pd.concat([X_train, y_train], axis=1)
val_df = pd.concat([X_val, y_val], axis=1)

print(f"Training set shape: {train_df.shape}")
print(f"Validation set shape: {val_df.shape}")

# Verify class distributions in splits
print(f"\nTraining set class distribution:")
train_class_counts = train_df['inflation'].value_counts().sort_index()
for class_label in [0, 1, 2]:
    count = train_class_counts.get(class_label, 0)
    ratio = count / len(train_df) if len(train_df) > 0 else 0
    print(f"  Class {class_label}: {count} records ({ratio:.4f})")

print(f"\nValidation set class distribution:")
val_class_counts = val_df['inflation'].value_counts().sort_index()
for class_label in [0, 1, 2]:
    count = val_class_counts.get(class_label, 0)
    ratio = count / len(val_df) if len(val_df) > 0 else 0
    print(f"  Class {class_label}: {count} records ({ratio:.4f})")

# Step 7: Save training and validation data
print("\nStep 7: Saving training and validation datasets...")
train_df.to_csv('/content/drive/MyDrive/world-inflation/data/reddit/production/training-data-65.csv',
                index=False)
val_df.to_csv('/content/drive/MyDrive/world-inflation/data/reddit/production/validation-data-65.csv',
              index=False)

print("Saved training data to training-data-622.csv")
print("Saved validation data to validation-data-622.csv")

# Summary statistics
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
print(f"Original dataset: {df.shape[0]} records")
print(f"Sampled dataset: {sampled_df.shape[0]} records ({sampled_df.shape[0]/df.shape[0]:.1%})")
print(f"Training set: {train_df.shape[0]} records ({train_df.shape[0]/sampled_df.shape[0]:.1%})")
print(f"Validation set: {val_df.shape[0]} records ({val_df.shape[0]/sampled_df.shape[0]:.1%})")

# Verify that class ratios are preserved
print(f"\nClass ratio preservation check:")
original_ratios = df['inflation'].value_counts(normalize=True).sort_index()
sampled_ratios = sampled_df['inflation'].value_counts(normalize=True).sort_index()
train_ratios = train_df['inflation'].value_counts(normalize=True).sort_index()
val_ratios = val_df['inflation'].value_counts(normalize=True).sort_index()

for class_label in [0, 1, 2]:
    orig_ratio = original_ratios.get(class_label, 0)
    samp_ratio = sampled_ratios.get(class_label, 0)
    train_ratio = train_ratios.get(class_label, 0)
    val_ratio = val_ratios.get(class_label, 0)

    print(f"Class {class_label}:")
    print(f"  Original: {orig_ratio:.4f}")
    print(f"  Sampled:  {samp_ratio:.4f} (diff: {abs(orig_ratio-samp_ratio):.4f})")
    print(f"  Training: {train_ratio:.4f} (diff: {abs(orig_ratio-train_ratio):.4f})")
    print(f"  Validation: {val_ratio:.4f} (diff: {abs(orig_ratio-val_ratio):.4f})")

print("\nProcess completed successfully!")

Step 1: Reading CSV file...
Original dataset shape: (130, 2)
Columns: ['body', 'inflation']

Step 3: Calculating class ratios...
Class distribution:
  Class 0: 40 records (0.3077)
  Class 1: 51 records (0.3923)
  Class 2: 39 records (0.3000)

Step 4: Sampling exactly half records with approximate stratification...
Target sample size: 65
Target counts per class:
  Class 0: 20 records
  Class 1: 26 records
  Class 2: 19 records
Total target records: 65

Sampled dataset shape: (65, 2)
Sampled class distribution:
  Class 0: 20 records (0.3077)
  Class 1: 26 records (0.4000)
  Class 2: 19 records (0.2923)

Step 5: Saving sampled dataset...
Saved sampled dataset to main-prod-622.csv

Step 6: Splitting into training and validation data...
Training set shape: (48, 2)
Validation set shape: (17, 2)

Training set class distribution:
  Class 0: 15 records (0.3125)
  Class 1: 19 records (0.3958)
  Class 2: 14 records (0.2917)

Validation set class distribution:
  Class 0: 5 records (0.2941)
  Class