In [1]:
from graphiti_core import Graphiti
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
import asyncio
import os
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from uuid import uuid4

# Google API key configuration
api_key = "AIzaSyCcsSlI_vh4H8xXNadS8jdKbWHoRYqOYaY"

# Initialize Graphiti with Gemini clients
graphiti = Graphiti(
   "neo4j://127.0.0.1:7687",
    "neo4j",
    "citygraph",
    llm_client=GeminiClient(
        config=LLMConfig(
            api_key=api_key,
            model="gemini-2.0-flash"
        )
    ),
  embedder=GeminiEmbedder(
        config=GeminiEmbedderConfig(
            api_key=api_key,
            embedding_model="embedding-001"
        )
    ),
    cross_encoder=GeminiRerankerClient(
        config=LLMConfig(
            api_key=api_key,
            model="gemini-2.5-flash-lite-preview-06-17"
        )
    )
)

# Now you can use Graphiti with Google Gemini for all components

In [2]:
await graphiti.build_indices_and_constraints()

In [3]:
class Area(BaseModel):
    area_name: str = Field(..., description="Name of the area")
    lat: float = Field(..., description="Latitude of the area")
    lon: float = Field(..., description="Longitude of the area")

class Event(BaseModel):
    event_type: str = Field(..., description="Type of the event")
    description: str
    severity: str | None = None

class Source(BaseModel):
    source_name: str

# === Define Custom Edge Schemas === #
class HappenedAt(BaseModel):
    """Connects Event -> Area"""
    edge_name: str = Field(default="happened_at", description="Edge name connecting Event to Area")

class ReportedBy(BaseModel):
    """Connects Event -> Source"""
    edge_name: str = Field(default="reported_by", description="Edge name connecting Event to Source")

class Nearby(BaseModel):
    distance_km: float | None = None  # Optional
    """Connects Area -> Area"""

In [4]:
# example_data = [
#     {
#         "event": {
#             "event_type": "weather",
#             "description": "Heavy rain in Koramangala",
#             "severity": "high"
#         },
#         "area": {
#             "area_name": "Koramangala",
#             "lat": 12.9352,
#             "lon": 77.6245
#         },
#         "source_name": "weatherapi",
#         "nearby_areas": [
#             {"area_name": "BTM Layout", "lat": 12.9166, "lon": 77.6101}
#         ]
#     },
#     {
#         "event": {
#             "event_type": "traffic_jam",
#             "description": "Waterlogging traffic in BTM",
#             "severity": "medium"
#         },
#         "area": {
#             "area_name": "BTM Layout",
#             "lat": 12.9166,
#             "lon": 77.6101
#         },
#         "source_name": "twitter",
#         "nearby_areas": [
#             {"area_name": "Koramangala", "lat": 12.9352, "lon": 77.6245}
#         ]
#     }
# ]

In [5]:
entity_types = {
        "Event": Event,
        "Area": Area,
        "Source": Source
    }

edge_types = {
        "HAPPENED_AT": HappenedAt,
        "REPORTED_BY": ReportedBy,
        "NEARBY": Nearby
    }

edge_type_map = {
        ("Event", "Area"): ["HAPPENED_AT"],
        ("Event", "Source"): ["REPORTED_BY"],
        ("Area", "Area"): ["NEARBY"]
    }

group_id = str(uuid4())


In [7]:
import datetime

In [8]:
await graphiti.add_episode(
    name="City Update",
    episode_body="There was a heavy rain event reported in Koramangala, classified with high severity, according to data from weatherapi. This event also notes BTM Layout as a nearby affected area, suggesting possible regional impact. Additionally, a traffic jam was reported in BTM Layout, caused by waterlogging, and marked with medium severity. This information was sourced from Twitter, and it identifies Koramangala as a nearby area, indicating a possible connection or shared impact zone between the two neighborhoods.",
    source_description="Traffic and weather updates",
    entity_types=entity_types,
    edge_types=edge_types,
    reference_time=datetime.datetime.now(),
    edge_type_map=edge_type_map
)

