In [None]:
import torch
import json
import os
from transformers import BertTokenizer
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.multiprocessing import Pool, set_start_method
import numpy as np

print("Script started")  # Debug: Confirm script starts

# ----------------- Config -----------------
base_dir = os.path.abspath("")
tokenizer_dir = os.path.join(base_dir, "final_tokenizer")  # Folder containing vocab.txt, tokenizer.json, etc.
lora_dataset_path = os.path.join(base_dir, "datasets/dataset_single_lora.jsonl")  # Specific JSONL file
output_dir = os.path.join(base_dir, "tokenized_output")  # Directory to save tokenized data
os.makedirs(output_dir, exist_ok=True)

# Set visible GPUs (skip 0, 1, 5; use 2, 3, 4, 6, 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,6,7"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"ðŸ”‘ Device in use: {device}")  # Verification print statement
num_gpus = torch.cuda.device_count()
print(f"ðŸ”‘ Using {num_gpus} GPUs: {torch.cuda.get_device_name(0)} et al.")

# Tokenizer parameters
max_length = 512  # Maximum sequence length for BERT
stride = 128      # Stride for overlapping chunks

# Load the custom tokenizer (using transformers)
print("ðŸ”‘ Loading custom tokenizer...")
try:
    tokenizer = BertTokenizer.from_pretrained(tokenizer_dir)
    print("âœ… Loaded tokenizer using transformers")
except Exception as e:
    print(f"âš  Error loading tokenizer: {e}")
    raise

# Custom Dataset for GPU processing
class LoRADataset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return sample['input'], sample['output']

# Tokenization function for GPU (defined at module level)
def tokenize_chunk(chunk_samples):
    chunk_data = []
    for i, sample in enumerate(chunk_samples):  # Unpack sample dict directly
        text = sample['input']
        label = sample['output']
        # Move tokenization to GPU
        encoded = tokenizer.encode_plus(
            text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(device)
        
        # Chunking logic
        chunks = []
        input_ids = encoded['input_ids'][0]  # Shape: (max_length,)
        for start in range(0, max_length, max_length - stride):
            end = min(start + max_length, max_length)
            chunk_ids = input_ids[start:end]
            if torch.any(chunk_ids != 0):  # Skip padding
                chunks.append({
                    "token_ids": chunk_ids.cpu().numpy().tolist(),
                    "start_idx": start,
                    "end_idx": end
                })
        
        chunk_data.append({
            "sample_index": i + 1,
            "label": label,
            "num_chunks": len(chunks),
            "chunks": chunks
        })
    return chunk_data

if __name__ == '__main__':
    set_start_method('spawn', force=True)
    print("Starting main process")  # Debug: Confirm main process starts

    try:
        print(f"Loading dataset from {lora_dataset_path}")  # Debug: Confirm file access attempt
        # Read JSONL file
        valid_data = []
        with open(lora_dataset_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f, 1):
                try:
                    obj = json.loads(line.strip())
                    valid_data.append(obj)
                except json.JSONDecodeError as e:
                    print(f"âš  Line {i} is malformed: {e}. Skipping.")
                    continue
        df = pd.DataFrame(valid_data)
        total_samples = len(df)
        print(f"ðŸ“Š Loaded dataset with {total_samples} rows. Columns: {df.columns.tolist()}")
        if total_samples != 18208:
            print(f"âš  Expected 18,208 samples, but found {total_samples}. Proceeding anyway.")
        
        # Validate columns
        text_column = 'input'
        label_column = 'output'
        if text_column not in df.columns or label_column not in df.columns:
            print(f"âš  Columns {text_column} or {label_column} not found. Available columns: {df.columns.tolist()}")
        else:
            samples = df[[text_column, label_column]].to_dict(orient='records')
            
            # Split samples into chunks for each GPU
            def split_into_chunks(samples, num_chunks):
                chunk_size = len(samples) // num_chunks
                chunks = []
                for i in range(num_chunks):
                    start = i * chunk_size
                    end = start + chunk_size if i < num_chunks - 1 else len(samples)
                    chunks.append(samples[start:end])
                return chunks
            
            # Parallel processing across GPUs
            print("\nðŸš€ Tokenizing entire dataset with chunking and saving to file...")
            sample_chunks = split_into_chunks(samples, num_gpus)
            with Pool(processes=num_gpus) as pool:
                results = pool.map(tokenize_chunk, sample_chunks)  # Pass only chunks
            
            # Combine results
            all_tokenized_data = []
            for chunk_results in results:
                all_tokenized_data.extend(chunk_results)
            
            # Validate and save
            if len(all_tokenized_data) == total_samples:
                print(f"âœ… Processed exactly {total_samples} samples.")
            else:
                print(f"âš  Processed {len(all_tokenized_data)} samples, expected {total_samples}.")
            
            # Save to a JSON file
            output_file = os.path.join(output_dir, "tokenized_full_dataset.json")
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(all_tokenized_data, f, ensure_ascii=False, indent=2)
            print(f"ðŸ“„ Saved tokenized data for entire dataset to {output_file}")

    except FileNotFoundError:
        print(f"âš  File {lora_dataset_path} not found. Please check the path.")
    except Exception as e:
        print(f"âš  Error: {e}")