# RAG-based Social Media Feed Generator for Fictional Universes

### John Skorcik

In [None]:
# Install required packages
!pip install -q langchain llama-cpp-python faiss-cpu sentence-transformers tqdm pydantic transformers accelerate networkx pyvis neo4j
!pip install -U langchain-community

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 MB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.9/41.9 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m345.7/345.7 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Import necessary libraries
import os
import json
import random
import numpy as np
import networkx as nx
import torch

from datetime import datetime, timedelta
from tqdm.notebook import tqdm
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
from IPython.display import HTML, display, IFrame

# RAG Components

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline



print("Libraries and Componenets successfully imported.")

Libraries and Componenets successfully imported.


In [None]:
# Set up a smaller open-source model using Hugging Face's models
def setup_huggingface_pipeline():
    """Set up a Hugging Face pipeline for text generation using a smaller model"""
    print("Loading model - this might take a minute...")

    # Using a smaller but capable model like Phi-2
    model_id = "microsoft/phi-2"  # alternatively, we could use "google/gemma-2b-it" or "mistralai/Mistral-7B-Instruct-v0.2"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )

    # Create a text generation pipeline
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=1024,
        temperature=0.7,
        top_p=0.9,
        #do_sample=True,
        repetition_penalty=1.05
    )

    # Create the LangChain wrapper
    llm = HuggingFacePipeline(pipeline=pipe)

    print("Model loaded successfully!")
    return llm

### Define social media post data models

In [None]:
class UserProfile(BaseModel):
    id: str
    name: str
    username: str
    bio: Optional[str] = None
    avatar_emoji: str = "👤"  # Using emoji as avatar placeholder
    followers: int = Field(default_factory=lambda: random.randint(50, 10000))
    following: int = Field(default_factory=lambda: random.randint(50, 500))


In [None]:
class SocialMediaPost(BaseModel):
    id: str
    user: UserProfile
    content: str
    likes: int = Field(default_factory=lambda: random.randint(0, 1000))
    shares: int = Field(default_factory=lambda: random.randint(0, 100))
    comments: int = Field(default_factory=lambda: random.randint(0, 50))
    timestamp: datetime = Field(default_factory=lambda: datetime.now() - timedelta(
        hours=random.randint(0, 72)))
    referenced_entities: List[str] = []

    def format_time(self) -> str:
        delta = datetime.now() - self.timestamp
        if delta.days > 0:
            return f"{delta.days}d ago"
        elif delta.seconds // 3600 > 0:
            return f"{delta.seconds // 3600}h ago"
        else:
            return f"{delta.seconds // 60}m ago"

    def display_post(self):
        """Display the post in a social media-like format"""
        html = f"""
        <div style="border: 1px solid #ddd; border-radius: 8px; padding: 12px; margin-bottom: 16px; max-width: 500px; font-family: Arial, sans-serif;">
            <div style="display: flex; align-items: center; margin-bottom: 8px;">
                <div style="font-size: 32px; margin-right: 12px;">{self.user.avatar_emoji}</div>
                <div>
                    <div style="font-weight: bold;">{self.user.name}</div>
                    <div style="color: #536471; font-size: 14px;">@{self.user.username}</div>
                </div>
            </div>
            <p style="margin: 12px 0; font-size: 15px; line-height: 1.4;">{self.content}</p>
            <div style="color: #536471; font-size: 14px; margin-top: 12px;">{self.format_time()}</div>
            <div style="display: flex; justify-content: space-between; color: #536471; font-size: 14px; margin-top: 12px;">
                <div>❤️ {self.likes}</div>
                <div>🔄 {self.shares}</div>
                <div>💬 {self.comments}</div>
            </div>
        </div>
        """
        display(HTML(html))

In [None]:
class Entity(BaseModel):
    """Base class for all entities in the knowledge graph"""
    id: str
    name: str
    type: str
    description: str
    attributes: Dict[str, Any] = Field(default_factory=dict)

    def __str__(self):
        return f"{self.name} ({self.type})"

In [None]:
class Relationship(BaseModel):
    """Base class for relationships between entities"""
    source_id: str
    target_id: str
    type: str
    description: str
    attributes: Dict[str, Any] = Field(default_factory=dict)

    def __str__(self):
        return f"{self.type}: {self.description}"

