# Truncate Training Datasets

This notebook loads the original training datasets and creates truncated versions with only the first 300 samples from each dataset.

In [14]:
import json
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
# Define the paths to the training data
train_paths = [
    '/content/drive/Shareddrives/517 nlp project/data/2WikiMultihopQA/train.json',
    '/content/drive/Shareddrives/517 nlp project/data/Bamboogle/train.json',
    '/content/drive/Shareddrives/517 nlp project/data/FEVER/fever_train.jsonl',
    '/content/drive/Shareddrives/517 nlp project/data/FEVEROUS/feverous_train.jsonl',
    '/content/drive/Shareddrives/517 nlp project/data/HotpotQA/train.json',
    '/content/drive/Shareddrives/517 nlp project/data/SVAMP/train.json',
    '/content/drive/Shareddrives/517 nlp project/data/VitaminC/vitaminc/train.jsonl'
]

# Number of samples to keep
num_samples = 300

In [16]:
# Function to check if a file is in JSONL format
def is_jsonl(file_path):
    # Check if file extension is jsonl
    if file_path.lower().endswith('.jsonl'):
        return True

    # If not explicit in extension, peek at first few lines
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            first_lines = [f.readline().strip() for _ in range(3)]

        # Check if each non-empty line is a valid JSON object
        valid_lines = 0
        for line in first_lines:
            if not line:
                continue
            try:
                json.loads(line)
                valid_lines += 1
            except json.JSONDecodeError:
                return False

        # If we found multiple valid JSON lines, likely JSONL format
        return valid_lines > 1
    except Exception:
        return False

# Function to truncate JSON dataset and save to new file
def truncate_dataset(file_path, num_samples):
    try:
        # Get the directory and filename
        directory = os.path.dirname(file_path)
        filename = os.path.basename(file_path)
        dataset_name = os.path.basename(os.path.dirname(file_path))

        # Create output filename with _truncated suffix
        base_filename = os.path.splitext(filename)[0]
        extension = os.path.splitext(filename)[1]
        output_filename = base_filename + f"_truncated_{num_samples}" + extension
        output_path = os.path.join(directory, output_filename)

        print(f"Processing {dataset_name}...")

        # Check if the file is in JSONL format
        if is_jsonl(file_path):
            print(f"  Detected JSONL format for {dataset_name}")
            # Process JSONL file line by line
            with open(file_path, 'r', encoding='utf-8') as f_in:
                lines = [line.strip() for line in f_in if line.strip()]

            # Count original samples
            original_count = len(lines)
            # Truncate to the specified number of samples
            truncated_lines = lines[:num_samples]

            # Write truncated lines to new file
            with open(output_path, 'w', encoding='utf-8') as f_out:
                for line in truncated_lines:
                    f_out.write(line + '\n')

            print(f"  Original samples: {original_count}, Truncated to: {len(truncated_lines)}")
        else:
            # Process as regular JSON
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            # Determine the structure of the data and truncate accordingly
            if isinstance(data, list):
                # If data is a list, simply take the first num_samples items
                truncated_data = data[:num_samples]
                print(f"  Original samples: {len(data)}, Truncated to: {len(truncated_data)}")
            elif isinstance(data, dict):
                # If data is a dictionary, we need to figure out which key contains the data
                # Try common keys like 'data', 'examples', etc.
                data_keys = [k for k in data.keys() if isinstance(data[k], list)]

                if data_keys:
                    main_key = data_keys[0]  # Use the first list as the main data
                    truncated_data = dict(data)
                    truncated_data[main_key] = data[main_key][:num_samples]
                    print(f"  Original samples: {len(data[main_key])}, Truncated to: {len(truncated_data[main_key])}")
                else:
                    # If no lists found, just use the original data (though this is unexpected)
                    truncated_data = data
                    print(f"  Warning: Could not identify data structure to truncate in {dataset_name}")
            else:
                print(f"  Error: Unexpected data format in {dataset_name}")
                return None

            # Save the truncated data
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(truncated_data, f, ensure_ascii=False, indent=2)

        print(f"  Saved to {output_path}")
        return output_path

    except Exception as e:
        print(f"  Error processing {file_path}: {str(e)}")
        return None

In [17]:
# Process each dataset
truncated_files = []

for path in train_paths:
    output_path = truncate_dataset(path, num_samples)
    if output_path:
        truncated_files.append(output_path)

print("\nSummary:")
print(f"Successfully truncated {len(truncated_files)} out of {len(train_paths)} datasets")

Processing 2WikiMultihopQA...
  Original samples: 167454, Truncated to: 300
  Saved to /content/drive/Shareddrives/517 nlp project/data/2WikiMultihopQA/train_truncated_300.json
Processing Bamboogle...
  Original samples: 125, Truncated to: 125
  Saved to /content/drive/Shareddrives/517 nlp project/data/Bamboogle/train_truncated_300.json
Processing FEVER...
  Detected JSONL format for FEVER
  Original samples: 116359, Truncated to: 300
  Saved to /content/drive/Shareddrives/517 nlp project/data/FEVER/fever_train_truncated_300.jsonl
Processing FEVEROUS...
  Detected JSONL format for FEVEROUS
  Original samples: 57033, Truncated to: 300
  Saved to /content/drive/Shareddrives/517 nlp project/data/FEVEROUS/feverous_train_truncated_300.jsonl
Processing HotpotQA...
  Original samples: 90447, Truncated to: 300
  Saved to /content/drive/Shareddrives/517 nlp project/data/HotpotQA/train_truncated_300.json
Processing SVAMP...
  Original samples: 700, Truncated to: 300
  Saved to /content/drive/Sha

In [18]:
# Verify the content of truncated files
for file_path in truncated_files:
    dataset_name = os.path.basename(os.path.dirname(file_path))

    # Check if the file is in JSONL format
    if is_jsonl(file_path):
        # Count the number of lines in the file
        with open(file_path, 'r', encoding='utf-8') as f:
            line_count = sum(1 for line in f if line.strip())
        print(f"{dataset_name}: {line_count} samples (JSONL format)")
    else:
        # Regular JSON format
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        if isinstance(data, list):
            print(f"{dataset_name}: {len(data)} samples")
        elif isinstance(data, dict):
            # Find the first list in the dictionary
            for key, value in data.items():
                if isinstance(value, list):
                    print(f"{dataset_name}: {len(value)} samples in '{key}'")
                    break
            else:
                print(f"{dataset_name}: Structure unclear")

2WikiMultihopQA: 300 samples
Bamboogle: 125 samples
FEVER: 300 samples (JSONL format)
FEVEROUS: 300 samples (JSONL format)
HotpotQA: 300 samples
SVAMP: 300 samples
vitaminc: 300 samples (JSONL format)
