# Extracts hidden states of QWEN model given different datasets

In [None]:
from Get_Go_Emo import get_go
from Get_Isear import get_isr

In [None]:
goEmo = get_go()

In [None]:
isear = get_isr()

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import json
import time
import os  # Import os module for directory operations


def extract_hidden_states(df, model_names, text_column='clean_text', batch_size=16, dataset_name="no_dataset_selected", 
                         device='cuda' if torch.cuda.is_available() else 'cpu', start_from_batch=0, output_directory_name='hidden_states',
                         use_external_storage=False, storage_device_name='Media', external_storage_path='/Volumes/'):
    """
    Extracts hidden states for each text in the DataFrame using specified models.
    
    Args:
        df (pd.DataFrame): Input DataFrame containing the text data.
        model_names (list): List of model names to extract hidden states from.
        text_column (str): Name of the column containing text data.
        batch_size (int): Batch size for processing.
        dataset_name (str): Name of the dataset being processed.
        device (str): Device to run the model on ('cuda' or 'cpu').
        start_from_batch (int): Batch number to start processing from (0-based index).
        use_external_storage (bool): If True, save to external_storage_path. Defaults to False.
        external_storage_path (str): Path to external storage. Defaults to '/Volumes/Media/'.
    """
    
    
    
    # Determine base directory based on storage choice
    if use_external_storage:
        base_dir = external_storage_path + storage_device_name + '/' + output_directory_name
    else:
        base_dir = output_directory_name
    
    # Create directory if it doesn't exist
    os.makedirs(base_dir, exist_ok=True)
    
    for model_name in model_names:
        print(f"\nProcessing model: {model_name}")
        model_start_time = time.time()
        print_name = model_name.replace('/', '')
        print(print_name)  # Output: 'QwenQwen2-7B'
        # Load tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
        model.eval()
        model.to(device)
        
        # Handle missing padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Tokenize all texts
        texts = df[text_column].tolist()
        tokenized = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        
        # Create DataLoader
        input_ids = tokenized['input_ids']
        attention_mask = tokenized['attention_mask']
        dataset = TensorDataset(input_ids, attention_mask)
        dataloader = DataLoader(dataset, batch_size=batch_size)
        
        total_batches = len(dataloader)
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                # Skip batches before the starting batch
                if batch_idx < start_from_batch:
                    continue
                    
                all_hidden_dicts = []
                batch_start_time = time.time()
                
                input_ids_batch, attention_mask_batch = [t.to(device) for t in batch]
                
                # Get model outputs
                outputs = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch)
                hidden_states = outputs.hidden_states
                
                # Process each example in the batch
                current_batch_size = input_ids_batch.size(0)
                for i in range(current_batch_size):
                    example_hidden = {}
                    for layer_idx, layer in enumerate(hidden_states):
                        cls_embedding = layer[i, 0, :].cpu().numpy().tolist()
                        example_hidden[f'layer_{layer_idx}'] = cls_embedding
                    all_hidden_dicts.append(example_hidden)
                
                # Calculate batch processing time
                batch_time = time.time() - batch_start_time
                
                # Print progress with time information
                print(
                    f"Batch {batch_idx + 1}/{total_batches} | "
                    f"Time: {batch_time:.2f}s | "
                    f"Avg: {(time.time() - model_start_time)/(batch_idx + 1 - start_from_batch):.2f}s/batch", 
                    end='\r'
                )

                # Save individual model's hidden states to JSON file in the determined directory
                output_filename = os.path.join(base_dir, f"{print_name}_{dataset_name}_{batch_idx}.json")
                with open(output_filename, 'w') as f:
                    json.dump(all_hidden_dicts, f, indent=2)  # indent for pretty-printing
                    
        
        # Print final summary
        total_time = time.time() - model_start_time
        processed_batches = total_batches - start_from_batch
        print(f"\nCompleted {model_name} in {total_time:.2f}s ({total_time/processed_batches:.4f}s/batch)")
        print(f"Saved hidden states to {output_filename}")
        
        # Cleanup
        del model, tokenizer
        torch.cuda.empty_cache()
    
    print(f"\nAll models processed and hidden states saved in '{base_dir}' directory.")

In [None]:
# other Model names
# 'Qwen/Qwen2-7B'
# 'Qwen/Qwen-7B'
# 'gpt2'
# 'bert-base-uncased'
# 'Qwen/Qwen2-0.5B'

# Example usage
model_names = ['Qwen/Qwen2-0.5B']  # Replace with your models

# Process goEmo dataset
goEmo_with_hidden = extract_hidden_states(goEmo, model_names, start_from_batch=654, dataset_name="goEmo", use_external_storage=True)

In [None]:
from analysis import describe_hidden_states, analyze_hidden_states, describe_all_hidden_states

In [None]:
# Example usage
analysis = analyze_hidden_states()

In [None]:
# Describe all files in the default hidden_states directory
describe_all_hidden_states()