In [None]:
class KnowledgeGraph:
    """Class to manage the knowledge graph for a fictional universe"""

    def __init__(self):
        self.entities = {}  # id -> Entity
        self.relationships = []  # List of Relationship objects
        self.graph = nx.DiGraph()  # NetworkX graph

    def add_entity(self, entity: Entity):
        """Add an entity to the knowledge graph"""
        self.entities[entity.id] = entity
        self.graph.add_node(entity.id, **entity.model_dump())

    def add_relationship(self, relationship: Relationship):
        """Add a relationship to the knowledge graph"""
        self.relationships.append(relationship)
        self.graph.add_edge(
            relationship.source_id,
            relationship.target_id,
            type=relationship.type,
            description=relationship.description,
            **relationship.attributes
        )

    def get_entity_by_name(self, name: str) -> Optional[Entity]:
        """Find an entity by name (case-insensitive)"""
        for entity in self.entities.values():
            if entity.name.lower() == name.lower():
                return entity
        return None

    def get_related_entities(self, entity_id: str) -> List[Dict]:
        """Get all entities related to the given entity"""
        if entity_id not in self.graph:
            return []

        related = []
        for neighbor in self.graph.neighbors(entity_id):
            edge_data = self.graph.get_edge_data(entity_id, neighbor)
            entity = self.entities[neighbor]
            related.append({
                "entity": entity,
                "relationship": edge_data
            })

        # Also check incoming edges
        for predecessor in self.graph.predecessors(entity_id):
            if predecessor != entity_id:  # Skip self-loops
                edge_data = self.graph.get_edge_data(predecessor, entity_id)
                entity = self.entities[predecessor]
                related.append({
                    "entity": entity,
                    "relationship": edge_data
                })

        return related

    def get_subgraph_for_entity(self, entity_id: str, depth: int = 2) -> nx.DiGraph:
        """Get a subgraph centered on the given entity with a specified depth"""
        if entity_id not in self.graph:
            return nx.DiGraph()

        # Find all nodes within "depth" steps of the entity
        nodes = {entity_id}
        current_layer = {entity_id}

        for _ in range(depth):
            next_layer = set()
            for node in current_layer:
                # Add outgoing neighbors
                next_layer.update(self.graph.neighbors(node))
                # Add incoming neighbors
                next_layer.update(self.graph.predecessors(node))

            nodes.update(next_layer)
            current_layer = next_layer

        # Create subgraph with these nodes
        return self.graph.subgraph(nodes).copy()

    def visualize(self, filename="knowledge_graph.html", height=800):
        """Visualize the knowledge graph using pyvis"""
        try:
            from pyvis.network import Network

            net = Network(height=f"{height}px", width="100%", directed=True, notebook=True)

            # Add nodes
            for entity_id, entity in self.entities.items():
                node_title = f"{entity.name}: {entity.description}"
                node_color = self._get_color_for_entity_type(entity.type)
                net.add_node(entity_id, label=entity.name, title=node_title, color=node_color)

            # Add edges
            for relationship in self.relationships:
                edge_title = relationship.description
                net.add_edge(relationship.source_id, relationship.target_id, title=edge_title, label=relationship.type)

            # Save and display
            net.save_graph(filename)
            return IFrame(filename, width="100%", height=height)

        except ImportError:
            print("pyvis is required for visualization. Install with: pip install pyvis")
            return None

    def _get_color_for_entity_type(self, entity_type: str) -> str:
        """Return a color based on entity type"""
        color_map = {
            "Character": "#4285F4",  # Blue
            "Location": "#34A853",   # Green
            "Event": "#FBBC05",      # Yellow
            "Item": "#EA4335",       # Red
            "Organization": "#8E24AA", # Purple
            "Concept": "#00ACC1"     # Teal
        }
        return color_map.get(entity_type, "#9E9E9E")  # Default gray

    def to_dict(self) -> Dict:
        """Convert the knowledge graph to a dictionary representation"""
        return {
            "entities": {id: entity.model_dump() for id, entity in self.entities.items()},
            "relationships": [r.dict() for r in self.relationships]
        }

    def to_json(self, filename: str):
        """Save the knowledge graph to a JSON file"""
        with open(filename, 'w') as f:
            json.dump(self.to_dict(), f, indent=2, default=str)

    @classmethod
    def from_dict(cls, data: Dict) -> 'KnowledgeGraph':
        """Create a knowledge graph from a dictionary representation"""
        kg = cls()

        for entity_id, entity_data in data["entities"].items():
            entity = Entity(**entity_data)
            kg.add_entity(entity)

        for relationship_data in data["relationships"]:
            relationship = Relationship(**relationship_data)
            kg.add_relationship(relationship)

        return kg

    @classmethod
    def from_json(cls, filename: str) -> 'KnowledgeGraph':
        """Load a knowledge graph from a JSON file"""
        with open(filename, 'r') as f:
            data = json.load(f)

        return cls.from_dict(data)

