In [6]:
import argparse
import os
import sys
import logging
from dotenv import load_dotenv
import traceback2 as traceback
import json
from datetime import datetime, timezone
import gc

In [14]:
from research_case.analyzers.persona_analysis import PersonaAnalyzer , ExtendedPersonaAnalyzer
from research_case.LLMclients.ollama_client import OllamaClient

load_dotenv()

# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [18]:
def main():
      
    args = {
        "input": "/Users/mogen/Desktop/Research_Case/data/test_preprocessed/processed_users.json",
        "conversations": "/Users/mogen/Desktop/Research_Case/data/test_preprocessed/processed_conversations.json",
        "output": "//Users/mogen/Desktop/Research_Case/results/ollama",
        "n_posts": 200,
        "n_conversations": 5,
        "max_personas": 500,
        "experiment_name": "my_experiment",
        "use_random_fields": True,
        "min_posts_per_user": 2,
        "num_fields": 8
    }
    
    input_path = args["input"]
    conversations_path = args["conversations"]
    output_path = os.path.join(args["output"], args["experiment_name"], "personas.json")

    # Validate input file
    if not os.path.exists(input_path):
        logger.error(f"Input file not found: {input_path}")
        exit(1)
    
    
    llm_client = OllamaClient(model_name="llama3.1:8b")
    analyzer = ExtendedPersonaAnalyzer(llm_client)
    
    # Step 4: Run analysis
    try:
        logger.info("Starting persona analysis...")
        
        # Load and sample users if max_personas is specified
        sampled_users = PersonaAnalyzer.load_and_sample_users(input_path, args["max_personas"], args["min_posts_per_user"])
        
        temp_input = os.path.join(os.path.dirname(input_path), "temp_sampled_users.json")
        with open(temp_input, 'w') as f:
            json.dump(sampled_users, f)
                
        for i in range(3):
            print(f"Iteration {i}")
            analyzer.analyze_persona_from_files(
                posts_path=temp_input,
                conversations_path=conversations_path,
                output_path=output_path,
                n_posts=args["n_posts"],
                n_conversations=args["n_conversations"])
            
        os.remove(temp_input)
            
        logger.info(f"Persona analysis with fields: {args["num_fields"]} completed. Results saved to {output_path}")
            
    except Exception as e:
            logger.error("Failed to analyze personas:")
            logger.error(traceback.format_exc())
            if 'temp_input' in locals() and os.path.exists(temp_input):
                os.remove(temp_input)
            raise
            exit(1)

In [19]:
main()

INFO:research_case.analyzers.persona_analysis:Initialized PersonaPromptGenerator
INFO:__main__:Starting persona analysis...
INFO:research_case.analyzers.persona_analysis:Filtered from 22 to 21 users with at least 2 posts


Iteration 0


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.p

Iteration 1


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.p

Iteration 2


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.persona_analysis:Failed to parse LLM response as JSON: Expecting value: line 1 column 1 (char 0)
INFO:research_case.analyzers.persona_analysis:Successfully extracted and fixed JSON
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
ERROR:research_case.analyzers.p

## Post generation


In [21]:
from pathlib import Path
from typing import Dict, List, Optional
from itertools import islice

import random


from research_case.generators.post_generator import PostGenerator, StimulusGenerator, GenerationPrompt



In [26]:
def select_random_persona_fields(persona: Dict[str, str], max_fields: int) -> Dict[str, str]:
    """
    Select a random number of fields (up to max_fields) from a persona dictionary.
    Excludes any fields with "N/A" values.
    
    Args:
        persona: Complete persona dictionary
        max_fields: Maximum number of fields to select
        
    Returns:
        Dictionary with randomly selected fields
    """
    # Filter out N/A fields
    valid_fields = {k: v for k, v in persona.items() if v != "N/A"}
    
    if not valid_fields:
        raise ValueError("No valid persona fields found (all marked as N/A)")
    
    # Determine how many fields to select (random number from 1 to max_fields)
    num_fields = random.randint(1, min(max_fields, len(valid_fields)))
    
    # Randomly select fields
    selected_keys = random.sample(list(valid_fields.keys()), num_fields)
    
    # Return dictionary with only selected fields
    return {k: valid_fields[k] for k in selected_keys}