AddEpisodeResults(episode=EpisodicNode(uuid='0207168b-c7c3-4a8a-800a-a62eb05082e7', name='City Update', group_id='', labels=[], created_at=datetime.datetime(2025, 7, 21, 7, 10, 52, 960287, tzinfo=datetime.timezone.utc), source=<EpisodeType.message: 'message'>, source_description='Traffic and weather updates', content='There was a heavy rain event reported in Koramangala, classified with high severity, according to data from weatherapi. This event also notes BTM Layout as a nearby affected area, suggesting possible regional impact. Additionally, a traffic jam was reported in BTM Layout, caused by waterlogging, and marked with medium severity. This information was sourced from Twitter, and it identifies Koramangala as a nearby area, indicating a possible connection or shared impact zone between the two neighborhoods.', valid_at=datetime.datetime(2025, 7, 21, 12, 40, 52, 960287), entity_edges=['3f514d73-f43c-42d0-86ea-b14a60b9d283', '76f8d7f2-bbf1-4617-b049-af64d0e031fc', '2177601a-06d2-4

In [None]:
ep2 = """A widespread power outage occurred in Koramangala, reportedly due to a transformer failure during ongoing maintenance work. The outage, which lasted for several hours, caused disruption in both residential and commercial zones and is considered a high-impact event. Local reports and resident complaints were aggregated from the BESCOM monitoring app. BTM Layout and HSR Layout have also experienced partial blackouts and voltage fluctuations during this period, suggesting the outage might have had a cascading effect on the surrounding grid network. Field engineers were dispatched to multiple sites, and emergency power systems were temporarily activated in some hospitals and tech parks.
"""

In [74]:
await graphiti.add_episode(
    name="City Update",
    episode_body= ep2,
    source_description="Power updates",
    reference_time=datetime.now(),
    entity_types=entity_types,
    edge_types=edge_types,
    edge_type_map=edge_type_map
)

In [79]:
results = await graphiti.search('What are reasons related to power outage?')

In [80]:
for result in results:
    print(f'UUID: {result.uuid}')
    print(f'Fact: {result.fact}')
    if hasattr(result, 'valid_at') and result.valid_at:
        print(f'Valid from: {result.valid_at}')
    if hasattr(result, 'invalid_at') and result.invalid_at:
        print(f'Valid until: {result.invalid_at}')
    print('---')

In [None]:
# for i, data in enumerate(example_data):
#         event_id = f"event_{i}"
#         area_id = data["area"]["area_name"]
#         source_id = data["source_name"]

#         # Register RawEpisode
#         episode = RawEpisode(
#             name=f"Event {i}",
#             content=data["event"]["description"],
#             reference_time="2025-07-18T10:00:00Z",
#             source=EpisodeType.message,
#             source_description="City Feed Ingestion"
#         )

#         await graphiti.add_episode(
#             name=f"Event {i}",
#             episode_body=episode.content,
#             reference_time=episode.reference_time,
#             source_description=episode.source_description,
#             group_id=group_id,
#             entity_types=entity_types,
#             edge_types=edge_types,
#             edge_type_map=edge_type_map
#         )


In [None]:
# results = await graphiti.search('weather in Koramangala?')

In [None]:
# results

In [40]:
for result in results:
    print(f'UUID: {result.uuid}')
    print(f'Fact: {result.fact}')
    if hasattr(result, 'valid_at') and result.valid_at:
        print(f'Valid from: {result.valid_at}')
    if hasattr(result, 'invalid_at') and result.invalid_at:
        print(f'Valid until: {result.invalid_at}')
    print('---')

In [51]:
import os
import re
from datetime import datetime, timedelta, timezone

from pydantic import BaseModel


class Speaker(BaseModel):
    index: int
    name: str
    role: str


class ParsedMessage(BaseModel):
    speaker_index: int
    speaker_name: str
    role: str
    relative_timestamp: str
    actual_timestamp: datetime
    content: str


def parse_timestamp(timestamp: str) -> timedelta:
    if 'm' in timestamp:
        match = re.match(r'(\d+)m(?:\s*(\d+)s)?', timestamp)
        if match:
            minutes = int(match.group(1))
            seconds = int(match.group(2)) if match.group(2) else 0
            return timedelta(minutes=minutes, seconds=seconds)
    elif 's' in timestamp:
        match = re.match(r'(\d+)s', timestamp)
        if match:
            seconds = int(match.group(1))
            return timedelta(seconds=seconds)
    return timedelta()  # Return 0 duration if parsing fails


def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
    with open(file_path) as file:
        content = file.read()

    messages = content.split('\n\n')
    speaker_dict = {speaker.index: speaker for speaker in speakers}

    parsed_messages: list[ParsedMessage] = []

    # Find the last timestamp to determine podcast duration
    last_timestamp = timedelta()
    for message in reversed(messages):
        lines = message.strip().split('\n')
        if lines:
            first_line = lines[0]
            parts = first_line.split(':', 1)
            if len(parts) == 2:
                header = parts[0]
                header_parts = header.split()
                if len(header_parts) >= 2:
                    timestamp = header_parts[1].strip('()')
                    last_timestamp = parse_timestamp(timestamp)
                    break

    # Calculate the start time
    now = datetime.now(timezone.utc)
    podcast_start_time = now - last_timestamp

    for message in messages:
        lines = message.strip().split('\n')
        if lines:
            first_line = lines[0]
            parts = first_line.split(':', 1)
            if len(parts) == 2:
                header, content = parts
                header_parts = header.split()
                if len(header_parts) >= 2:
                    speaker_index = int(header_parts[0])
                    timestamp = header_parts[1].strip('()')

                    if len(lines) > 1:
                        content += '\n' + '\n'.join(lines[1:])

                    delta = parse_timestamp(timestamp)
                    actual_time = podcast_start_time + delta

                    speaker = speaker_dict.get(speaker_index)
                    if speaker:
                        speaker_name = speaker.name
                        role = speaker.role
                    else:
                        speaker_name = f'Unknown Speaker {speaker_index}'
                        role = 'Unknown'

                    parsed_messages.append(
                        ParsedMessage(
                            speaker_index=speaker_index,
                            speaker_name=speaker_name,
                            role=role,
                            relative_timestamp=timestamp,
                            actual_timestamp=actual_time,
                            content=content.strip(),
                        )
                    )

    return parsed_messages


def parse_podcast_messages():
    file_path = 'podcast_transcript.txt'
    script_dir = os.path.dirname("D:\\city-graph\\city-graph\\podcast_transcript.txt")
    relative_path = os.path.join(script_dir, file_path)

    speakers = [
        Speaker(index=0, name='Stephen DUBNER', role='Host'),
        Speaker(index=1, name='Tania Tetlow', role='Guest'),
        Speaker(index=4, name='Narrator', role='Narrator'),
        Speaker(index=5, name='Kamala Harris', role='Quoted'),
        Speaker(index=6, name='Unknown Speaker', role='Unknown'),
        Speaker(index=7, name='Unknown Speaker', role='Unknown'),
        Speaker(index=8, name='Unknown Speaker', role='Unknown'),
        Speaker(index=10, name='Unknown Speaker', role='Unknown'),
    ]

    parsed_conversation = parse_conversation_file(relative_path, speakers)
    print(f'Number of messages: {len(parsed_conversation)}')
    return parsed_conversation

In [52]:
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import logging
import os
import sys
from uuid import uuid4

from dotenv import load_dotenv
from pydantic import BaseModel, Field


from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data

load_dotenv()


def setup_logging():
    # Create a logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)  # Set the logging level to INFO

    # Create console handler and set level to INFO
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Add formatter to console handler
    console_handler.setFormatter(formatter)

    # Add console handler to logger
    logger.addHandler(console_handler)

    return logger


class Person(BaseModel):
    """A human person, fictional or nonfictional."""

    first_name: str | None = Field(..., description='First name')
    last_name: str | None = Field(..., description='Last name')
    occupation: str | None = Field(..., description="The person's work occupation")


class IsPresidentOf(BaseModel):
    """Relationship between a person and the entity they are a president of"""


async def test(use_bulk: bool = False):
    setup_logging()
    client = graphiti
    await clear_data(client.driver)
    await client.build_indices_and_constraints()
    messages = parse_podcast_messages()
    group_id = str(uuid4())

    raw_episodes: list[RawEpisode] = []
    for i, message in enumerate(messages[3:14]):
        raw_episodes.append(
            RawEpisode(
                name=f'Message {i}',
                content=f'{message.speaker_name} ({message.role}): {message.content}',
                reference_time=message.actual_timestamp,
                source=EpisodeType.message,
                source_description='Podcast Transcript',
            )
        )
    if use_bulk:
        await client.add_episode_bulk(
            raw_episodes,
            group_id=group_id,
            entity_types={'Person': Person},
            edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
            edge_type_map={('Person', 'Entity'): ['IS_PRESIDENT_OF']},
        )
    else:
        for i, message in enumerate(messages[3:14]):
            episodes = await client.retrieve_episodes(
                message.actual_timestamp, 3, group_ids=[group_id]
            )
            episode_uuids = [episode.uuid for episode in episodes]

            await client.add_episode(
                name=f'Message {i}',
                episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
                reference_time=message.actual_timestamp,
                source_description='Podcast Transcript',
                group_id=group_id,
                entity_types={'Person': Person},
                edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
                edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
                previous_episode_uuids=episode_uuids,
            )

In [None]:
await test(False)  # Run the test function without bulk mode