In [None]:
class UniverseKG:
    """A class to manage fictional universe knowledge through a knowledge graph"""

    def __init__(self, universe_name: str):
        self.universe_name = universe_name
        self.knowledge_graph = KnowledgeGraph()
        self.vectorstore = None
        self.embeddings = None
        self.llm = None
        self.universe_summary = ""
        self.expansion_chain = None
        self.entity_chain = None
        self.relationship_chain = None

        # Create a universe entity
        universe_entity = Entity(
            id=f"universe_{self._generate_id()}",
            name=universe_name,
            type="Universe",
            description=f"The {universe_name} fictional universe"
        )
        self.knowledge_graph.add_entity(universe_entity)
        self.universe_id = universe_entity.id

    def initialize_llm(self):
        """Initialize the LLM for knowledge graph operations"""
        print("Initializing LLM...")

        self.llm = setup_huggingface_pipeline()

        # Universe expansion prompt
        expansion_template = """<s>[INST] You are a knowledge graph expert for the fictional universe: {universe_name}.

I need you to expand on the universe concept with creative and consistent worldbuilding details. Generate a rich and
coherent description of this universe based on the name alone. Think about the core themes, setting, timeline, magic
or technology systems, major factions, and overall atmosphere.

Your description should be detailed, imaginative, and internally consistent. Make sure the universe you describe
would work well for interactive storytelling and social media simulation.

Current universe information:
{universe_description}

Please generate an expanded universe summary that adds depth and richness to this concept. [/INST]"""

        expansion_prompt = PromptTemplate(
            input_variables=["universe_name", "universe_description"],
            template=expansion_template
        )

        self.expansion_chain = LLMChain(llm=self.llm, prompt=expansion_prompt)

        # Entity generation prompt
        entity_template = """<s>[INST] You are a knowledge graph expert for the fictional universe: {universe_name}.

Based on the universe description below, generate detailed information for {num_entities} new {entity_type} entities
that would exist in this universe. Make these entities interesting, diverse, and consistent with the universe lore.

Universe description:
{universe_description}

For each entity, provide:
1. A name
2. A detailed description (2-3 sentences)
3. Several key attributes or characteristics
4. How this entity fits into the universe

Current entities already in the universe:
{existing_entities}

Please generate these new entities in a creative but consistent way that adds depth to the universe.
Do not regenerate existing entities. [/INST]"""

        entity_prompt = PromptTemplate(
            input_variables=["universe_name", "universe_description", "entity_type", "num_entities", "existing_entities"],
            template=entity_template
        )

        self.entity_chain = LLMChain(llm=self.llm, prompt=entity_prompt)

        # Relationship generation prompt
        relationship_template = """<s>[INST] You are a knowledge graph expert for the fictional universe: {universe_name}.

I need you to identify meaningful relationships between entities in this universe.
Based on the universe description and the entities listed below, generate {num_relationships} realistic
relationships between these entities.

Universe description:
{universe_description}

Entities:
{entity_list}

For each relationship, specify:
1. The source entity
2. The target entity
3. The type of relationship (e.g., "FRIENDS_WITH", "ENEMY_OF", "LOCATED_IN", "MEMBER_OF", "CREATED", etc.)
4. A brief description of the relationship (1-2 sentences)

Make sure the relationships are logical and consistent with what we know about the universe and these entities.
For characters, think about their alliances, rivalries, family ties, and organizational affiliations.
For locations, consider which entities might be located there or have special connections to these places.
For events, think about which entities participated in or were affected by them.

Please generate diverse and interesting relationships that enrich the universe's narrative. [/INST]"""

        relationship_prompt = PromptTemplate(
            input_variables=["universe_name", "universe_description", "entity_list", "num_relationships"],
            template=relationship_template
        )

        self.relationship_chain = LLMChain(llm=self.llm, prompt=relationship_prompt)

        print("LLM initialized successfully!")

    def expand_universe(self):
        """Expand the universe description using the LLM"""
        if not self.llm:
            self.initialize_llm()

        universe_entity = self.knowledge_graph.entities[self.universe_id]
        current_description = universe_entity.description

        print("Expanding universe concept...")
        expanded_description = self.expansion_chain.run(
            universe_name=self.universe_name,
            universe_description=current_description
        )

        # Update universe entity
        universe_entity.description = expanded_description
        self.universe_summary = expanded_description

        # Update in knowledge graph
        self.knowledge_graph.entities[self.universe_id] = universe_entity

        return expanded_description

    def generate_entities(self, entity_type: str, num_entities: int = 5):
        """Generate new entities for the universe"""
        if not self.llm:
            self.initialize_llm()

        # Get existing entities of this type for context
        existing_entities = "\n".join([
            f"- {entity.name}: {entity.description}"
            for entity in self.knowledge_graph.entities.values()
            if entity.type == entity_type
        ])

        if not self.universe_summary:
            self.universe_summary = self.knowledge_graph.entities[self.universe_id].description

        print(f"Generating {num_entities} {entity_type} entities...")

        # More structured prompt to get consistent output
        entity_template = """
        You are creating {num_entities} new {entity_type} entities for the fictional universe: {universe_name}.

        Universe description:
        {universe_description}

        Existing {entity_type} entities:
        {existing_entities}

        Generate exactly {num_entities} new {entity_type} entities.

        FORMAT YOUR RESPONSE EXACTLY AS FOLLOWS FOR EACH ENTITY:

        ###ENTITY###
        Name: [entity name]
        Description: [brief description]
        Attribute1: [value]
        Attribute2: [value]
        (add more attributes as appropriate)
        ###END###

        Make each entity interesting and fitting for the universe. Do not include any other text outside the specified format.
        """

        generation_prompt = PromptTemplate(
            input_variables=["universe_name", "universe_description", "entity_type", "num_entities", "existing_entities"],
            template=entity_template
        )

        entity_chain = LLMChain(llm=self.llm, prompt=generation_prompt)

        generation_result = entity_chain.run(
            universe_name=self.universe_name,
            universe_description=self.universe_summary,
            entity_type=entity_type,
            num_entities=num_entities,
            existing_entities=existing_entities
        )

        print("Raw generation result:")
        print(generation_result[:500] + "..." if len(generation_result) > 500 else generation_result)

        # Process the result to extract entities using the ###ENTITY### markers
        entities = []
        entity_blocks = generation_result.split("###ENTITY###")

        for block in entity_blocks:
            if "###END###" not in block:
                continue

            entity_text = block.split("###END###")[0].strip()
            if not entity_text:
                continue

            entity_data = {"attributes": {}}

            # Parse the entity data
            lines = entity_text.split('\n')
            for line in lines:
                line = line.strip()
                if not line:
                    continue

                if ":" in line:
                    key, value = line.split(":", 1)
                    key = key.strip()
                    value = value.strip()

                    if key.lower() == "name":
                        entity_data["name"] = value
                    elif key.lower() == "description":
                        entity_data["description"] = value
                    else:
                        # Everything else is an attribute
                        entity_data["attributes"][key] = value

            if "name" in entity_data and "description" in entity_data:
                entities.append(entity_data)

        print(f"Extracted {len(entities)} entities from the generation result")

        # Create and add entities to the knowledge graph
        added_entities = []
        for entity_data in entities:
            entity = Entity(
                id=f"{entity_type.lower()}_{self._generate_id()}",
                name=entity_data["name"],
                type=entity_type,
                description=entity_data["description"],
                attributes=entity_data.get("attributes", {})
            )

            self.knowledge_graph.add_entity(entity)
            added_entities.append(entity)

            # Connect to universe
            relationship = Relationship(
                source_id=entity.id,
                target_id=self.universe_id,
                type="PART_OF",
                description=f"{entity.name} is part of the {self.universe_name} universe"
            )
            self.knowledge_graph.add_relationship(relationship)

        print(f"Added {len(added_entities)} {entity_type} entities to the universe")
        return added_entities

    def generate_relationships(self, num_relationships: int = 10):
        """Generate relationships between existing entities"""
        if not self.llm:
            self.initialize_llm()

        if len(self.knowledge_graph.entities) < 3:
            print("Not enough entities to generate meaningful relationships")
            return []

        # Prepare entity list for the prompt
        entity_list = "\n".join([
            f"- {entity.name} ({entity.type}): {entity.description}"
            for entity_id, entity in self.knowledge_graph.entities.items()
            if entity_id != self.universe_id  # Exclude the universe itself
        ])

        if not self.universe_summary:
            self.universe_summary = self.knowledge_graph.entities[self.universe_id].description

        print(f"Generating {num_relationships} relationships...")
        generation_result = self.relationship_chain.run(
            universe_name=self.universe_name,
            universe_description=self.universe_summary,
            entity_list=entity_list,
            num_relationships=num_relationships
        )

        # Process the result to extract relationships
        # This is a simplified parsing logic - for production, use a more robust approach
        relationships = []
        lines = generation_result.strip().split('\n')

        current_relationship = {}
        for line in lines:
            line = line.strip()
            if not line:
                continue

            # New relationship starts with a number
            if line[0].isdigit() or line.startswith("-"):
                # Save previous relationship if it exists
                if current_relationship.get("source") and current_relationship.get("target"):
                    relationships.append(current_relationship)
                    current_relationship = {}

                # This might be the beginning of a new relationship
                continue

            # Source entity
            if line.startswith("Source:") or line.startswith("From:"):
                current_relationship["source"] = line.split(":", 1)[1].strip()

            # Target entity
            elif line.startswith("Target:") or line.startswith("To:"):
                current_relationship["target"] = line.split(":", 1)[1].strip()

            # Relationship type
            elif line.startswith("Type:") or line.startswith("Relationship:"):
                current_relationship["type"] = line.split(":", 1)[1].strip()

            # Description
            elif line.startswith("Description:"):
                current_relationship["description"] = line.split(":", 1)[1].strip()

        # Add the last relationship
        if current_relationship.get("source") and current_relationship.get("target"):
            relationships.append(current_relationship)

        # Create and add relationships to the knowledge graph
        added_relationships = []
        for rel_data in relationships:
            if not all(k in rel_data for k in ["source", "target", "type", "description"]):
                continue  # Skip incomplete relationships

            # Find the entity IDs
            source_entity = self.knowledge_graph.get_entity_by_name(rel_data["source"])
            target_entity = self.knowledge_graph.get_entity_by_name(rel_data["target"])

            if not source_entity or not target_entity:
                continue  # Skip if entities not found

            relationship = Relationship(
                source_id=source_entity.id,
                target_id=target_entity.id,
                type=rel_data["type"].upper().replace(" ", "_"),
                description=rel_data["description"]
            )

            self.knowledge_graph.add_relationship(relationship)
            added_relationships.append(relationship)

        print(f"Added {len(added_relationships)} relationships to the universe")
        return added_relationships

    def initialize_embeddings(self):
        """Initialize the embeddings for vector search"""
        print("Initializing embeddings...")
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        print("Embeddings initialized successfully!")

    def build_vectorstore(self):
        """Build a vector store from knowledge graph entities and relationships"""
        if not self.embeddings:
            self.initialize_embeddings()

        print("Building vector store...")

        # Prepare documents
        docs = []

        # Add entity documents
        for entity_id, entity in self.knowledge_graph.entities.items():
            doc_text = f"Name: {entity.name}\nType: {entity.type}\nDescription: {entity.description}\n"

            # Add attributes
            if entity.attributes:
                doc_text += "Attributes:\n"
                for key, value in entity.attributes.items():
                    doc_text += f"- {key}: {value}\n"

            docs.append(doc_text)

        # Add relationship documents
        for relationship in self.knowledge_graph.relationships:
            source_entity = self.knowledge_graph.entities[relationship.source_id]
            target_entity = self.knowledge_graph.entities[relationship.target_id]

            doc_text = f"Relationship: {source_entity.name} {relationship.type} {target_entity.name}\n"
            doc_text += f"Description: {relationship.description}\n"

            docs.append(doc_text)

        # Create vector store
        self.vectorstore = FAISS.from_texts(docs, self.embeddings)

        print(f"Vector store built successfully with {len(docs)} documents")
        return self.vectorstore

    def get_entity_context(self, entity_name: str, max_depth: int = 2) -> str:
        """Get context information about an entity and its relationships"""
        entity = self.knowledge_graph.get_entity_by_name(entity_name)
        if not entity:
            return f"No information found about {entity_name}"

        context = f"Entity: {entity.name} ({entity.type})\n"
        context += f"Description: {entity.description}\n\n"

        # Add attributes
        if entity.attributes:
            context += "Attributes:\n"
            for key, value in entity.attributes.items():
                context += f"- {key}: {value}\n"
            context += "\n"

        # Add relationships
        related = self.knowledge_graph.get_related_entities(entity.id)
        if related:
            context += "Relationships:\n"
            for rel in related:
                related_entity = rel["entity"]
                relationship = rel["relationship"]
                context += f"- {relationship['type']}: {related_entity.name} - {relationship['description']}\n"

        return context

    def retrieve_related_context(self, query: str, k: int = 3) -> str:
        """Retrieve context information relevant to a query"""
        if not self.vectorstore:
            self.build_vectorstore()

        docs = self.vectorstore.similarity_search(query, k=k)
        context = "\n\n".join([doc.page_content for doc in docs])
        return context

    def _generate_id(self) -> str:
        """Generate a random ID string"""
        return ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=8))

    def auto_populate(self, num_characters=5, num_locations=3, num_events=2, num_items=2, num_relationships=15):
        """Automatically populate the universe with various entity types and relationships"""
        if not self.llm:
            self.initialize_llm()

        # First, expand the universe concept
        self.expand_universe()

        # Generate entities
        print(f"Attempting to generate {num_characters} characters...")
        if num_characters > 0:
            characters = self.generate_entities("Character", num_characters)
            print(f"Generated {len(characters)} characters: {[c.name for c in characters]}")

        if num_locations > 0:
            self.generate_entities("Location", num_locations)

        if num_events > 0:
            self.generate_entities("Event", num_events)

        if num_items > 0:
            self.generate_entities("Item", num_items)

        # Generate relationships
        if num_relationships > 0:
            self.generate_relationships(num_relationships)

        # Initialize vector store
        self.build_vectorstore()

        print(f"Universe '{self.universe_name}' auto-populated successfully!")
        return self.knowledge_graph