# Modify generate_posts_for_user function to use the random selection
def generate_posts_for_user(
    user_id: str,
    persona: Dict,
    original_posts: List[Dict],
    post_generator: PostGenerator,
    stimulus_generator: StimulusGenerator,
    posts_per_persona: int,
    max_persona_fields: int  # New parameter
) -> List[Dict]:
    """
    Generate posts for a single user and return in flat structure.
    
    Args:
        user_id: User identifier
        persona: User's persona data
        original_posts: List of original posts
        post_generator: PostGenerator instance
        stimulus_generator: StimulusGenerator instance
        posts_per_persona: Number of posts to generate
        max_persona_fields: Maximum number of persona fields to use
        
    Returns:
        List of generated post records
    """
    generated_records = []
    timestamp = datetime.now(timezone.utc).isoformat()
    
    for i in range(posts_per_persona):
        if i < len(original_posts):
            # Extract original post content
            original_post = original_posts[i].get('full_text', '')
            original_post_id = original_posts[i].get('tweet_id', '')
            original_timestamp = original_posts[i].get('created_at', '')
            
            # For each post, randomly select persona fields
            random_persona = select_random_persona_fields(persona, max_persona_fields)
            
            # Create stimulus and generate new post with random persona subset
            stimulus = stimulus_generator.create_post_stimulus(original_post)
            prompt = GenerationPrompt(persona=random_persona, stimulus=stimulus)
            generated_text = post_generator.generate_post(prompt)
            
            # Create flat record structure
            record = {
                "user_id": user_id,
                "generation_id": f"{user_id}_gen_{i}",  # Unique identifier for generated post
                "original_post_id": original_post_id,
                "original_text": original_post,
                "original_timestamp": original_timestamp,
                "stimulus": stimulus,
                "generated_text": generated_text,
                "generation_timestamp": timestamp,
                "used_persona_fields": list(random_persona.keys()),  # Track which fields were used
                **{f"persona_{k}": v for k, v in persona.items()},  # Still include all fields in record
            }
            
            generated_records.append(record)
    
    return generated_records

In [37]:
#parameter to run as a notebook 
args = {
    "personas": "/Users/mogen/Desktop/Research_Case/results/ollama/my_experiment/personas.json",  # Required, path to JSON file with user personas
    "posts": "/Users/mogen/Desktop/Research_Case/data/test_preprocessed/processed_users.json",  # Default path to posts file
    "output": "/Users/mogen/Desktop/Research_Case/results/ollama/my_experiment/generated_posts.json",
    "posts_per_persona": 5,  # Default number of posts per persona
    "max_persona_fields": 3  
}

In [38]:
# Step 2: Set up file paths
personas_path = args["personas"]
posts_path = args["posts"]
if not args['output']:
    output_path = os.path.join(
        os.path.dirname(personas_path), "generated_posts.json"
    )
else:
    output_path=args["output"] 

# Validate input files
for path in [personas_path, posts_path]:
    if not os.path.exists(path):
        logger.error(f"Input file not found: {path}")
        exit(1)

# Step 3: Load environment variables and initialize clients
load_dotenv()
llm_client = OllamaClient(model_name="llama3.1:8b")
post_generator = PostGenerator(llm_client)
stimulus_generator = StimulusGenerator(llm_client)

# Step 4: Run generation
try:
    logger.info("Starting post generation...")
    
    # Load personas and posts
    with open(personas_path, 'r') as f:
        personas = json.load(f)
    with open(posts_path, 'r') as f:
        original_posts = json.load(f)

    # Initialize results list for flat structure
    all_generated_records = []
    # Generate posts for each persona
    for user_id, persona in personas.items():
        logger.info(f"Generating posts for user {user_id}")
        
        # Get user's original posts
        user_posts = original_posts.get(user_id, [])
        if not user_posts:
            logger.warning(f"No original posts found for user {user_id}")
            continue
        
        # Generate posts for this user with random persona fields
        max_fields = args["max_persona_fields"] or len(persona)  # Use all fields if max not specified
        user_records = generate_posts_for_user(
            user_id=user_id,
            persona=persona,
            original_posts=user_posts,
            post_generator=post_generator,
            stimulus_generator=stimulus_generator,
            posts_per_persona=args["posts_per_persona"],
            max_persona_fields=max_fields
        )
        
        all_generated_records.extend(user_records)
        # Add generation run metadata
        output_data = {
            "metadata": {
                "generation_timestamp": datetime.now(timezone.utc).isoformat(),
                "num_users": len(personas),
                "posts_per_persona": args["posts_per_persona"],
                "total_posts_generated": len(all_generated_records)
            },
            "generated_posts": all_generated_records
        }

        # Save results
        with open(output_path, 'w') as f:
            json.dump(output_data, f, indent=4)

        logger.info(f"Post generation completed. Generated {len(all_generated_records)} posts.")
        logger.info(f"Results saved to {output_path}")

except Exception as e:
    logger.error("Failed to generate posts:")
    logger.error(traceback.format_exc())
    raise


INFO:__main__:Starting post generation...
INFO:__main__:Generating posts for user 254185636.0
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
INFO:__main__:Post generation completed. Generated 5 posts.
INFO:_