In [None]:
class FeedGenerator:
    """Generator for social media feeds based on fictional universes with KG-based RAG"""

    def __init__(self, universe_kg: UniverseKG):
        self.universe = universe_kg
        self.llm = None
        self.post_chain = None
        self.initialize_llm()

    def initialize_llm(self):
        """Initialize the LLM and chain for generating posts"""
        print("Initializing LLM...")

        self.llm = setup_huggingface_pipeline()

        # Create prompt template for generating posts
        post_template = """<s>[INST] You are a social media post generator for characters in a fictional universe.

Universe Information:
{universe_description}

Character Information:
Name: {character_name}
Description: {character_description}

Additional Context:
{context}

Based on this information and context, generate a realistic social media post that this character would make.
The post should reference universe-specific knowledge, other characters, locations, events, or items in a natural way.
Make the post authentic to how people post on social media. Avoid hashtags unless they're being used ironically.
Keep the post between 1-3 short paragraphs maximum.

Ensure the tone, vocabulary, and content align with the character's personality and role in the universe.
The post should feel like something written spontaneously, not a carefully crafted narrative.

After generating the post, please list any specific entities (characters, locations, events, items, etc.) that were
referenced in the post. Format as a simple comma-separated list. [/INST]"""

        post_prompt = PromptTemplate(
            input_variables=["universe_description", "character_name", "character_description", "context"],
            template=post_template
        )

        self.post_chain = LLMChain(llm=self.llm, prompt=post_prompt)
        print("LLM initialized successfully!")

    def generate_character_username(self, character_name: str) -> str:
        """Generate a plausible username for a character"""
        name_parts = character_name.lower().split()

        # Generate different username styles
        if random.random() < 0.3 and len(name_parts) > 1:
            # First initial + last name
            username = f"{name_parts[0][0]}{name_parts[-1]}"
        elif random.random() < 0.5:
            # Full name with random number
            username = f"{name_parts[0]}{name_parts[-1] if len(name_parts) > 1 else ''}{random.randint(1, 99)}"
        else:
            # Some variation of the name
            if len(name_parts) > 1:
                username = f"the_real_{name_parts[-1]}"
            else:
                username = f"{name_parts[0]}_official"

        # Remove spaces and special characters
        username = ''.join(c for c in username if c.isalnum() or c == '_').lower()

        return username

    def generate_post(self, character_id: Optional[str] = None) -> SocialMediaPost:
        """Generate a social media post from a random or specified character"""
        # Select a character
        if character_id and character_id in self.universe.knowledge_graph.entities:
            character = self.universe.knowledge_graph.entities[character_id]
        else:
            # Get a random character
            characters = [
                entity for entity_id, entity in self.universe.knowledge_graph.entities.items()
                if entity.type == "Character"
            ]
            if not characters:
                raise ValueError("No characters in universe")
            character = random.choice(characters)

        # Get context about this character
        context = self.universe.get_entity_context(character.name)

        # Generate post content
        universe_description = self.universe.universe_summary or self.universe.knowledge_graph.entities[self.universe.universe_id].description

        post_result = self.post_chain.run(
            universe_description=universe_description,
            character_name=character.name,
            character_description=character.description,
            context=context
        )

        # Process the result to extract the post content and referenced entities
        parts = post_result.strip().split("\n\n")

        # The post content is everything except the last part (which contains the entity list)
        post_content = "\n\n".join(parts[:-1]) if len(parts) > 1 else post_result

        # Extract referenced entities if available
        referenced_entities = []
        if len(parts) > 1 and ":" in parts[-1]:
            entity_list = parts[-1].split(":", 1)[1].strip()
            referenced_entities = [entity.strip() for entity in entity_list.split(",")]

        # Create user profile for the character
        profile = UserProfile(
            id=character.id,
            name=character.name,
            username=self.generate_character_username(character.name),
            bio=character.description[:100] + "..." if len(character.description) > 100 else character.description,
            avatar_emoji=random.choice(["👩", "👨", "🧙", "👸", "🤴", "👩‍🚀", "👨‍🚀", "🧝‍♀️", "🧝", "🧛", "🦸", "🦹"])
        )

        # Create the post
        post = SocialMediaPost(
            id=f"post_{random.randint(10000, 99999)}",
            user=profile,
            content=post_content.strip(),
            referenced_entities=referenced_entities
        )

        return post

    def generate_feed(self, num_posts: int = 5) -> List[SocialMediaPost]:
        """Generate a feed with multiple posts from different characters"""
        feed = []

        # Get all characters
        characters = [
            entity_id for entity_id, entity in self.universe.knowledge_graph.entities.items()
            if entity.type == "Character"
        ]

        if not characters:
            raise ValueError("No characters in universe")

        # Make sure we have enough characters
        num_chars = min(num_posts, len(characters))

        # Select random characters (no duplicates)
        if num_chars < len(characters):
            selected_chars = random.sample(characters, num_chars)
        else:
            selected_chars = characters.copy()
            # Fill remaining with random selections (may have duplicates)
            for _ in range(num_posts - num_chars):
                selected_chars.append(random.choice(characters))

        # Generate posts from selected characters
        for char_id in tqdm(selected_chars, desc="Generating posts"):
            post = self.generate_post(char_id)
            feed.append(post)

        # Sort by timestamp to create chronological feed
        feed.sort(key=lambda x: x.timestamp)

        return feed

### Demo

In [None]:
# Demo function to create a universe from just a name
def create_universe_from_name(universe_name: str, auto_populate: bool = True):
    """Create a universe from just a name using the LLM to expand it"""
    universe = UniverseKG(universe_name)

    # Initialize the LLM
    universe.initialize_llm()

    if auto_populate:
        # Auto-populate the universe
        universe.auto_populate(
            num_characters=6,
            num_locations=4,
            num_events=3,
            num_items=3,
            num_relationships=20
        )
    else:
        # Just expand the universe description
        universe.expand_universe()

    return universe

# Function to visualize knowledge graph
def visualize_universe(universe: UniverseKG):
    """Visualize the universe knowledge graph"""
    return universe.knowledge_graph.visualize()

# Demo function to generate a feed
def generate_demo_feed(universe: UniverseKG, num_posts: int = 8):
    """Generate a social media feed for the given universe"""
    generator = FeedGenerator(universe)
    feed = generator.generate_feed(num_posts)

    print(f"\n==== Social Media Feed from {universe.universe_name} ====\n")
    for post in feed:
        post.display_post()
        if post.referenced_entities:
            print(f"Referenced entities: {', '.join(post.referenced_entities)}")
        print("\n" + "-"*50 + "\n")

    return feed

### UI for creating a universe

In [None]:
def create_universe_ui():
    """Create a user interface for creating and exploring universes"""
    try:
        import ipywidgets as widgets
        from IPython.display import display, clear_output

        # Main universe object
        universe = None

        # Create universe UI
        universe_name_input = widgets.Text(description="Universe Name:", placeholder="Enter a name for your universe")

        auto_populate_checkbox = widgets.Checkbox(
            value=True,
            description='Auto-populate universe',
            disabled=False
        )

        # Entity count sliders
        char_slider = widgets.IntSlider(value=6, min=0, max=15, step=1, description='Characters:')
        loc_slider = widgets.IntSlider(value=4, min=0, max=10, step=1, description='Locations:')
        event_slider = widgets.IntSlider(value=3, min=0, max=8, step=1, description='Events:')
        item_slider = widgets.IntSlider(value=3, min=0, max=8, step=1, description='Items:')
        rel_slider = widgets.IntSlider(value=20, min=0, max=50, step=5, description='Relationships:')

        entity_sliders = widgets.VBox([char_slider, loc_slider, event_slider, item_slider, rel_slider])

        # Create button
        create_btn = widgets.Button(description="Create Universe")
        create_output = widgets.Output()

        # Feed generation
        feed_slider = widgets.IntSlider(value=5, min=1, max=20, description='# Posts:')
        generate_feed_btn = widgets.Button(description="Generate Feed")
        feed_output = widgets.Output()

        # Visualization
        visualize_btn = widgets.Button(description="Visualize Knowledge Graph")
        visualize_output = widgets.Output()

        # Handle auto-populate checkbox
        def on_auto_populate_change(change):
            if change['new']:
                entity_sliders.layout.display = ''
            else:
                entity_sliders.layout.display = 'none'

        auto_populate_checkbox.observe(on_auto_populate_change, names='value')

        # Create universe function
        def on_create_click(b):
            with create_output:
                clear_output()
                if not universe_name_input.value:
                    print("Please enter a universe name")
                    return

                print(f"Creating universe: {universe_name_input.value}")

                nonlocal universe
                universe = UniverseKG(universe_name_input.value)
                universe.initialize_llm()

                if auto_populate_checkbox.value:
                    universe.auto_populate(
                        num_characters=char_slider.value,
                        num_locations=loc_slider.value,
                        num_events=event_slider.value,
                        num_items=item_slider.value,
                        num_relationships=rel_slider.value
                    )
                else:
                    universe.expand_universe()

                print(f"\nUniverse '{universe.universe_name}' created successfully!")
                print(f"\nUniverse description:\n{universe.universe_summary}")

        create_btn.on_click(on_create_click)

        # Generate feed function
        def on_generate_feed_click(b):
            with feed_output:
                clear_output()
                if universe is None:
                    print("Please create a universe first")
                    return

                print(f"Generating social media feed for {universe.universe_name}...")
                generator = FeedGenerator(universe)
                feed = generator.generate_feed(feed_slider.value)

                print(f"\n==== Social Media Feed from {universe.universe_name} ====\n")
                for post in feed:
                    post.display_post()
                    if post.referenced_entities:
                        print(f"Referenced entities: {', '.join(post.referenced_entities)}")
                    print("\n" + "-"*50 + "\n")

        generate_feed_btn.on_click(on_generate_feed_click)

        # Visualize function
        def on_visualize_click(b):
            with visualize_output:
                clear_output()
                if universe is None:
                    print("Please create a universe first")
                    return

                print(f"Visualizing knowledge graph for {universe.universe_name}...")
                display(universe.knowledge_graph.visualize())

        visualize_btn.on_click(on_visualize_click)

        # Create widgets
        create_widget = widgets.VBox([
            universe_name_input,
            auto_populate_checkbox,
            entity_sliders,
            create_btn,
            create_output
        ])

        feed_widget = widgets.VBox([
            feed_slider,
            generate_feed_btn,
            feed_output
        ])

        vis_widget = widgets.VBox([
            visualize_btn,
            visualize_output
        ])

        # Create tabs
        tab = widgets.Tab()
        tab.children = [create_widget, feed_widget, vis_widget]
        tab.set_title(0, 'Create Universe')
        tab.set_title(1, 'Generate Feed')
        tab.set_title(2, 'Visualize')

        display(tab)

    except ImportError:
        print("ipywidgets is required for the UI. Install with: pip install ipywidgets")
        return None

### Example universes

In [None]:
def create_example_universes():
    """Create some example universes to demonstrate the system"""
    examples = {}

    print("Creating example universes...")

    # Fantasy universe
    examples["Eldoria"] = create_universe_from_name("Eldoria", auto_populate=True)

    # Sci-fi universe
    examples["Nova Prism"] = create_universe_from_name("Nova Prism", auto_populate=True)

    # Urban fantasy universe
    examples["Hidden Hollows"] = create_universe_from_name("Hidden Hollows", auto_populate=True)

    print("Example universes created!")
    return examples

### Main App

In [None]:
print("This notebook demonstrates generating social media feeds based on fictional universes using knowledge graphs and RAG.")
print("The system can automatically expand a universe from just a name!")
print("\nOptions:")
print("1. Create a universe from scratch with a name")
print("2. Use example universes (Eldoria, Nova Prism, Hidden Hollows)")
print("3. Use the interactive UI to create and explore universes")


Enhanced RAG-based Social Media Feed Generator
This notebook demonstrates generating social media feeds based on fictional universes using knowledge graphs and RAG.
The system can automatically expand a universe from just a name!

Options:
1. Create a universe from scratch with a name
2. Use example universes (Eldoria, Nova Prism, Hidden Hollows)
3. Use the interactive UI to create and explore universes


In [None]:
# Uncomment one of these to run the demo:
#universe = create_universe_from_name("Star Wars during the Clone Wars", auto_populate=True)
#generate_demo_feed(universe, num_posts=8)
#visualize_universe(universe)

examples = create_example_universes()
generate_demo_feed(examples["Eldoria"], num_posts=8)
visualize_universe(examples["Eldoria"])

# create_universe_ui()  # Interactive UI for creating and exploring universes


Creating example universes...
Initializing LLM...
Loading model - this might take a minute...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cpu
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Model loaded successfully!
LLM initialized successfully!
Expanding universe concept...
