In [1]:
import numpy as np
import pandas as pd
import geopandas as gpd

# Process the data

In [3]:
# move_dt = pd.read_csv("data/foood_mob_cover.csv")
# move_dt = move_dt.assign(
#     who = (move_dt['quitaz_1'] * 8e3 + move_dt['quitaz_2']).astype(int),
#     seq = move_dt['seiqd'],
#     lon_o = (move_dt['mean_dur'] - 200)/1e3 + 114,
#     lon_d = (move_dt['std_dur'] - 200)/1e3 + 114,
#     lat_o = (move_dt['mean_volm'])/10 + 22.65,
#     lat_d = (move_dt['std_volm'])/10 + 22.65,
#     date = move_dt['date'],
#     poi_o = move_dt['district_o'],
#     poi_d = move_dt['district_d'],
#     inplace=True
# ) 

# select_columns = ['who', 'seq', 'lon_o', 'lat_o', 'lon_d', 'lat_d', 'date', 'poi_o', 'poi_d']
# move_dt = move_dt[select_columns].copy()
# # sort the values of the move_dt dataframe by who, date and seq
# move_dt = move_dt.sort_values(by=['who', 'date', 'seq']).reset_index(drop=True)
# move_dt.to_csv("data/processed_moves.csv", index=False)

In [4]:
# move_dt = pd.read_csv("data/processed_moves.csv")
# persons = move_dt.who.drop_duplicates().reset_index(drop=True)
# persons.to_csv("data/processed_users.csv", index=False)

In [2]:
# process stay data
# read all the data as series, because they only contain one column
t_start = pd.read_csv('data/st/reptoire.csv').iloc[:, 0]
t_end = pd.read_csv('data/st/nif.csv').iloc[:, 0]
ptype = pd.read_csv('data/st/model.csv').iloc[:, 0]
poi = pd.read_csv('data/st/iop.csv').iloc[:, 0]
who = pd.read_csv('data/st/est.csv').iloc[:, 0]
date = pd.read_csv('data/st/aoz.csv').iloc[:, 0]
lon_p1 = pd.read_csv('data/st/mean_log_p1.csv').iloc[:, 0]
lon_p2 = pd.read_csv('data/st/mean_log_p2.csv').iloc[:, 0]
lat_p1 = pd.read_csv('data/st/std_log_p1.csv').iloc[:, 0]
lat_p2 = pd.read_csv('data/st/std_log_p2.csv').iloc[:, 0]

# concatenate the longitude and latitude parts
lon = pd.concat([lon_p1, lon_p2], ignore_index=True)
lat = pd.concat([lat_p1, lat_p2], ignore_index=True)

st_data = pd.DataFrame({
    't_start': t_start,
    't_end': t_end,
    'ptype': ptype,
    'poi': poi,
    'who': who,
    'date': date,
    'lon': lon,
    'lat': lat
})

In [3]:
st_sample = st_data.assign(
    # turn to the time stamp by adding the base time 1677600000
    who = st_data['who'].astype(int),
    date = st_data['date'].astype(int) + 20202020,
    t_start = pd.to_datetime(st_data['t_start'] + 1677600000, unit='s'),
    t_end = pd.to_datetime(st_data['t_end'] + 1677600000, unit='s'),
    lon = st_data['lon'] / 4,
    lat = st_data['lat'] * 4,
    ptype = st_data['ptype'].astype(int),
    poi = st_data['poi'].astype(int)
)

st_sample = st_sample[['who', 'date', 't_start', 't_end', 'lon', 'lat', 'ptype', 'poi']].\
    sort_values(by=['who', 'date', 't_start'])


In [4]:
# ====== Stay Record Processing Algorithm ======
def merge_consecutive_stays(df, gap_minutes=30):
    """
    Merge consecutive stay records with the same location and gap <= gap_minutes
    OR same location AND adjacent calendar dates
    
    Parameters:
    - df: DataFrame sorted by who and t_start
    - gap_minutes: Maximum gap threshold (minutes)
    
    Returns:
    - Merged DataFrame
    """
    if len(df) == 0:
        return df.copy()
    
    # Calculate time gap to next record (in minutes)
    time_diff = (df['t_start'].shift(-1) - df['t_end']).dt.total_seconds() / 60
    
    # Check if location changed (use np.isclose for floating-point comparison)
    lon_same = np.isclose(df['lon'].shift(-1), df['lon'], rtol=1e-9)
    lat_same = np.isclose(df['lat'].shift(-1), df['lat'], rtol=1e-9)
    pos_same = lon_same & lat_same
    
    # Check if dates are adjacent
    date_adjacent = (df['date'].shift(-1) - df['date']) == 1
    
    # Determine if merge is needed:
    # Option 1: same location AND time gap <= threshold
    # Option 2: same location AND adjacent calendar dates
    need_merge = pos_same & ((time_diff <= gap_minutes) | date_adjacent)
    need_merge.iloc[-1] = False  # Last record doesn't need merge
    
    # Create group identifier: new group whenever merge is not needed
    group_id = (~need_merge).cumsum()
    
    # Aggregate by group
    # Add group_id as a column first to avoid FutureWarning
    df = df.copy()
    df['group_id'] = group_id
    merged = df.groupby(['who', 'group_id'], as_index=False).agg({
        't_start': 'min',
        't_end': 'max',
        'lon': 'first',
        'lat': 'first',
        'ptype': 'first',
        'poi': 'first',
        'date': 'first'
    })
    merged = merged.drop(columns=['group_id'])
    
    return merged


def filter_short_stays(df, min_minutes=30):
    """
    Filter out stay records shorter than min_minutes
    
    Parameters:
    - df: DataFrame
    - min_minutes: Minimum stay duration (minutes)
    
    Returns:
    - Filtered DataFrame
    """
    stay_duration = (df['t_end'] - df['t_start']).dt.total_seconds() / 60
    return df[stay_duration >= min_minutes].reset_index(drop=True)


# ====== Processing Pipeline ======
print(f"Original data: {len(st_sample)} records")

# Step 1: Group by who (data is already sorted by who, date, t_start)
st_processed = st_sample.copy()

# Step 2: Merge consecutive stays for each individual
def process_single_person(group):
    """Process data for a single individual"""
    return merge_consecutive_stays(group, gap_minutes=30)

st_processed = st_processed.groupby('who', group_keys=False).apply(process_single_person)
print(f"After Step 2 (merge consecutive stays): {len(st_processed)} records")

# Step 3: Filter out stays shorter than 30 minutes
st_processed = filter_short_stays(st_processed, min_minutes=30)
print(f"After Step 3 (filter short stays): {len(st_processed)} records")

# Step 4: Merge again
st_processed = st_processed.groupby('who', group_keys=False).apply(process_single_person)
print(f"After Step 4 (merge again): {len(st_processed)} records")

# Re-sort
st_processed = st_processed.sort_values(by=['who', 't_start']).reset_index(drop=True)

st_processed

Original data: 720427 records


  st_processed = st_processed.groupby('who', group_keys=False).apply(process_single_person)


After Step 2 (merge consecutive stays): 652777 records
After Step 3 (filter short stays): 575138 records


  st_processed = st_processed.groupby('who', group_keys=False).apply(process_single_person)


After Step 4 (merge again): 567555 records


Unnamed: 0,who,t_start,t_end,lon,lat,ptype,poi,date
0,126272,2019-01-01 00:28:58,2019-01-01 10:11:49,113.833530,22.689639,1,0,20190101
1,126272,2019-01-01 10:19:15,2019-01-01 17:13:30,113.948257,22.529631,0,3,20190101
2,126272,2019-01-01 17:26:58,2019-01-01 18:47:28,113.889444,22.773309,0,14,20190101
3,126272,2019-01-01 18:59:06,2019-01-01 21:26:17,113.867473,22.573858,0,11,20190101
4,126272,2019-01-02 10:05:07,2019-01-02 19:41:13,114.095561,22.553550,0,2,20190102
...,...,...,...,...,...,...,...,...
567550,78623283,2019-12-30 11:00:45,2019-12-30 18:46:13,113.912155,22.534095,2,0,20191230
567551,78623283,2019-12-30 19:58:40,2019-12-30 23:15:45,113.908330,22.518993,1,1,20191230
567552,78623283,2019-12-31 09:02:22,2019-12-31 18:30:06,114.026052,22.625212,0,9,20191231
567553,78623283,2019-12-31 18:30:40,2019-12-31 20:05:57,113.945381,22.556100,2,0,20191231


# Format the trajectories

In [6]:
# ====== Step 1 & 2: Random Jitter & HDBSCAN Clustering ======
import hdbscan
from typing import Optional
import plotly.express as px
from pyproj import Transformer
from shapely.geometry import Point

# Random jitter function: uniform random offset within specified radius (area-uniform)
def jitter_within_radius(xy: np.ndarray, max_radius_m: float, rng: np.random.Generator) -> np.ndarray:
    """
    Apply uniform random jitter within a circle of specified radius.
    Uses sqrt(r) for radius to ensure uniform area distribution.
    """
    # Check: the dimension of the input array should be 2
    if xy.shape[1] != 2:
        raise ValueError("Input array should have exactly 2 columns (x, y)")
    n = xy.shape[0]
    angles = rng.uniform(0.0, 2.0 * np.pi, size=n)
    radii = max_radius_m * np.sqrt(rng.uniform(0.0, 1.0, size=n))
    offsets = np.column_stack([radii * np.cos(angles), radii * np.sin(angles)])
    return xy + offsets


def add_jitter_and_cluster(
    df: pd.DataFrame, 
    jitter_radius_m: float = 300, 
    min_cluster_size: int = 6, 
    min_samples: int = 6,
    seed: int = 114514
) -> pd.DataFrame:
    """
    Apply 300m random jitter to coordinates, then perform HDBSCAN clustering.
    
    This function:
    1. Converts lon/lat to UTM coordinates (EPSG:32650)
    2. Applies random jitter within specified radius
    3. Converts back to WGS84 lon/lat (updating original coordinates)
    4. Performs HDBSCAN clustering on the jittered coordinates
    
    Cluster labels:
    - 0: Missing coordinates (outside study area)
    - -1: HDBSCAN noise points
    - 1+: HDBSCAN clusters (original HDBSCAN labels shifted by +1)
    
    Parameters:
    - df: DataFrame with 'who', 'lon', 'lat' columns
    - jitter_radius_m: Radius for random jitter in meters
    - min_cluster_size: HDBSCAN parameter
    - min_samples: HDBSCAN parameter
    - seed: Random seed for reproducibility
    
    Returns:
    - DataFrame with additional 'cluster_id' column (0 = missing, -1 = noise, 1+ = clusters)
    """
    RNG = np.random.default_rng(seed)
    df = df.copy()
    
    def cluster_one_person(person_df: pd.DataFrame) -> pd.DataFrame:
        """Cluster stays for a single person using HDBSCAN.
        
        Returns:
        - cluster_id = 0: missing coordinates (outside study area)
        - cluster_id = -1: HDBSCAN noise
        - cluster_id = 1+: HDBSCAN clusters (original label + 1)
        """
        person_df = person_df.copy()
        
        # Check for missing coordinates
        has_coords = ~person_df[['lon', 'lat']].isna().any(axis=1)
        valid_mask = has_coords.values
        valid_count = valid_mask.sum()
        total_count = len(person_df)
        
        # Initialize all cluster_ids to 0 (missing coordinates)
        person_df['cluster_id'] = 0
        
        # If no valid coordinates at all, skip clustering
        if valid_count == 0:
            print(f"  User {person_df['who'].iloc[0]}: All {total_count} records have missing coordinates, skipping clustering")
            return person_df
        
        # If some records have missing coordinates, report stats
        if valid_count < total_count:
            missing_count = total_count - valid_count
            print(f"  User {person_df['who'].iloc[0]}: {missing_count} missing, {valid_count} valid coordinates")
        
        # Get valid records
        valid_indices = person_df[has_coords].index
        valid_lon = person_df.loc[valid_indices, 'lon'].values.copy()
        valid_lat = person_df.loc[valid_indices, 'lat'].values.copy()
        
        # UTM conversion (EPSG:32650 = UTM Zone 50N, covers Shenzhen)
        transformer = Transformer.from_crs("EPSG:4326", "EPSG:32650", always_xy=True)
        inverse_transformer = Transformer.from_crs("EPSG:32650", "EPSG:4326", always_xy=True)
        
        # Convert lon/lat to UTM coordinates (meters)
        utm_x, utm_y = transformer.transform(valid_lon, valid_lat)
        
        # Apply random jitter to UTM coordinates
        xy_valid = np.column_stack([utm_x, utm_y])
        xy_jittered = jitter_within_radius(xy_valid, max_radius_m=jitter_radius_m, rng=RNG)
        
        # Convert back to WGS84 lon/lat
        jittered_lon, jittered_lat = inverse_transformer.transform(xy_jittered[:, 0], xy_jittered[:, 1])
        
        # Update original lon/lat with jittered coordinates
        person_df.loc[valid_indices, 'lon'] = jittered_lon
        person_df.loc[valid_indices, 'lat'] = jittered_lat
        
        # HDBSCAN clustering
        clusterer = hdbscan.HDBSCAN(
            min_cluster_size=min_cluster_size, 
            min_samples=min_samples,
            metric='euclidean'
        )
        labels_valid = clusterer.fit_predict(xy_jittered)
        
        # Shift labels for valid records:
        # HDBSCAN returns -1 for noise, 0, 1, 2... for clusters
        # We want to keep -1 as noise, shift 0闂?, 1闂?, 2闂?...
        # So we add 1 to non-noise labels only
        labels_shifted = np.where(labels_valid == -1, -1, labels_valid + 1)
        
        # Assign shifted labels back to valid records
        person_df.loc[valid_indices, 'cluster_id'] = labels_shifted
        
        return person_df
    
    # Group by 'who' and apply clustering
    print("Applying random jitter (300m) and HDBSCAN clustering...")
    print(f"Total users: {df['who'].nunique()}, Total records: {len(df)}")
    df_clustered = df.groupby('who', group_keys=False).apply(cluster_one_person)
    
    return df_clustered


# Apply clustering to processed data
print("Applying random jitter (300m) and HDBSCAN clustering...")
st_clustered = add_jitter_and_cluster(
    st_processed, 
    jitter_radius_m=300, 
    min_cluster_size=6, 
    min_samples=10
)

# Display clustering results
print(f"\nClustering complete!")
print(f"Total records: {len(st_clustered)}")
print(f"Missing coordinates (cluster_id=0): {(st_clustered['cluster_id'] == 0).sum()}")
print(f"HDBSCAN noise (cluster_id=-1): {(st_clustered['cluster_id'] == -1).sum()}")
print(f"Records in clusters (cluster_id>=1): {(st_clustered['cluster_id'] >= 1).sum()}")
print(f"Number of unique clusters: {st_clustered[st_clustered['cluster_id'] >= 1]['cluster_id'].nunique()}")

st_clustered.head(10)


Applying random jitter (300m) and HDBSCAN clustering...
Applying random jitter (300m) and HDBSCAN clustering...
Total users: 441, Total records: 567555
  User 395753: 1 missing, 1199 valid coordinates
  User 2436270: 5 missing, 1385 valid coordinates
  User 2590264: 3 missing, 1588 valid coordinates
  User 3087457: 2 missing, 1023 valid coordinates
  User 3549549: 2 missing, 1409 valid coordinates
  User 3803542: 1 missing, 1028 valid coordinates
  User 3932007: 1 missing, 1169 valid coordinates
  User 4717957: 1 missing, 751 valid coordinates
  User 4975553: 1 missing, 1271 valid coordinates
  User 6071595: 3 missing, 1640 valid coordinates
  User 6854307: 1 missing, 1322 valid coordinates
  User 7477809: 5 missing, 2034 valid coordinates
  User 7823042: 1 missing, 1262 valid coordinates
  User 8614296: 2 missing, 1388 valid coordinates
  User 12521378: 7 missing, 1696 valid coordinates
  User 13976993: 3 missing, 1176 valid coordinates
  User 14921919: 1 missing, 1372 valid coordinat

  df_clustered = df.groupby('who', group_keys=False).apply(cluster_one_person)


Unnamed: 0,who,t_start,t_end,lon,lat,ptype,poi,date,cluster_id
0,126272,2019-01-01 00:28:58,2019-01-01 10:11:49,113.834643,22.69071,1,0,20190101,3
1,126272,2019-01-01 10:19:15,2019-01-01 17:13:30,113.948529,22.530949,0,3,20190101,20
2,126272,2019-01-01 17:26:58,2019-01-01 18:47:28,113.889141,22.773277,0,14,20190101,9
3,126272,2019-01-01 18:59:06,2019-01-01 21:26:17,113.86688,22.573279,0,11,20190101,-1
4,126272,2019-01-02 10:05:07,2019-01-02 19:41:13,114.098086,22.554694,0,2,20190102,26
5,126272,2019-01-02 19:56:04,2019-01-02 22:37:02,114.097994,22.598955,0,9,20190102,10
6,126272,2019-01-03 00:46:34,2019-01-03 07:51:17,113.831582,22.688753,1,0,20190103,3
7,126272,2019-01-03 08:11:55,2019-01-03 14:11:15,113.842961,22.607576,2,1,20190103,24
8,126272,2019-01-03 14:11:19,2019-01-03 16:40:00,113.829167,22.737915,0,10,20190103,-1
9,126272,2019-01-03 16:53:00,2019-01-03 21:15:07,114.023939,22.531267,0,2,20190103,22


In [None]:
# ====== Step 3: Interactive Map Visualization ======
def visualize_user_clusters(
    df: pd.DataFrame, 
    user_id: int, 
    zoom: int = 12,
    height: int = 600
) -> px.scatter_mapbox:
    """
    Create an interactive map showing cluster locations for a specific user.
    
    Parameters:
    - df: DataFrame with 'who', 'lon', 'lat', 'cluster_id' columns
    - user_id: User ID to visualize
    - zoom: Initial zoom level
    - height: Map height in pixels
    
    Returns:
    - Plotly Express scatter_mapbox figure
    """
    user_data = df[df['who'] == user_id].copy()
    
    if user_data.empty:
        raise ValueError(f"No data found for user {user_id}")
    
    # Add cluster label for hover (missing vs noise vs cluster)
    # cluster_id = 0: Missing coordinates (outside study area)
    # cluster_id = -1: HDBSCAN noise
    # cluster_id = 1+: HDBSCAN clusters
    user_data['cluster_label'] = user_data['cluster_id'].apply(
        lambda x: f"Missing (Outside Study Area)" if x == 0 
                  else (f"HDBSCAN Noise" if x == -1 
                  else f"Cluster {x}")
    )
    
    # Define custom colors: missing=red, noise=gray, clusters=auto
    color_discrete_map = {
        'Missing (Outside Study Area)': 'red',
        'HDBSCAN Noise': 'gray',
    }
    
    # Get unique cluster labels for ordering
    cluster_labels = [f'Cluster {i}' for i in sorted(user_data[user_data['cluster_id'] >= 1]['cluster_id'].unique())]
    
    # Create scatter mapbox with custom colors
    fig = px.scatter_mapbox(
        user_data,
        lat='lat',
        lon='lon',
        color='cluster_label',
        color_discrete_map=color_discrete_map,
        category_orders={'cluster_label': ['Missing (Outside Study Area)', 'HDBSCAN Noise'] + cluster_labels},
        hover_data={
            'who': True,
            't_start': True,
            't_end': True,
            'cluster_id': True,
            'lat': ':.5f',
            'lon': ':.5f'
        },
        title=f"Stay locations for User {user_id}",
        zoom=zoom,
        height=height,
        size_max=15
    )
    
    fig.update_layout(mapbox_style="carto-positron")
    fig.update_layout(margin={"r": 0, "t": 40, "l": 0, "b": 0})
    
    return fig


# Example: Visualize clusters for a specific user
# Change user_id to visualize different users
inspect_user = st_clustered['who'].iloc[0]  # First user as example
print(f"Visualizing clusters for user {inspect_user}...")
fig = visualize_user_clusters(st_clustered, user_id=inspect_user, zoom=12, height=500)
fig.show()

# Visualize summary statistics per user
# New cluster_id meanings: 0 = missing, -1 = noise, 1+ = clusters
print("\nCluster statistics per user:")
user_stats = st_clustered.groupby('who').agg({
    'cluster_id': [
        'count',                                          # total_stays
        lambda x: (x == 0).sum(),                         # missing_stays
        lambda x: (x != 0).sum(),                          # valid_stays (not missing)
        lambda x: (x == -1).sum(),                         # noise_stays
        lambda x: (x >= 1).sum(),                          # clustered_stays
        lambda x: (x[x >= 1].nunique() if (x >= 1).any() else 0)  # num_clusters
    ]
}).reset_index()
user_stats.columns = ['who', 'total_stays', 'missing_stays', 'valid_stays', 'noise_stays', 'clustered_stays', 'num_clusters']
print(user_stats.head(10))

# Summary
print(f"\n=== Overall Statistics ===")
print(f"Total records: {len(st_clustered)}")
print(f"Missing coordinates (cluster_id=0): {(st_clustered['cluster_id'] == 0).sum()}")
print(f"Valid coordinates: {(st_clustered['cluster_id'] != 0).sum()}")
print(f"  - HDBSCAN noise (cluster_id=-1): {(st_clustered['cluster_id'] == -1).sum()}")
print(f"  - In clusters (cluster_id>=1): {(st_clustered['cluster_id'] >= 1).sum()}")
print(f"Number of unique clusters: {st_clustered[st_clustered['cluster_id'] >= 1]['cluster_id'].nunique()}")


In [7]:

# ====== Step 4 & 5: OOP Data Models & Conversion ======
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from datetime import datetime, timedelta


@dataclass
class Visit:
    """
    Represents a single stay/visit at a location.
    
    Attributes:
    - t_start: Start timestamp of the visit
    - t_end: End timestamp of the visit
    - lon: Longitude of the location
    - lat: Latitude of the location
    - cluster_id: Cluster label (0 = missing/outside study area, -1 = HDBSCAN noise, 1+ = HDBSCAN clusters)
    - ptype: Place type (optional)
    - poi: POI information (optional)
    """
    t_start: datetime
    t_end: datetime
    lon: float
    lat: float
    cluster_id: int = 0  # 0 = missing, -1 = noise, 1+ = clusters
    ptype: Optional[int] = None
    poi: Optional[int] = None
    
    @property
    def duration(self) -> timedelta:
        """Calculate duration of the visit."""
        return self.t_end - self.t_start
    
    @property
    def date(self) -> int:
        """Extract date in YYYYMMDD format from t_start."""
        return int(self.t_start.strftime('%Y%m%d'))
    
    @classmethod
    def from_dict(cls, data: dict) -> 'Visit':
        """Create Visit from dictionary."""
        return cls(
            t_start=pd.to_datetime(data['t_start']),
            t_end=pd.to_datetime(data['t_end']),
            lon=float(data['lon']),
            lat=float(data['lat']),
            cluster_id=int(data.get('cluster_id', -1)),
            ptype=int(data['ptype']) if pd.notna(data.get('ptype')) else None,
            poi=int(data['poi']) if pd.notna(data.get('poi')) else None
        )
    
    def to_dict(self) -> dict:
        """Convert Visit to dictionary."""
        return {
            't_start': self.t_start,
            't_end': self.t_end,
            'lon': self.lon,
            'lat': self.lat,
            'cluster_id': self.cluster_id,
            'ptype': self.ptype,
            'poi': self.poi
        }


@dataclass
class Trajectory:
    """
    Represents a daily trajectory (from 3:00 AM to next day 3:00 AM).
    
    Attributes:
    - date: Date of the trajectory (YYYYMMDD format, representing the day starting at 3 AM)
    - visits: List of Visit objects
    """
    date: int
    visits: List[Visit] = field(default_factory=list)
    
    @property
    def num_visits(self) -> int:
        """Return number of visits."""
        return len(self.visits)
    
    @property
    def start_time(self) -> Optional[datetime]:
        """Return start time of first visit."""
        return self.visits[0].t_start if self.visits else None
    
    @property
    def end_time(self) -> Optional[datetime]:
        """Return end time of last visit."""
        return self.visits[-1].t_end if self.visits else None
    
    def add_visit(self, visit: Visit):
        """Add a visit to the trajectory."""
        self.visits.append(visit)
        # Keep visits sorted by start time
        self.visits.sort(key=lambda v: v.t_start)
    
    def to_dict(self) -> dict:
        """Convert trajectory to dictionary."""
        return {
            'date': self.date,
            'visits': [v.to_dict() for v in self.visits]
        }


class User:
    """
    Represents a user with their trajectories and memory.
    
    Attributes:
    - id: User identifier
    - trajectories: Dictionary mapping date to Trajectory
    - memory: User's memory (to be implemented)
    """
    
    # Reference time: 3:00 AM threshold for day boundaries
    DAY_THRESHOLD_HOUR = 3
    
    def __init__(self, id: int):
        """
        Initialize a User.
        
        Parameters:
        - id: User identifier
        """
        self.id = id
        self.trajectories: Dict[int, Trajectory] = {}
        self.memory: Optional[Dict] = None
    
    @property
    def num_trajectories(self) -> int:
        """Return number of trajectories."""
        return len(self.trajectories)
    
    @property
    def total_visits(self) -> int:
        """Return total number of visits across all trajectories."""
        return sum(t.num_visits for t in self.trajectories.values())
    
    @property
    def unique_clusters(self) -> set:
        """Return set of unique cluster IDs visited by this user.
        
        Note: cluster_id >= 1 are actual clusters (0 = missing, -1 = noise)
        """
        clusters = set()
        for traj in self.trajectories.values():
            for visit in traj.visits:
                if visit.cluster_id >= 1:  # Only count actual clusters (not noise=-1 or missing=0)
                    clusters.add(visit.cluster_id)
        return clusters
    
    def add_trajectory(self, trajectory: Trajectory):
        """Add a trajectory to the user."""
        self.trajectories[trajectory.date] = trajectory
    
    def get_trajectory(self, date: int) -> Optional[Trajectory]:
        """Get trajectory for a specific date."""
        return self.trajectories.get(date)
    
    def get_all_visits(self) -> List[Visit]:
        """Get all visits across all trajectories."""
        visits = []
        for traj in self.trajectories.values():
            visits.extend(traj.visits)
        return sorted(visits, key=lambda v: v.t_start)
    
    def to_dataframe(self) -> pd.DataFrame:
        """Convert user data to DataFrame."""
        records = []
        for date, traj in self.trajectories.items():
            for visit in traj.visits:
                records.append({
                    'who': self.id,
                    'date': date,
                    't_start': visit.t_start,
                    't_end': visit.t_end,
                    'lon': visit.lon,
                    'lat': visit.lat,
                    'cluster_id': visit.cluster_id,
                    'ptype': visit.ptype,
                    'poi': visit.poi
                })
        return pd.DataFrame(records)
    
    @classmethod
    def from_dataframe(cls, df: pd.DataFrame) -> 'User':
        """
        Create User objects from a DataFrame.
        
        Handles splitting visits that cross midnight (3 AM threshold).
        """
        # Ensure datetime columns
        df = df.copy()
        df['t_start'] = pd.to_datetime(df['t_start'])
        df['t_end'] = pd.to_datetime(df['t_end'])
        
        users_dict: Dict[int, User] = {}
        
        for _, row in df.iterrows():
            who = int(row['who'])
            
            if who not in users_dict:
                users_dict[who] = User(id=who)
            
            user = users_dict[who]
            
            # Process visits that may cross the 3 AM boundary
            visits_to_add = cls._process_crossing_visits(row)
            
            for visit, visit_date in visits_to_add:
                # Get or create trajectory for this date
                if visit_date not in user.trajectories:
                    user.trajectories[visit_date] = Trajectory(date=visit_date)
                
                user.trajectories[visit_date].add_visit(visit)
        
        return users_dict
    
    @staticmethod
    def _process_crossing_visits(row: pd.Series) -> List[tuple]:
        """
        Process a visit record that may cross midnight (3 AM threshold).
        
        The daily trajectory is defined as: from 3:00 AM today to 3:00 AM tomorrow.
        A timestamp is assigned to its "active day" based on when the person was likely active:
        - Before 3 AM: belongs to previous calendar day (stayed up late)
        - After 3 AM: belongs to current calendar day
        
        This method splits a visit into segments that fit within individual active days.
        
        Handles complex cases:
        1. Visit crosses a single 3 AM threshold
        2. Visit crosses multiple 3 AM thresholds (e.g., spanning multiple days)
        3. Visit starts after 3 AM and ends after 3 AM but on different days
        
        Returns:
        - List of (Visit, date) tuples, where date is YYYYMMDD of the active day
        """
        threshold_hour = User.DAY_THRESHOLD_HOUR
        result = []
        
        t_start = pd.to_datetime(row['t_start'])
        t_end = pd.to_datetime(row['t_end'])
        
        # Create initial visit with original attributes
        visit = Visit(
            t_start=t_start,
            t_end=t_end,
            lon=float(row['lon']),
            lat=float(row['lat']),
            cluster_id=int(row.get('cluster_id', -1)),
            ptype=int(row['ptype']) if pd.notna(row.get('ptype')) else None,
            poi=int(row['poi']) if pd.notna(row.get('poi')) else None
        )
        
        # Calculate threshold dates for start and end
        # The threshold date represents the "active day" for a given timestamp
        # This is based on the intuition that:
        # - Early morning (e.g., 2 AM) belongs to the previous calendar day (staying up late)
        # - Late night (e.g., 4 AM) belongs to the current calendar day
        #
        # A trajectory day is: from 3:00 AM today to 3:00 AM tomorrow
        # e.g., 2019-01-01 02:00 belongs to threshold date 2018-12-31 (stayed up late)
        # e.g., 2019-01-01 04:00 belongs to threshold date 2019-01-01
        def get_threshold_date(dt: pd.Timestamp) -> int:
            """Get the threshold date (YYYYMMDD) for a given datetime."""
            if dt.hour >= threshold_hour:
                # After 3 AM: belongs to current calendar day
                threshold_date = dt.date()
            else:
                # Before 3 AM: belongs to previous calendar day (stayed up late)
                threshold_date = (dt - timedelta(days=1)).date()
            return int(threshold_date.strftime('%Y%m%d'))
        
        start_threshold_date = get_threshold_date(t_start)
        end_threshold_date = get_threshold_date(t_end)
        
        # If no threshold crossing, no splitting needed
        if start_threshold_date == end_threshold_date:
            result.append((visit, start_threshold_date))
            return result
        
        # Visit spans multiple threshold dates - need to split
        # Find the first threshold that is AFTER t_start
        # If t_start is already past today's threshold, first threshold is tomorrow
        first_threshold = t_start.replace(hour=threshold_hour, minute=0, second=0)
        if t_start >= first_threshold:
            first_threshold = first_threshold + timedelta(days=1)
        
        # Part 1: From t_start to first threshold (if t_start is before first threshold)
        if t_start < first_threshold:
            visit1 = Visit(
                t_start=t_start,
                t_end=first_threshold,
                lon=visit.lon,
                lat=visit.lat,
                cluster_id=visit.cluster_id,
                ptype=visit.ptype,
                poi=visit.poi
            )
            # Segment ends before first threshold, so it belongs to previous trajectory
            # But we use start_threshold_date which correctly handles this
            result.append((visit1, start_threshold_date))
        
        # Middle parts: full threshold-to-threshold segments
        # These represent full days where the person was stationary
        current_threshold = first_threshold
        while current_threshold < t_end:
            next_threshold = current_threshold + timedelta(days=1)
            if next_threshold > t_end:
                break
            
            # Create a visit representing the full stationary day
            middle_visit = Visit(
                t_start=current_threshold,
                t_end=next_threshold,
                lon=visit.lon,
                lat=visit.lat,
                cluster_id=visit.cluster_id,
                ptype=visit.ptype,
                poi=visit.poi
            )
            # This segment crosses midnight, so belongs to the day after current_threshold starts
            # Which is exactly what get_threshold_date(t_end) gives us
            current_threshold_date = get_threshold_date(current_threshold + timedelta(hours=12))
            result.append((middle_visit, current_threshold_date))
            
            current_threshold = next_threshold
        
        # Last part: from last threshold to t_end
        last_threshold = current_threshold
        if t_end > last_threshold:
            last_visit = Visit(
                t_start=last_threshold,
                t_end=t_end,
                lon=visit.lon,
                lat=visit.lat,
                cluster_id=visit.cluster_id,
                ptype=visit.ptype,
                poi=visit.poi
            )
            # Use the end time to determine the trajectory date
            last_threshold_date = get_threshold_date(t_end)
            result.append((last_visit, last_threshold_date))
        
        return result
    
    def __repr__(self) -> str:
        return f"User(id={self.id}, trajectories={self.num_trajectories}, total_visits={self.total_visits})"


def convert_dataframe_to_users(df: pd.DataFrame) -> Dict[int, User]:
    """
    Convert a DataFrame to a dictionary of User objects.
    
    Parameters:
    - df: DataFrame with columns ['who', 'date', 't_start', 't_end', 'lon', 'lat', 'cluster_id', 'ptype', 'poi']
    
    Returns:
    - Dictionary mapping user_id to User objects
    """
    print(f"Converting DataFrame with {len(df)} records to User objects...")
    users = User.from_dataframe(df)
    print(f"Created {len(users)} User objects")
    return users


# ====== Execute Step 5: Convert DataFrame to User Model ======
print("\n" + "="*60)
print("Converting DataFrame to User Model...")
print("="*60)

users_dict = convert_dataframe_to_users(st_clustered)

# Display sample users
print("\nSample users:")
for i, (user_id, user) in enumerate(list(users_dict.items())[:3]):
    print(f"\n{user}")
    print(f"  - Unique clusters visited: {len(user.unique_clusters)}")
    print(f"  - Date range: {min(user.trajectories.keys())} to {max(user.trajectories.keys())}")

# Convert back to DataFrame to verify
print("\nVerification: Converting User objects back to DataFrame...")
test_user_id = list(users_dict.keys())[0]
test_user = users_dict[test_user_id]
test_df = test_user.to_dataframe()
print(f"User {test_user_id} converted back to DataFrame: {len(test_df)} records")
print(test_df.head())


Converting DataFrame to User Model...
Converting DataFrame with 567555 records to User objects...
Created 441 User objects

Sample users:

User(id=126272, trajectories=350, total_visits=1320)
  - Unique clusters visited: 33
  - Date range: 20181231 to 20191231

User(id=278978, trajectories=346, total_visits=1182)
  - Unique clusters visited: 25
  - Date range: 20181231 to 20191231

User(id=395753, trajectories=360, total_visits=1537)
  - Unique clusters visited: 36
  - Date range: 20181231 to 20191231

Verification: Converting User objects back to DataFrame...
User 126272 converted back to DataFrame: 1320 records
      who      date             t_start               t_end         lon  \
0  126272  20181231 2019-01-01 00:28:58 2019-01-01 03:00:00  113.834643   
1  126272  20190101 2019-01-01 03:00:00 2019-01-01 10:11:49  113.834643   
2  126272  20190101 2019-01-01 10:19:15 2019-01-01 17:13:30  113.948529   
3  126272  20190101 2019-01-01 17:26:58 2019-01-01 18:47:28  113.889141   
4  

# Model-free RL modeling

## Simple TD learning

In [8]:
# Reference: 251111_rl_demo.ipynb - MF modeling approach
# Enhanced with: time encoding, day sequence, forgetting rate, end-of-day action (-9), delayed state inclusion
import time
from scipy.special import logsumexp
from scipy.optimize import minimize
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import warnings


In [None]:
# =============================================================================
# MF Model Configuration
# =============================================================================

@dataclass
class MFConfig:
    """Configuration for Model-Free RL estimation."""
    # Learning rate for TD updates (alpha)
    alpha_init: float = 0.1
    # Softmax inverse temperature (beta)
    beta_init: float = 1.0
    # Exploration probability (epsilon)
    epsilon_init: float = 0.1
    # Forgetting rate (phi) - discount applied when day changes
    phi_init: float = 0.1  # Will be estimated as sigmoid(logit_phi)
    # Reward type: 'linear', 'power', 'log'
    reward_type: str = 'log'
    # Reward parameter (for power/log functions)
    reward_param_init: float = 1.0
    # Visit threshold before adding to known set
    visit_threshold: int = 3
    # Maximum iterations for optimization
    maxiter: int = 1000
    # Convergence tolerance
    ftol: float = 1e-6
    # Reference date for day sequence calculation
    ref_date: Tuple[int, int, int] = (2018, 12, 31)


# =============================================================================
# Data Preparation Functions
# =============================================================================

def compute_day_sequence(date_array: np.ndarray[int], 
                        ref_date: int = None) -> np.ndarray:
    """
    Compute day sequence as days since reference date.
    
    Parameters:
    - date_array: Array of dates in YYYYMMDD format
    - ref_year, ref_month, ref_day: Reference date
    
    Returns:
    - Array of day sequences (integers, monotonically increasing)
    """
    if ref_date is None:
        ref_date = datetime(2018, 12, 31)
    else:
        # convert the YYYYMMDD integer string to datetime
        ref_date = datetime.strptime(str(ref_date), '%Y%m%d')
    day_seq = np.zeros(len(date_array), dtype=int)
    for i, d in enumerate(date_array):
        year = d // 10000
        month = (d // 100) % 100
        day = d % 100
        current_date = datetime(year, month, day)
        day_seq[i] = (current_date - ref_date).days
    
    return day_seq


def compute_time_angle(time_val: datetime,
                       angle_base_hour = 3) -> float:
    """
    Compute time angle scaled to [0, 1] where:
    - 0.0 = same day's 3 AM
    - 1.0 = next day's 3 AM (24 hours later)
    
    Example: 4 AM -> 1/24, Noon (12 PM) -> 9/24 = 0.375
    
    Parameters:
    - time_val: Datetime of stay end (departure time)
    
    Returns:
    - Time angle in [0, 1]
    """
    # 3 AM as the reference "start" of the day
    
    hour = time_val.hour
    minute = time_val.minute
    second = time_val.second

    # Compute total seconds since 3 AM; wrap around 24 hours
    seconds_in_day = 24 * 3600
    time_seconds = hour * 3600 + minute * 60 + second
    base_seconds = angle_base_hour * 3600

    delta = (time_seconds - base_seconds) % seconds_in_day  # wrap negative values to next day
    time_angle = delta / seconds_in_day  # scale to 0~1
    return time_angle

In [10]:
def compute_reward_array(stay_minutes: np.ndarray, 
                           reward_type: str = 'log', 
                           reward_param: Optional[float] = None) -> np.ndarray:
    """
    Compute reward array based on stay duration (time-based rewards).
    """
    base = np.maximum(stay_minutes / 30.0, 0.0)
    
    if reward_type == 'linear' or reward_param is None:
        rewards = base
    elif reward_type == 'power':
        rewards = np.power(base, reward_param)
    elif reward_type == 'log':
        # Normalized log reward: 0 to 1 range
        rewards = np.log1p(reward_param * base) / np.log1p(reward_param)
    else:
        rewards = base
    
    return np.nan_to_num(rewards, nan=0.0, posinf=0.0, neginf=0.0)


def compute_time_discount_factor(time_angle: float) -> float:
    """
    Compute time discount factor based on time angle.
    
    The factor is: 1 / (1 - time_angle)
    This scales Q values so that values at different times of day are comparable.
    
    Example:
    - time_angle = 0 (3 AM): factor = 1.0
    - time_angle = 0.75 (9 PM): factor = 4.0
    - time_angle = 0.99 (2:58 AM next day): factor = 100.0
    
    Parameters:
    - time_angle: Time of day in [0, 1] scale
    
    Returns:
    - Discount factor for time scaling
    """
    if 1.0 - time_angle <= 0:
        factor = 100.0
    else:
        factor = 1.0 / (1.0 - time_angle)
        if factor > 100.0:
            factor = 100.0
    return factor

In [None]:
def prepare_mf_data(user_df: pd.DataFrame,
                   config: MFConfig = None) -> Dict[str, Any]:
    """
    Prepare MF modeling data from user trajectory.
    
    This function:
    1. Computes day sequence (days since reference date)
    2. Computes time angles for each stay
    3. Labels actions: -9 (end-of-day), -1 (explore), 0+ (cluster)
    
    Parameters:
    - user_df: DataFrame with columns ['t_start', 't_end', 'cluster_id', 'date']
    - config: MF configuration
    
    Returns:
    - Dictionary with prepared data for MF modeling
    """
    if config is None:
        config = MFConfig()
    
    # Sort by timestamp
    df = user_df.sort_values('t_start').reset_index(drop=True)
    
    # Extract basic data
    n_records = len(df)
    
    # Compute day sequence
    if 'date' not in df.columns:
        raise ValueError("Input DataFrame must contain 'date' column for MF model preparation.")
    date_array = df['date'].to_numpy()
    day_seq = compute_day_sequence(date_array, 
                                   config.ref_date[0], 
                                   config.ref_date[1], 
                                   config.ref_date[2])

    
    # Compute time angles (0-1 scale from 3 AM to next 3 AM)
    time_angles = np.zeros(n_records)
    for i, t_end in enumerate(df['t_end']):
        time_angles[i] = compute_time_angle(t_end)
    
    # Extract states (cluster_ids)
    states = df['cluster_id'].astype(int).to_numpy()
    
    # Compute stay durations for rewards
    # note, here the reward should be the next one, not current one.
    stay_minutes = (df['t_end'] - df['t_start']).dt.total_seconds() / 60.0
    stay_minutes = stay_minutes.to_numpy()
    stay_minutes = np.roll(stay_minutes, -1)
    stay_minutes[-1] = 0
    
    # Label actions
    # -9: end-of-day action (when next day_seq < current day_seq or last record)
    # -1: explore new location
    # 0+: transition to specific cluster
    actions = np.zeros(n_records, dtype=int)
    actions[-1] = -9
    for t in range(n_records - 1):
        current_day = day_seq[t]
        next_day = day_seq[t + 1]
        
        if next_day > current_day:
            # Transition to next day -> end-of-day action
            actions[t] = -9
        else:
            # Normal transition
            next_state = states[t + 1]
            if next_state == -1:
                # Transition to noise/unknown location -> explore action
                actions[t] = -1
            elif next_state == 0:
                # Transition to outside study area
                actions[t] = 0
            else:
                # Transition to known cluster
                actions[t] = next_state
    
    # Compute same-day next indicator for TD bootstrapping
    same_day_next = np.zeros(n_records, dtype=bool)
    if n_records > 2:
        for t in range(n_records - 1):
            if day_seq[t] == day_seq[t + 1]: 
                same_day_next[t] = True
    
    return {
        'states': states,
        'actions': actions,
        'day_seq': day_seq,
        'time_angles': time_angles,
        'stay_minutes': stay_minutes,
        'date_array': date_array,
        'n_records': n_records, 
        'same_day_next': same_day_next,
        'df': df
    }


In [13]:

# =============================================================================
# Parameter Handling
# =============================================================================

def unpack_params_mf(theta: np.ndarray, 
                     feature_dim: int = 0, 
                     has_reward_param: bool = True) -> Dict[str, float]:
    """
    Unpack MF model parameters from optimization vector.
    
    Parameters:
    - theta: Parameter vector
    - feature_dim: Dimension of feature weights (for future extensions)
    - has_reward_param: Whether reward_param is included
    
    Returns:
    - Dictionary with alpha, beta, epsilon, phi, reward_param
    """
    idx = 0
    
    # Feature weights (placeholder for future extensions)
    w = theta[:feature_dim] if feature_dim > 0 else np.array([])
    idx += feature_dim
    
    # TD learning rate: alpha = sigmoid(logit_alpha)
    logit_alpha = theta[idx]; idx += 1
    alpha = 1.0 / (1.0 + np.exp(-logit_alpha))
    
    # Softmax temperature: beta = exp(log_beta)
    log_beta = theta[idx]; idx += 1
    beta = np.exp(log_beta)
    
    # Exploration probability: epsilon = sigmoid(logit_epsilon)
    logit_epsilon = theta[idx]; idx += 1
    epsilon = 1.0 / (1.0 + np.exp(-logit_epsilon))
    
    # Forgetting rate: phi = sigmoid(logit_phi) in [0, 1)
    logit_phi = theta[idx]; idx += 1
    phi = 1.0 / (1.0 + np.exp(-logit_phi))
    
    result = {
        'w': w,
        'alpha': alpha,
        'beta': beta,
        'epsilon': epsilon,
        'phi': phi
    }
    
    if has_reward_param:
        logit_reward = theta[idx]; idx += 1
        reward_param = np.exp(logit_reward)
        result['reward_param'] = reward_param
    
    return result


def pack_params_mf(alpha: float, beta: float, epsilon: float, phi: float,
                   reward_param: Optional[float] = None,
                   feature_dim: int = 0) -> np.ndarray:
    """
    Pack MF model parameters into optimization vector.
    
    Parameters:
    - alpha: TD learning rate (0, 1)
    - beta: Softmax inverse temperature (0, +inf)
    - epsilon: Exploration probability (0, 1)
    - phi: Forgetting rate (0, 1)
    - reward_param: Reward function parameter
    - feature_dim: Dimension of feature weights
    
    Returns:
    - Parameter vector
    """
    params = []
    
    if feature_dim > 0:
        params.extend([0.0] * feature_dim)
    
    params.append(np.log(alpha / (1.0 - alpha)))  # logit_alpha
    params.append(np.log(beta))                   # log_beta
    params.append(np.log(epsilon / (1.0 - epsilon)))  # logit_epsilon
    params.append(np.log(phi / (1.0 - phi)))     # logit_phi
    
    if reward_param is not None:
        params.append(np.log(reward_param))       # log_reward_param
    
    return np.array(params, dtype=np.float64)


In [None]:
# =============================================================================
# MF Simulation and Log-Likelihood
# =============================================================================

def simulate_and_loglik_mf(theta: np.ndarray,
                          mf_data: Dict[str, Any],
                          feature_dim: int = 0,
                          reward_type: str = 'log',
                          has_reward_param: bool = True,
                          visit_threshold: int = 3) -> float:
    """
    Compute negative log-likelihood for enhanced MF (TD) model.
    
    Features:
    - TD learning with time scaling
    - Forgetting rate between days
    - Delayed state inclusion (after visit_threshold visits)
    - Actions: -9 (end), -1 (explore), 0+ (clusters)
    
    Parameters:
    - theta: Parameter vector
    - mf_data: Prepared MF data dictionary
    - feature_dim: Dimension of feature weights
    - reward_type: Reward function type
    - has_reward_param: Whether reward_param is included
    - visit_threshold: Visits required before adding to known set
    
    Returns:
    - Negative log-likelihood
    """
    # Unpack parameters
    params = unpack_params_mf(theta, feature_dim, has_reward_param)
    alpha = params['alpha']
    beta = params['beta']
    epsilon = params['epsilon']
    phi = params['phi']
    reward_param = params.get('reward_param', 1.0)
    
    # Extract data
    states = mf_data['states']
    actions = mf_data['actions']
    day_seq = mf_data['day_seq']
    time_angles = mf_data['time_angles']
    n_records = mf_data['n_records']
    same_day_next = mf_data['same_day_next']
    
    # Compute reward array
    reward_array = compute_reward_array(mf_data['stay_minutes'], reward_type, reward_param)
    
    # Q-tables: Q[s][a] -> Q value
    Q_tables: Dict[int, Dict[int, float]] = defaultdict(lambda: defaultdict(float))
    
    # Track visit counts for delayed inclusion
    visit_counts: Dict[int, int] = defaultdict(int)
    
    # Known states and actions (with delayed inclusion)
    known_states: Set[int] = {-1, 0}  # Noise and outside area always known
    known_actions: Set[int] = {-9, -1, 0}  # End, explore, outside always known
    
    loglik = 0.0
    
    # Track previous day for forgetting
    prev_day = None
    
    for t in range(n_records):
        s = int(states[t])
        a = int(actions[t])
        r_t = float(reward_array[t])
        current_day = int(day_seq[t])
        time_angle = float(time_angles[t])
        
        # Apply forgetting when day changes
        if prev_day is not None and current_day > prev_day:
            # New day: apply forgetting discount to all Q values
            discount_factor = (1.0 - phi)
            for state_dict in Q_tables.values():
                for action_key in state_dict:
                    multi_day_discount = discount_factor ** (current_day - prev_day)
                    state_dict[action_key] *= multi_day_discount
        prev_day = current_day
                
        # Update visit counts and known sets
        if a > 0:  # Transition to a cluster
            visit_counts[a] += 1
            if visit_counts[a] >= visit_threshold:
                known_states.add(a)
                known_actions.add(a)

        s_perc = s if s in known_states else -1
        a_perc = a if a in known_actions else -1

        # Apply time scaling to current Q values
        # time_scale = compute_time_discount_factor(time_angle)
        time_scale = 1
        
        # Build action list for current state
        # Actions available: -9 (end), -1 (explore), 0 (outside), plus visited clusters
        # Note: -1 (explore) is not selected through Q-function, but through epsilon,
        # so exclude -1 from evaluated_actions for softmax policy
        evaluated_actions = sorted([act for act in known_actions if act != -1])
        
        q_values = []
        for act in evaluated_actions:
            q_td = Q_tables[s_perc].get(act, 0.0)
            # Scale Q value by time discount
            q_values.append(q_td * time_scale)
        
        q_values = np.asarray(q_values, dtype=np.float64)
        
        # Softmax policy
        logits = beta * q_values
        probs_exploit = np.exp(logits - logsumexp(logits))
        
        # Map action to probability
        if a_perc in evaluated_actions:
            idx_a = evaluated_actions.index(a_perc)
            if idx_a < len(probs_exploit):
                action_prob = (1.0 - epsilon) * probs_exploit[idx_a]
        elif a_perc == -1: 
        # meaning the action is exploration -1.
        # Add exploration probability
            action_prob = epsilon
        else:
            action_prob = 0
            warnings.warn("The action falls out the consideration when evaluation.")

        loglik += np.log(action_prob + 1e-12)
        
        # TD update with time scaling
        # Next Q value also needs time scaling
        if t < n_records - 1 and same_day_next[t]:
            a_next = int(actions[t + 1])
            a_next_perc = a_next if a_next in known_actions else -1

            next_time_angle = float(time_angles[t + 1])
            # next_time_discount = compute_time_discount_factor(next_time_angle)
            next_time_scale = 1
            next_Q_td = Q_tables[a_perc].get(a_next_perc, 0.0) * next_time_scale
        else:
            next_Q_td = 0.0
        
        # TD error (using scaled Q values)
        current_Q = Q_tables[s_perc].get(a_perc, 0.0) * time_scale
        delta = r_t + next_Q_td - current_Q
        
        # Update Q value (unscale before update)
        Q_tables[s_perc][a_perc] = Q_tables[s_perc].get(a_perc, 0.0) + alpha * delta
    
    return -loglik

In [None]:

# =============================================================================
# Model Fitting
# =============================================================================

def fit_mf_model(user_df: pd.DataFrame,
                config: MFConfig = None,
                verbose: bool = True) -> Dict[str, Any]:
    """
    Fit enhanced MF model to a single user's trajectory data.
    
    Parameters:
    - user_df: DataFrame with columns ['t_start', 't_end', 'cluster_id', 'date']
    - config: MF configuration
    - verbose: Whether to print progress
    
    Returns:
    - Dictionary with fitted parameters, statistics, and metadata
    """
    if config is None:
        config = MFConfig()
    
    # Prepare MF data
    mf_data = prepare_mf_data(user_df, config)
    n_records = mf_data['n_records']
    
    if verbose:
        print(f"Preparing MF model for {len(user_df)} visits, {n_records} actions (including end-of-day)...")
    
    # Parameter dimension
    has_reward_param = config.reward_type in ('power', 'log')
    extra_params = 4 + (1 if has_reward_param else 0)  # alpha, beta, epsilon, phi, [reward_param]
    param_dim = extra_params
    
    # Initial parameters
    initial_theta = pack_params_mf(
        alpha=config.alpha_init,
        beta=config.beta_init,
        epsilon=config.epsilon_init,
        phi=config.phi_init,
        reward_param=config.reward_param_init if has_reward_param else None,
        feature_dim=0
    )
    
    # Optimize
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        result = minimize(
            simulate_and_loglik_mf,
            initial_theta,
            args=(mf_data, 0, config.reward_type, has_reward_param, config.visit_threshold),
            method='L-BFGS-B',
            options={'maxiter': config.maxiter, 'ftol': config.ftol}
        )
    
    # Extract fitted parameters
    fitted_params = unpack_params_mf(result.x, 0, has_reward_param)
    final_loglik = -result.fun
    
    # Compute model selection criteria
    k_params = param_dim
    AIC = 2 * k_params - 2 * final_loglik
    BIC = k_params * np.log(n_records) - 2 * final_loglik if n_records > 0 else np.inf
    
    summary = {
        'n_records': int(n_records),
        'log_likelihood': float(final_loglik),
        'AIC': float(AIC),
        'BIC': float(BIC),
        'alpha_td': float(fitted_params['alpha']),
        'beta': float(fitted_params['beta']),
        'epsilon_explore': float(fitted_params['epsilon']),
        'phi_forget': float(fitted_params['phi']),
        'reward_type': config.reward_type,
        'reward_param': float(fitted_params.get('reward_param', None)),
        'visit_threshold': config.visit_threshold,
        'converged': result.success,
        'n_iterations': result.nit if hasattr(result, 'nit') else None,
        'optimization_message': result.message if hasattr(result, 'message') else None
    }
    
    if verbose:
        print(f"MF model fitting {'converged' if result.success else 'did not converge'}.")
        print(f"  Log-likelihood: {final_loglik:.2f}, AIC: {AIC:.2f}, BIC: {BIC:.2f}")
        print(f"  alpha (TD rate): {fitted_params['alpha']:.4f}")
        print(f"  beta (softmax temp): {fitted_params['beta']:.4f}")
        print(f"  epsilon (explore): {fitted_params['epsilon']:.4f}")
        print(f"  phi (forgetting): {fitted_params['phi']:.4f}")
        if has_reward_param:
            print(f"  reward_param: {fitted_params.get('reward_param', None):.4f}")
    
    return summary


def fit_mf_for_all_users(users_dict: Dict[int, Any],
                         config: MFConfig = None,
                         sample_size: Optional[int] = None,
                         verbose: bool = True) -> pd.DataFrame:
    """
    Fit enhanced MF model for all users in the dictionary.
    
    Parameters:
    - users_dict: Dictionary mapping user_id to User objects
    - config: MF configuration
    - sample_size: Number of users to fit (None for all)
    - verbose: Whether to print progress
    
    Returns:
    - DataFrame with fitted parameters for each user
    """
    if config is None:
        config = MFConfig()
    
    user_ids = list(users_dict.keys())
    if sample_size is not None:
        user_ids = user_ids[:sample_size]
    
    results = []
    for i, user_id in enumerate(user_ids):
        if verbose:
            print(f"[{i+1}/{len(user_ids)}] Fitting MF for user {user_id}...", end="")
        
        user = users_dict[user_id]
        user_df = user.to_dataframe()
        
        try:
            t_start = time.time()
            result = fit_mf_model(user_df, config, verbose=False)
            t_end = time.time()
            elapsed = t_end - t_start
            if verbose:
                print(f"\tUser {user_id} MF model fit time: {elapsed:.2f} seconds")
            result['user_id'] = user_id
            result['fit_time_seconds'] = elapsed
            results.append(result)
        except Exception as e:
            if verbose:
                print(f"  Error fitting user {user_id}: {e}")
            results.append({
                'user_id': user_id, 'n_visits': 0, 'n_records': 0,
                'log_likelihood': np.nan, 'AIC': np.nan, 'BIC': np.nan,
                'alpha_td': np.nan, 'beta': np.nan, 'epsilon_explore': np.nan,
                'phi_forget': np.nan, 'reward_type': config.reward_type,
                'reward_param': np.nan, 'visit_threshold': config.visit_threshold,
                'converged': False, 'n_iterations': None, 'optimization_message': str(e),
                'fit_time_seconds': np.nan
            })
    
    return pd.DataFrame(results)



In [None]:

# =============================================================================
# Execute Enhanced MF Modeling
# =============================================================================

print("\n" + "="*60)
print("Enhanced Model-Free (MF) RL Estimation")
print("Features: Time encoding, Day sequence, Forgetting rate,")
print("          End-of-day action (-9), Delayed state inclusion")
print("="*60)

# Configure enhanced MF model
mf_config = MFConfig(
    alpha_init=0.1,
    beta_init=1.0,
    epsilon_init=0.1,
    phi_init=0.1,  # Initial forgetting rate
    reward_type='log',
    reward_param_init=1.0,
    visit_threshold=3,  # Add to known set after 3 visits
    maxiter=1000,
    ftol=1e-6
)

# Fit MF model for sample users
mf_results = fit_mf_for_all_users(
    users_dict, 
    config=mf_config, 
    sample_size=50,  # Change to None for all users
    verbose=True
)

print("\nEnhanced MF Model Fitting Summary:")
print(mf_results[['user_id', 'n_records', 'log_likelihood', 'AIC', 'BIC', 
                  'alpha_td', 'beta', 'epsilon_explore', 'phi_forget']].to_string(index=False))

# Save results
mf_results.to_csv('mf_estimation_results_enhanced.csv', index=False)
print(f"\nEnhanced MF estimation results saved to 'mf_estimation_results_enhanced.csv'")

# Aggregate statistics
print("\nAggregate Statistics Across Users:")
print(f"  Average alpha (TD rate): {mf_results['alpha_td'].mean():.4f} ± {mf_results['alpha_td'].std():.4f}")
print(f"  Average beta (softmax temp): {mf_results['beta'].mean():.4f} ± {mf_results['beta'].std():.4f}")
print(f"  Average epsilon (explore): {mf_results['epsilon_explore'].mean():.4f} ± {mf_results['epsilon_explore'].std():.4f}")
print(f"  Average phi (forgetting): {mf_results['phi_forget'].mean():.4f} ± {mf_results['phi_forget'].std():.4f}")
print(f"  Average log-likelihood: {mf_results['log_likelihood'].mean():.2f} ± {mf_results['log_likelihood'].std():.2f}")



Enhanced Model-Free (MF) RL Estimation
Features: Time encoding, Day sequence, Forgetting rate,
          End-of-day action (-9), Delayed state inclusion
[1/50] Fitting MF for user 126272...
    User 126272 MF model fit time: 16.96 seconds
[2/50] Fitting MF for user 278978...
    User 278978 MF model fit time: 9.96 seconds
[3/50] Fitting MF for user 395753...
    User 395753 MF model fit time: 15.92 seconds
[4/50] Fitting MF for user 506035...
    User 506035 MF model fit time: 12.79 seconds
[5/50] Fitting MF for user 612431...
    User 612431 MF model fit time: 9.39 seconds
[6/50] Fitting MF for user 661336...
    User 661336 MF model fit time: 19.77 seconds
[7/50] Fitting MF for user 824617...



overflow encountered in exp



    User 824617 MF model fit time: 11.81 seconds
[8/50] Fitting MF for user 928827...
    User 928827 MF model fit time: 21.78 seconds
[9/50] Fitting MF for user 1017498...
    User 1017498 MF model fit time: 10.35 seconds
[10/50] Fitting MF for user 1159109...
    User 1159109 MF model fit time: 15.34 seconds
[11/50] Fitting MF for user 1196732...
    User 1196732 MF model fit time: 13.46 seconds
[12/50] Fitting MF for user 1200620...
    User 1200620 MF model fit time: 13.82 seconds
[13/50] Fitting MF for user 1369893...
    User 1369893 MF model fit time: 12.41 seconds
[14/50] Fitting MF for user 1581867...
    User 1581867 MF model fit time: 11.12 seconds
[15/50] Fitting MF for user 1840362...
    User 1840362 MF model fit time: 13.13 seconds
[16/50] Fitting MF for user 2046261...
    User 2046261 MF model fit time: 10.79 seconds
[17/50] Fitting MF for user 2221830...
    User 2221830 MF model fit time: 14.87 seconds
[18/50] Fitting MF for user 2436270...
    User 2436270 MF model 

## MF enhanced with episodic memory

When considering generalization to novel time state, we refer to the approach in area of episodic RL to estimate the new Q-value or SR value. It functions in a non-parametric way: 
- Retrival: kernal weights + recency bias
- Storage: direct input.
Bayesian method (no explicity episodic memory) is not a possible choice, because the value estimation problem is not a stochastic one.

In [36]:
# =============================================================================
# MF Enhanced with Episodic Memory (Time-Augmented)
# =============================================================================

from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass
from datetime import datetime
from collections import defaultdict


@dataclass
class MFEpiConfig:
    """Configuration for MF + episodic memory model."""
    alpha_init: float = 0.1
    beta_init: float = 1.0
    epsilon_init: float = 0.1
    phi_init: float = 0.1
    sigma_t_init: float = 1.0 / 12.0

    reward_type: str = 'log'
    reward_param_init: float = 1.0

    visit_threshold: int = 3
    memory_threshold: float = 0.01

    maxiter: int = 1000
    ftol: float = 1e-6


@dataclass
class EpisodicRecord:
    """One memory trace for a (state, action, time) tuple."""
    q_value: float
    time_angle: float
    day_seq: int
    record_date: int
    strength: float = 1.0


class EpisodicMemory:
    """Non-parametric episodic memory for Q-value retrieval."""

    def __init__(self, phi: float, config: Optional[MFEpiConfig] = None):
        self.config = config if config is not None else MFEpiConfig()
        self.phi = phi
        self.Q_table: Dict[int, Dict[int, List[EpisodicRecord]]] = defaultdict(lambda: defaultdict(list))
        self.Q_decay: Dict[int, Dict[int, float]] = defaultdict(lambda: defaultdict(float))
        self.node_strength: Dict[int, float] = {}
        self.node_visits: Dict[int, int] = {}
        self.last_day: Optional[int] = None
        self._active_nodes: Optional[Set[int]] = None

    def add_record(self, node_id: int, action_id: int, time_angle: float, q_value: float,
                   day_seq: int, record_date: int):
        rec = EpisodicRecord(
            q_value=float(q_value),
            time_angle=float(time_angle),
            day_seq=int(day_seq),
            record_date=int(record_date),
            strength=1.0,
        )
        self.Q_table[node_id][action_id].append(rec)
        self.Q_decay[node_id][action_id] = 1.0
        self.node_strength[node_id] = self.node_strength.get(node_id, 0.0) + 1.0
        self.node_visits[node_id] = self.node_visits.get(node_id, 0) + 1
        self._active_nodes = None

    def decay(self, current_day: int):
        if self.last_day is None:
            self.last_day = current_day
            return

        day_diff = int(current_day - self.last_day)
        if day_diff > 0:
            factor = (1.0 - self.phi) ** day_diff
            for s in self.Q_table:
                for a in self.Q_table[s]:
                    for rec in self.Q_table[s][a]:
                        rec.strength *= factor
            for node_id in self.node_strength:
                self.node_strength[node_id] *= factor
                for action_id in self.Q_decay[node_id]:
                    self.Q_decay[node_id][action_id] *= factor
            self._active_nodes = None

        self.last_day = current_day

    @staticmethod
    def compute_time_similarity(t1: float, t2: float, sigma_t: float) -> float:
        """Circular Gaussian kernel on normalized time angle [0, 1)."""
        diff = abs(float(t1) - float(t2))
        sigma = max(float(sigma_t), 1e-6)
        return float(np.exp(-0.5 * (diff / sigma) ** 2))

    def get_records_for_sa_pair(self, node_id: int, action_id: int) -> List[EpisodicRecord]:
        return self.Q_table[int(node_id)][int(action_id)]

    def retrieve_q(self, target_node: int, target_action: int, target_time: float,
                   sigma_t: Optional[float] = None) -> float:
        """Return Q-hat(target_node, target_action, target_time); returns 0.0 if no evidence."""
        sigma = self.config.sigma_t_init if sigma_t is None else float(sigma_t)
        recs = self.get_records_for_sa_pair(target_node, target_action)
        if not recs:
            return 0.0

        q_values = np.array([rec.q_value for rec in recs], dtype=np.float64)
        strengths = np.array([rec.strength for rec in recs], dtype=np.float64)
        similarities = np.array([self.compute_time_similarity(target_time, rec.time_angle, sigma) for rec in recs])

        weights = similarities * strengths
        q_estimate = np.average(q_values, weights=weights)
        q_estimate *= self.Q_decay[target_node][target_action]

        return q_estimate

    def get_active_nodes(self) -> Set[int]:
        if self._active_nodes is not None:
            return self._active_nodes

        self._active_nodes = {
            node_id
            for node_id, strength in self.node_strength.items()
            if strength >= self.config.memory_threshold
            and self.node_visits.get(node_id, 0) >= self.config.visit_threshold
        }
        return self._active_nodes

In [37]:
# =============================================================================
# Data Preparation and Parameter Utilities
# =============================================================================

def prepare_data(user_df: pd.DataFrame, config: Optional[MFEpiConfig] = None) -> Dict[str, Any]:
    """Prepare trajectory arrays for MF + episodic model."""
    config = config if config is not None else MFEpiConfig()

    df = user_df.sort_values(by=['t_start']).reset_index(drop=True)
    n_records = len(df)

    states = df['cluster_id'].astype(int).to_numpy()
    time_angles = df['t_end'].apply(compute_time_angle).to_numpy(dtype=float)
    date_array = df['date'].to_numpy()

    date_baseline = int(date_array.min())
    day_seq = compute_day_sequence(date_array, date_baseline)

    stay_minutes = (df['t_end'] - df['t_start']).dt.total_seconds() / 60.0
    stay_minutes = stay_minutes.to_numpy(dtype=float)
    stay_minutes = np.roll(stay_minutes, -1)
    stay_minutes[-1] = 0.0

    actions = np.zeros(n_records, dtype=int)
    actions[-1] = -9
    for t in range(n_records - 1):
        if day_seq[t + 1] > day_seq[t]:
            actions[t] = -9
        else:
            actions[t] = int(states[t + 1])


    same_day_next = np.zeros(n_records, dtype=bool)
    for t in range(n_records - 1):
        same_day_next[t] = day_seq[t] == day_seq[t + 1]

    return {
        'states': states,
        'actions': actions,
        'day_seq': day_seq,
        'time_angles': time_angles,
        'date_array': date_array,
        'stay_minutes': stay_minutes,
        'same_day_next': same_day_next,
        'n_records': n_records,
    }


def unpack_params_time_epi(theta: np.ndarray) -> Dict[str, float]:
    """theta = [log_alpha, log_beta, logit_epsilon, logit_phi, log_sigma_t]."""
    idx = 0
    alpha = 1.0 / (1.0 + np.exp(-theta[idx])); idx += 1
    beta = np.exp(theta[idx]); idx += 1
    epsilon = 1.0 / (1.0 + np.exp(-theta[idx])); idx += 1
    phi = 1.0 / (1.0 + np.exp(-theta[idx])); idx += 1

    alpha = float(np.clip(alpha, 1e-6, 1.0))
    epsilon = float(np.clip(epsilon, 1e-6, 1.0 - 1e-6))

    return {
        'alpha': alpha,
        'beta': float(beta),
        'epsilon': epsilon,
        'phi': float(phi),
    }


def pack_params_time_epi(alpha: float, beta: float,
                         epsilon: float, phi: float) -> np.ndarray:
    return np.array([
        np.log(alpha / (1.0 - alpha)),
        np.log(beta),
        np.log(epsilon / (1.0 - epsilon)),
        np.log(phi / (1.0 - phi)),
    ], dtype=float)


In [40]:
# =============================================================================
# Training Objective and Model Fitting
# =============================================================================

def simulate_and_loglik_mfe(theta: np.ndarray,
                            mfe_data: Dict[str, Any],
                            config: Optional[MFEpiConfig] = None) -> float:
    """Negative log-likelihood for MF + episodic memory with TD updates."""
    config = config if config is not None else MFEpiConfig()

    params = unpack_params_time_epi(theta)
    alpha = params['alpha']
    beta = params['beta']
    epsilon = params['epsilon']
    phi = params['phi']


    states = mfe_data['states']
    actions = mfe_data['actions']
    time_angles = mfe_data['time_angles']
    day_seq = mfe_data['day_seq']
    date_array = mfe_data['date_array']
    same_day_next = mfe_data['same_day_next']
    n_records = mfe_data['n_records']

    reward_array = compute_reward_array(
        mfe_data['stay_minutes'],
        config.reward_type,
        config.reward_param_init,
    )

    visit_counts: Dict[int, int] = defaultdict(int)
    known_states: Set[int] = {-1, 0}
    known_actions: Set[int] = {-9, -1, 0}

    memory = EpisodicMemory(phi, config)
    loglik = 0.0

    for t in range(n_records):
        s = int(states[t])
        a = int(actions[t])
        r_t = float(reward_array[t])
        current_day = int(day_seq[t])
        current_date = int(date_array[t])
        time_angle = float(time_angles[t])

        memory.decay(current_day)

        if a > 0:
            visit_counts[a] += 1
            if visit_counts[a] >= config.visit_threshold:
                known_states.add(a)
                known_actions.add(a)

        s_perc = s if s in known_states else -1
        a_perc = a if a in known_actions else -1

        evaluated_actions = sorted([act for act in known_actions if act != -1])
        if len(evaluated_actions) == 0:
            action_prob = np.clip(epsilon, 1e-12, 1.0)
        else:
            q_values = np.array([
                memory.retrieve_q(s_perc, act, time_angle, config.sigma_t_init)
                for act in evaluated_actions
            ], dtype=float)

            beta_q = beta * q_values
            beta_q -= np.max(beta_q)
            softmax = np.exp(beta_q)
            softmax /= (np.sum(softmax) + 1e-12)

            probs = (1.0 - epsilon) * softmax
            if a_perc in evaluated_actions:
                idx_a = evaluated_actions.index(a_perc)
                action_prob = float(np.clip(probs[idx_a], 1e-12, 1.0))
            else:
                action_prob = float(np.clip(epsilon, 1e-12, 1.0))

        loglik += np.log(action_prob)

        if t < n_records - 1 and same_day_next[t]:
            a_next = int(actions[t + 1])
            a_next_perc = a_next if a_next in known_actions else -1
            next_time_angle = float(time_angles[t + 1])
            next_q = memory.retrieve_q(a_perc, a_next_perc, next_time_angle, config.sigma_t_init)
        else:
            next_q = 0.0

        if a_perc != -1:
            current_q = memory.retrieve_q(s_perc, a_perc, time_angle, config.sigma_t_init)
            delta = r_t + next_q - current_q
            new_q = current_q + alpha * delta
            memory.add_record(s_perc, a_perc, time_angle, new_q, current_day, current_date)

    return float(-loglik)


def fit_mfe_model(user_df: pd.DataFrame,
                            config: Optional[MFEpiConfig] = None,
                            verbose: bool = True) -> Dict[str, Any]:
    """Fit MF + episodic memory model for one user trajectory."""

    episodic_data = prepare_data(user_df, config)
    n_records = episodic_data['n_records']
    if n_records < 2:
        return {
            'n_records': int(n_records),
            'log_likelihood': np.nan,
            'AIC': np.nan,
            'BIC': np.nan,
            'alpha_td': np.nan,
            'beta': np.nan,
            'epsilon_explore': np.nan,
            'phi_forget': np.nan,
            'converged': False,
            'n_iterations': 0,
            'optimization_message': 'Insufficient records',
        }

    theta_init = pack_params_time_epi(
        alpha=np.clip(config.alpha_init, 1e-6, 1.0),
        beta=max(config.beta_init, 1e-6),
        epsilon=np.clip(config.epsilon_init, 1e-4, 1 - 1e-4),
        phi=np.clip(config.phi_init, 1e-4, 1 - 1e-4),
    )


    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        result = minimize(
            simulate_and_loglik_mfe,
            theta_init,
            args=(episodic_data, config),
            method='L-BFGS-B',
            options={'maxiter': int(config.maxiter), 'ftol': float(config.ftol), 'disp': False},
        )

    fitted = unpack_params_time_epi(result.x)
    log_likelihood = -float(result.fun)

    k_params = 4
    aic = 2 * k_params - 2 * log_likelihood
    bic = k_params * np.log(max(n_records, 1)) - 2 * log_likelihood

    summary = {
        'n_records': int(n_records),
        'log_likelihood': float(log_likelihood),
        'AIC': float(aic),
        'BIC': float(bic),
        'alpha_td': float(fitted['alpha']),
        'beta': float(fitted['beta']),
        'epsilon_explore': float(fitted['epsilon']),
        'phi_forget': float(fitted['phi']),
        'converged': bool(result.success),
        'n_iterations': int(result.nit),
        'optimization_message': str(result.message),
    }

    if verbose:
        print(f"MFE model fitting {'converged' if result.success else 'did not converge'}.")
        print(f"  Log-likelihood: {summary['log_likelihood']:.2f}, AIC: {summary['AIC']:.2f}, BIC: {summary['BIC']:.2f}")
        print(f"  alpha (TD rate): {summary['alpha_td']:.4f}")
        print(f"  beta (softmax temp): {summary['beta']:.4f}")
        print(f"  epsilon (explore): {summary['epsilon_explore']:.4f}")
        print(f"  phi (forgetting): {summary['phi_forget']:.4f}")

    return summary


def fit_mfe_for_all_users(users_dict: Dict[int, Any],
                          config: Optional[MFEpiConfig] = None,
                          sample_size: Optional[int] = None,
                          verbose: bool = True) -> pd.DataFrame:
    """Fit MF+episodic model for all users."""
    config = config if config is not None else MFEpiConfig()

    user_ids = list(users_dict.keys())
    if sample_size is not None:
        user_ids = user_ids[:sample_size]

    results = []
    for i, user_id in enumerate(user_ids):
        if verbose:
            print(f"[{i+1}/{len(user_ids)}] Fitting MFE for user {user_id}...", end='')

        user = users_dict[user_id]
        user_df = user.to_dataframe()

        try:
            t0 = time.time()
            result = fit_mfe_model(user_df, config, verbose=False)
            elapsed = time.time() - t0
            if verbose:
                print(f"\tUser {user_id} MFE model fit time: {elapsed:.2f} seconds")
            result['user_id'] = user_id
            result['fit_time_seconds'] = elapsed
            results.append(result)
        except Exception as e:
            if verbose:
                print(f"  Error fitting user {user_id}: {e}")
            results.append({
                'user_id': user_id,
                'n_records': 0,
                'log_likelihood': np.nan,
                'AIC': np.nan,
                'BIC': np.nan,
                'alpha_td': np.nan,
                'beta': np.nan,
                'epsilon_explore': np.nan,
                'phi_forget': np.nan,
                'converged': False,
                'n_iterations': None,
                'optimization_message': str(e),
                'fit_time_seconds': np.nan,
            })

    return pd.DataFrame(results)


In [42]:
# =============================================================================
# Execute MFE (MF + Episodic Memory) Modeling
# =============================================================================

print("\n" + "=" * 60)
print("MF Enhanced with Episodic Memory Estimation")
print("Features: TD updates + non-parametric episodic Q retrieval")
print("=" * 60)

mfe_config = MFEpiConfig(
    alpha_init=0.1,
    beta_init=1.0,
    epsilon_init=0.1,
    phi_init=0.1,
    sigma_t_init=1.0 / 12.0,
    reward_type='log',
    reward_param_init=1.0,
    visit_threshold=3,
    memory_threshold=0.01,
    maxiter=1000,
    ftol=1e-6,
)

mfe_results = fit_mfe_for_all_users(
    users_dict,
    config=mfe_config,
    sample_size=10,
    verbose=True,
)

print("\nMFE Model Fitting Summary:")
print(mfe_results[['user_id', 'n_records', 'log_likelihood', 'AIC', 'BIC',
                   'alpha_td', 'beta', 'epsilon_explore', 'phi_forget']].to_string(index=False))

mfe_results.to_csv('mfe_estimation_results.csv', index=False)
print("\nMFE estimation results saved to 'mfe_estimation_results.csv'")

print("\nAggregate Statistics Across Users:")
print(f"  Average alpha (TD rate): {mfe_results['alpha_td'].mean():.4f} ± {mfe_results['alpha_td'].std():.4f}")
print(f"  Average beta (softmax temp): {mfe_results['beta'].mean():.4f} ± {mfe_results['beta'].std():.4f}")
print(f"  Average epsilon (explore): {mfe_results['epsilon_explore'].mean():.4f} ± {mfe_results['epsilon_explore'].std():.4f}")
print(f"  Average phi (forgetting): {mfe_results['phi_forget'].mean():.4f} ± {mfe_results['phi_forget'].std():.4f}")
print(f"  Average log-likelihood: {mfe_results['log_likelihood'].mean():.2f} ± {mfe_results['log_likelihood'].std():.2f}")



MF Enhanced with Episodic Memory Estimation
Features: TD updates + non-parametric episodic Q retrieval
[1/10] Fitting MFE for user 126272...	User 126272 MFE model fit time: 44.81 seconds
[2/10] Fitting MFE for user 278978...	User 278978 MFE model fit time: 29.88 seconds
[3/10] Fitting MFE for user 395753...	User 395753 MFE model fit time: 55.93 seconds
[4/10] Fitting MFE for user 506035...	User 506035 MFE model fit time: 69.07 seconds
[5/10] Fitting MFE for user 612431...	User 612431 MFE model fit time: 54.67 seconds
[6/10] Fitting MFE for user 661336...	User 661336 MFE model fit time: 123.93 seconds
[7/10] Fitting MFE for user 824617...	User 824617 MFE model fit time: 9.13 seconds
[8/10] Fitting MFE for user 928827...	User 928827 MFE model fit time: 98.38 seconds
[9/10] Fitting MFE for user 1017498...	User 1017498 MFE model fit time: 19.54 seconds
[10/10] Fitting MFE for user 1159109...	User 1159109 MFE model fit time: 63.69 seconds

MFE Model Fitting Summary:
 user_id  n_records  lo

# Model-based RL modeling

## Episodic RL (Not very successful)

In [None]:
"""
Episodic RL Modeling with Contextual Memory
============================================

This module implements an episodic reinforcement learning model that incorporates
contextual/situational memory for urban cognition modeling.

Key Features:
1. Graph Memory: Tracks visited locations and their memory strength
2. Episodic Q-Table: Records daily cumulative rewards and temporal context
3. Time-Weighted Retrieval: Computes Q-values using temporal similarity
4. Memory Decay: Geometric decay of memory strength across days
"""

from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass, field
from collections import defaultdict
from scipy.special import logsumexp
import warnings
from datetime import timedelta


# =============================================================================
# Episodic Memory Configuration
# =============================================================================

@dataclass
class EpisodicConfig:
    """Configuration for Episodic RL model."""
    # Memory formation threshold
    visit_threshold: int = 3  # Visits required to add cluster to graph
    
    # Memory decay parameter (phi_episodic)
    phi_episodic: float = 0.1  # Daily decay rate for memory strength
    
    # Memory strength threshold for inclusion in decisions
    memory_threshold: float = 0.1  # Exclude nodes with strength < 0.1
    
    # Time similarity Gaussian std (in time_angle units, 0-1 = 24h)
    # std = 2 hours = 2/24 = 1/12 闂?0.0833
    time_sim_std: float = 1.0 / 12.0
    
    # Exploration rate
    epsilon_init: float = 0.1
    
    # Softmax temperature
    beta_init: float = 1.0
    
    # Reward type for episodic gains
    reward_type: str = 'log'
    reward_param: float = 1.0
    
    # Optimization settings
    maxiter: int = 1000
    ftol: float = 1e-6

    # Reference date for day sequence calculation
    ref_date: Tuple[int, int, int] = (2018, 12, 31)



# =============================================================================
# Data Structures for Episodic Memory
# =============================================================================

@dataclass
class EpisodicRecord:
    """
    A single episodic memory record for one action (cluster) on one day.

    Stores:
    - cluster_id: The cluster visited (action taken)
    - gain: Cumulative reward for that day (sum of rewards from this action onward)
    - time_angle: Time angle when the action was taken
    - record_date: Date of the record (YYYYMMDD)
    - trace: Decay factor for this record (starts at 1.0, decays daily)

    This record is created at the end of each day and added to episodic_table.
    It is used in Q-value computation via weighted average.
    """
    cluster_id: int           # The cluster visited (action)
    gain: float               # Cumulative reward for the day (no discount)
    time_angle: float         # Time angle when the action was taken
    record_date: int          # Date when the record was made
    trace: float = 1.0        # Internal difference in days (for decay)
    
    def to_dict(self) -> dict:
        return {
            'cluster_id': self.cluster_id,
            'gain': self.gain,
            'time_angle': self.time_angle,
            'record_date': self.record_date
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> 'EpisodicRecord':
        return cls(
            cluster_id=int(data['cluster_id']),
            gain=float(data['gain']),
            time_angle=float(data['time_angle']),
            record_date=int(data['record_date'])
        )


class Memory:
    """
    Graph-based memory structure for episodic RL.

    Maintains:
    - graph_nodes: Set of cluster IDs that have been added to graph (after visit_threshold visits)
    - visit_counts: Number of times each cluster has been visited
    - memory_strength: Current memory strength for each cluster in graph
    - episodic_table: List of EpisodicRecord (all historical records, used for Q-computation)

    Key Design:
    - A cluster enters graph when visited visit_threshold times
    - Memory strength is initialized to 1.0 upon entry
    - Memory strength increases by 1.0 after each day's visit (applied at day-end)
    - Memory strength decays geometrically at the start of each new day
    """

    def __init__(self, config: EpisodicConfig):
        self.config = config
        self.graph_nodes: Set[int] = set()  # Clusters in the graph
        self.visit_counts: Dict[int, int] = defaultdict(int)
        self.memory_strength: Dict[int, float] = defaultdict(float)
        self.episodic_table: List[EpisodicRecord] = []  # date -> table
        
        # Special nodes always present (but with 0 memory strength initially)
        # -1: noise/unknown, 0: outside study area
        self.always_present: Set[int] = {-1, 0}
    
    def add_visit(self, cluster_id: int):
        """
        Record a visit to a cluster and update memory strength.

        If visit count reaches threshold, add cluster to graph with initial strength 1.0.
        If already in graph, increase memory strength by 1.0.

        Note: This is called at the END of each day. The purpose of this function is to select the suitable cluster.
        """
        if cluster_id <= 0:  # Skip special nodes (-1: noise, 0: outside)
            return

        self.visit_counts[cluster_id] += 1

        # Add to graph if threshold reached
        if self.visit_counts[cluster_id] >= self.config.visit_threshold:
            if cluster_id not in self.graph_nodes:
                # First time entering graph
                self.graph_nodes.add(cluster_id)
                self.memory_strength[cluster_id] = 1.0
            else:
                # Already in graph: increase strength by 1.0
                self.memory_strength[cluster_id] += 1.0
    

    def decay(self, date_diff: int = 1):
        """
        Apply geometric decay to memory strength at the start of a new day.

        Called at the START of each new day (before any decision-making).

        Memory strength decays as: strength *= (1 - phi_episodic) ^ date_diff
        EpisodicRecord.trace also decays similarly (used for Q-value computation).

        Parameters:
        - date_diff: Number of days since last decay (usually 1, can be >1 for gaps)
        """
        decay_factor = (1.0 - self.config.phi_episodic) ** date_diff
        for node_id in self.graph_nodes:
            self.memory_strength[node_id] *= decay_factor
        for episode_record in self.episodic_table:
            episode_record.trace *= decay_factor


    def get_active_nodes(self) -> Set[int]:
        """
        Get clusters with memory strength above threshold for Q-function inclusion.

        Only clusters with sufficient memory strength participate in episodic Q-computation.
        Weak memories (below threshold) are excluded, effectively "forgotten".

        Returns:
        - Set of cluster IDs with memory_strength >= memory_threshold
        """
        if self.config.memory_threshold <= 0:
            return self.graph_nodes.copy()

        return {
            node_id for node_id in self.graph_nodes
            if self.memory_strength.get(node_id, 0.0) >= self.config.memory_threshold
        }

    
    def add_episodic_record(self, record: EpisodicRecord):
        """
        Add an episodic record to the episodic table.

        Called at day-end to store the day's experiences for future Q-computation.
        """
        self.episodic_table.append(record)


    def get_records_for_action(self, cluster_id: int) -> List[EpisodicRecord]:
        """
        Get all historical episodic records for a specific action (cluster).

        These records are used in compute_episodic_Q to calculate Q-values
        based on time-weighted average of historical gains.

        Parameters:
        - cluster_id: The cluster/action to retrieve records for

        Returns:
        - List of EpisodicRecord objects for this cluster
        """
        assert cluster_id != -1, "cluster_id could not be -1."
        return [r for r in self.episodic_table if r.cluster_id == cluster_id]
    


# =============================================================================
# Reward Computation Functions
# =============================================================================




# =============================================================================
# Data Preparation for Episodic Model
# =============================================================================

def prepare_episodic_data(user_df: pd.DataFrame,
                         config: EpisodicConfig = None) -> Dict[str, Any]:
    """
    Prepare episodic modeling data from user trajectory.
    
    Parameters:
    - user_df: DataFrame with columns ['t_start', 't_end', 'cluster_id', 'date']
    - config: Episodic configuration
    
    Returns:
    - Dictionary with prepared data for episodic modeling
    """
    if config is None:
        config = EpisodicConfig()
    
    # Ensure datetime columns
    df = user_df.copy()
    # Sort by time
    df = df.sort_values(by=['t_start']).reset_index(drop=True)
    
    # Basic data
    n_records = len(df)
    states = df['cluster_id'].to_numpy()
    
    # Compute time angles
    time_angles = np.array([
        compute_time_angle(dt) for dt in df['t_end']
    ])
    
    # Extract dates as array
    date_array = df['date'].to_numpy()
    unique_dates = sorted(df['date'].unique())
    # Compute day sequence (consecutive integers starting from 0)
    day_seq = compute_day_sequence(date_array, 
                                   config.ref_date[0], 
                                   config.ref_date[1], 
                                   config.ref_date[2])
    
    # Compute stay durations and rewards
    stay_minutes = (df['t_end'] - df['t_start']).dt.total_seconds() / 60.0
    stay_minutes = stay_minutes.to_numpy()
    stay_minutes = np.roll(stay_minutes, -1)
    stay_minutes[-1] = 0

    # Reward for each record
    reward_array = compute_reward_array(
        stay_minutes, 
        config.reward_type, 
        config.reward_param
    )
    
    # Compute daily gains as backward cumulative sum of rewards
    # Gain at time t: sum of rewards from t until the end of that day
    # This represents the cumulative reward obtained AFTER taking the action at time t
    gains = np.zeros(n_records)
    for date in unique_dates:
        day_mask = df['date'] == date
        idxs, = np.where(day_mask.values)
        day_rewards = reward_array[idxs]
        # Compute backward cumulative sum (Gain at each time step: sum of future rewards including now)
        backward_cumsum = np.cumsum(day_rewards[::-1])[::-1]
        gains[idxs] = backward_cumsum
    
    # Compute actions
    # -9: end-of-day, -1: explore (noise), 0: outside, 1+: clusters
    actions = np.zeros(n_records, dtype=int)
    actions[-1] = -9  # Last action is end-of-day
    
    for t in range(n_records - 1):
        current_date = df.iloc[t]['date']
        next_date = df.iloc[t + 1]['date']
        
        if next_date > current_date:
            # Next day - end of trajectory
            actions[t] = -9
        else:
            next_state = states[t + 1]
            if next_state == -1:
                actions[t] = -1  # Explore to noise
            elif next_state == 0:
                actions[t] = 0  # Outside study area
            else:
                actions[t] = next_state  # Transition to cluster
    
    # Compute same-day next indicator
    same_day_next = np.zeros(n_records, dtype=bool)
    for t in range(n_records - 1):
        if day_seq[t] == day_seq[t + 1]:
            same_day_next[t] = True
        
    return {
        'states': states,
        'actions': actions,
        'time_angles': time_angles,
        'date_array': date_array,
        'day_seq': day_seq,
        'stay_minutes': stay_minutes,
        'reward_array': reward_array,
        'gains': gains,
        'n_records': n_records,
        'same_day_next': same_day_next,
        'df': df
    }


# =============================================================================
# Parameter Handling
# =============================================================================

def unpack_params_epi(theta: np.ndarray) -> Dict[str, float]:
    """
    Unpack episodic model parameters from optimization vector.
    
    Parameters:
    - theta: Parameter vector [phi_episodic, beta, epsilon]
    
    Returns:
    - Dictionary with phi_episodic, beta, epsilon
    """
    idx = 0
    
    # Episodic decay rate: phi = sigmoid(logit_phi) in [0, 1)
    logit_phi = theta[idx]; idx += 1
    phi_episodic = 1.0 / (1.0 + np.exp(-logit_phi))
    
    # Softmax temperature: beta = exp(log_beta)
    log_beta = theta[idx]; idx += 1
    beta = np.exp(log_beta)
    
    # Exploration probability: epsilon = sigmoid(logit_epsilon)
    logit_epsilon = theta[idx]; idx += 1
    epsilon = 1.0 / (1.0 + np.exp(-logit_epsilon))
    
    return {
        'phi_episodic': phi_episodic,
        'beta': beta,
        'epsilon': epsilon
    }


def pack_params_epi(phi_episodic: float, beta: float, epsilon: float) -> np.ndarray:
    """
    Pack episodic model parameters into optimization vector.
    
    Parameters:
    - phi_episodic: Episodic decay rate (0, 1)
    - beta: Softmax inverse temperature (0, +inf)
    - epsilon: Exploration probability (0, 1)
    
    Returns:
    - Parameter vector
    """
    return np.array([
        np.log(phi_episodic / (1.0 - phi_episodic)),  # logit_phi
        np.log(beta),                                  # log_beta
        np.log(epsilon / (1.0 - epsilon))              # logit_epsilon
    ], dtype=np.float64)


# =============================================================================
# Episodic Q-Value Computation
# =============================================================================

def compute_episodic_Q(
    cluster_id: int,
    current_time_angle: float,
    memory: Memory,
    episode_decay: bool = True
) -> float:
    """
    Compute episodic Q-value for a given action (cluster) at current time.

    Q(action) = weighted_average(gain, weights = time_similarity * decay_trace)

    This retrieves all historical records for this cluster and computes
    a time-weighted average of gains, where:
    - Time similarity: Gaussian similarity between current time and record time
    - Decay trace: Decay factor based on days since record (from EpisodicRecord.trace)

    Parameters:
    - cluster_id: The action (cluster) to compute Q-value for
    - current_time_angle: Current time of decision (in [0, 1) scale)
    - memory: Memory object containing episodic_table with historical records
    - episode_decay: Whether to apply decay weights (True) or use raw traces (False)

    Returns:
    - Episodic Q-value (float). Returns 0.0 if no historical records exist.
    """

    # Get all historical records for this cluster
    records = memory.get_records_for_action(cluster_id)
    if not records:
        return 0.0

    gain_array = np.array([r.gain for r in records])
    time_sim_array = np.array([compute_time_similarity(r.time_angle, current_time_angle) for r in records])

    if episode_decay:
        decay_array = np.array([r.trace for r in records])
        weights = decay_array * time_sim_array
    else:
        weights = time_sim_array

    # Weighted average of gains based on time similarity and decay
    estimate_Q = np.average(gain_array, weights=weights)

    return estimate_Q


# =============================================================================
# Episodic Simulation and Log-Likelihood
# =============================================================================

def simulate_and_loglik_epi(
    theta: np.ndarray,
    episodic_data: Dict[str, Any],
    config: EpisodicConfig = None
) -> float:
    """
    Compute negative log-likelihood for episodic RL model.

    This function simulates decision-making using episodic memory:

    CAUSAL FLOW (Critical Design):
    --------------------------------
    1. At time t, decision is based on HISTORICAL memory only
    2. Current action is recorded but NOT immediately added to memory
    3. At day-end: all recorded actions are added to memory (update graph + episodic records)
    4. At next-day start: memory decays based on phi_episodic

    This ensures no information leakage: decisions use only what was known BEFORE.

    Processing Order at Each Step t:
    --------------------------------
    1. Record current action (add to adopted_actions, cached_episodes)
    2. Apply decay if new day started (based on prev_date)
    3. Get active nodes (memory_strength >= threshold)
    4. Compute Q-values from historical episodic records
    5. Compute action probability via softmax(beta * Q)
    6. Accumulate log-likelihood
    7. At day-end: update memory (add_visit, add_episodic_record)

    Parameters:
    - theta: Parameter vector [phi_episodic, beta, epsilon]
    - episodic_data: Prepared episodic data (states, actions, time_angles, gains, etc.)
    - config: Episodic configuration

    Returns:
    - Negative log-likelihood (to be minimized)
    """
    if config is None:
        config = EpisodicConfig()

    # Unpack parameters
    params = unpack_params_epi(theta)
    config.phi_episodic = params['phi_episodic']
    beta = params['beta']
    epsilon = params['epsilon']

    # Initialize fresh memory for this user
    base_memory = Memory(config)

    # Extract prepared data
    day_seq = episodic_data['day_seq']
    time_angles = episodic_data['time_angles']
    date_array = episodic_data['date_array']
    action_seq = episodic_data['actions']
    n_records = episodic_data['n_records']
    same_day_next = episodic_data['same_day_next']
    gains = episodic_data['gains']

    # Tracking variables
    prev_date = None           # Previous date for detecting day transitions
    adopted_actions = set()    # Actions taken on current day (for memory update)
    cached_episodes = list()   # Episode records for current day
    loglik = 0.0

    # Main simulation loop
    for t in range(n_records):
        a = int(action_seq[t])           # Current action
        time_angle = float(time_angles[t])  # Current time angle
        current_date = int(date_array[t])   # Current date
        current_day_seq = int(day_seq[t])   # Current day sequence number

        # --- Step 1: Record current action (but don't update memory yet) ---
        adopted_actions.add(a)  # Track for day-end update

        # Create episodic record (will be added to episodic_table at day-end)
        cached_episodes.append(EpisodicRecord(
            cluster_id=a,
            time_angle=time_angle,
            record_date=current_date,
            gain=gains[t],
        ))

        # --- Step 2: Apply decay at start of new day ---
        if prev_date is not None and current_date > prev_date:
            prev_date_seq = day_seq[t-1]
            date_diff = current_day_seq - prev_date_seq
            assert date_diff > 0, f"Date difference should be positive"
            base_memory.decay(date_diff)
        prev_date = current_date

        # --- Step 3 & 4: Compute Q-values from historical memory ---
        active_nodes = base_memory.get_active_nodes()

        # Available actions for Q-function:
        # -9 (end trajectory): always available
        # 0 (outside city): always available
        # active clusters: only those with memory_strength >= threshold
        available_actions = sorted([-9, 0] + list(active_nodes))

        # Compute episodic Q-values for each available action
        q_values = []
        for action in available_actions:
            Q = compute_episodic_Q(
                action,
                time_angle,
                base_memory,
                episode_decay=False  # Trace already applied in decay()
            )
            q_values.append(Q)

        q_values = np.asarray(q_values, dtype=np.float64)

        # --- Step 5: Compute action probability ---
        # Softmax policy over Q-values
        logits = beta * q_values
        probs_exploit = np.exp(logits - logsumexp(logits))

        # Map observed action to probability
        action_prob = 0.0

        if a == -1:
            # Noise/explore action: always pure exploration
            action_prob = epsilon
        elif a in available_actions:
            # Exploitation: (1-epsilon) * softmax probability
            idx = available_actions.index(a)
            if idx < len(probs_exploit):
                action_prob = (1.0 - epsilon) * probs_exploit[idx]
        else:
            # Action in memory but forgotten (strength < threshold)
            # Or action never recorded: treated as exploration
            action_prob = epsilon

        # --- Step 6: Accumulate log-likelihood ---
        loglik += np.log(action_prob + 1e-12)

        # --- Step 7: At day-end, update memory ---
        if not same_day_next[t]:
            # Add all visited actions to graph (update memory_strength)
            for act in adopted_actions:
                base_memory.add_visit(act)
            adopted_actions.clear()

            # Add all episode records to episodic_table (for future Q-computation)
            for record in cached_episodes:
                base_memory.add_episodic_record(record)
            cached_episodes.clear()
    
    return -loglik


# =============================================================================
# Model Fitting
# =============================================================================

def fit_epi_model(
    user_df: pd.DataFrame,
    config: EpisodicConfig = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """
    Fit episodic RL model to a single user's trajectory data.
    
    Parameters:
    - user_df: DataFrame with columns ['t_start', 't_end', 'cluster_id', 'date']
    - config: Episodic configuration
    - verbose: Whether to print progress
    
    Returns:
    - Dictionary with fitted parameters and statistics
    """
    
    if config is None:
        config = EpisodicConfig()
    
    # Prepare episodic data
    episodic_data = prepare_episodic_data(user_df, config)
    n_records = episodic_data['n_records']
    
    if verbose:
        print(f"Preparing Episodic model for {len(user_df)} visits, {n_records} actions...")
    
    # Initial parameters
    initial_theta = pack_params_epi(
        phi_episodic=config.phi_episodic,
        beta=config.beta_init,
        epsilon=config.epsilon_init
    )
    
    # Optimize
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        result = minimize(
            simulate_and_loglik_epi,
            initial_theta,
            args=(episodic_data, config),
            method='L-BFGS-B',
            options={'maxiter': config.maxiter, 'ftol': config.ftol}
        )
    
    # Extract fitted parameters
    fitted_params = unpack_params_epi(result.x)
    final_loglik = -result.fun
    
    # Compute model selection criteria
    k_params = 3  # phi_episodic, beta, epsilon
    n_obs = n_records
    AIC = 2 * k_params - 2 * final_loglik
    BIC = k_params * np.log(n_obs) - 2 * final_loglik if n_obs > 0 else np.inf
    
    summary = {
        'n_records': int(n_records),
        'log_likelihood': float(final_loglik),
        'AIC': float(AIC),
        'BIC': float(BIC),
        'phi_episodic': float(fitted_params['phi_episodic']),
        'beta': float(fitted_params['beta']),
        'epsilon': float(fitted_params['epsilon']),
        'converged': result.success,
        'n_iterations': result.nit if hasattr(result, 'nit') else None,
        'optimization_message': result.message if hasattr(result, 'message') else None
    }
    
    if verbose:
        print(f"Episodic model fitting {'converged' if result.success else 'did not converge'}.")
        print(f"  Log-likelihood: {final_loglik:.2f}, AIC: {AIC:.2f}, BIC: {BIC:.2f}")
        print(f"  phi_episodic (decay): {fitted_params['phi_episodic']:.4f}")
        print(f"  beta (softmax temp): {fitted_params['beta']:.4f}")
        print(f"  epsilon (explore): {fitted_params['epsilon']:.4f}")
    
    return summary


def fit_epi_for_all_users(
    users_dict: Dict[int, Any],
    config: EpisodicConfig = None,
    sample_size: Optional[int] = None,
    verbose: bool = True
) -> pd.DataFrame:
    """
    Fit episodic RL model for all users in the dictionary.
    
    Parameters:
    - users_dict: Dictionary mapping user_id to User objects
    - config: Episodic configuration
    - sample_size: Number of users to fit (None for all)
    - verbose: Whether to print progress
    
    Returns:
    - DataFrame with fitted parameters for each user
    """
    if config is None:
        config = EpisodicConfig()
    
    user_ids = list(users_dict.keys())
    if sample_size is not None:
        user_ids = user_ids[:sample_size]
    
    results = []
    for i, user_id in enumerate(user_ids):
        if verbose:
            print(f"[{i+1}/{len(user_ids)}] Fitting Episodic for user {user_id}...", end="")
        
        user = users_dict[user_id]
        user_df = user.to_dataframe()
        
        try:
            t_start = time.time()
            result = fit_epi_model(user_df, config, verbose=False)
            t_end = time.time()
            elapsed = t_end - t_start
            if verbose:
                print(f"\tUser {user_id} Episodic model fit time: {elapsed:.2f} seconds")
            result['user_id'] = user_id
            result['fit_time_seconds'] = elapsed
            results.append(result)
        except Exception as e:
            if verbose:
                print(f"  Error fitting user {user_id}: {e}")
            results.append({
                'user_id': user_id,
                'n_records': 0,
                'log_likelihood': np.nan,
                'AIC': np.nan,
                'BIC': np.nan,
                'phi_episodic': np.nan,
                'beta': np.nan,
                'epsilon': np.nan,
                'converged': False,
                'n_iterations': None,
                'optimization_message': str(e)
            })
    
    return pd.DataFrame(results)


In [38]:
# =============================================================================
# Diagnostic: Check if phi_episodic is actually used in optimization
# =============================================================================

def simulate_with_different_phi(theta, episodic_data, phi_values=[0.01, 0.1, 0.4, 0.9], verbose=True):
    """
    Test the model with different phi values to see if they produce different log-likelihoods.
    """
    results = []

    for phi in phi_values:
        config = EpisodicConfig(
            visit_threshold=3,
            phi_episodic=phi,
            memory_threshold=0.1,
            time_sim_std=1.0/12.0,
            epsilon_init=0.1,
            beta_init=1.0,
        )

        params = unpack_params_epi(theta)
        beta = params['beta']
        epsilon = params['epsilon']
        base_memory = Memory(config)

        day_seq = episodic_data['day_seq']
        time_angles = episodic_data['time_angles']
        date_array = episodic_data['date_array']
        action_seq = episodic_data['actions']
        n_records = episodic_data['n_records']
        same_day_next = episodic_data['same_day_next']
        gains = episodic_data['gains']

        prev_date = None
        adopted_actions = set()
        cached_episodes = list()
        loglik = 0.0

        for t in range(n_records):
            a = int(action_seq[t])
            time_angle = float(time_angles[t])
            current_date = int(date_array[t])
            current_day_seq = int(day_seq[t])

            adopted_actions.add(a)
            cached_episodes.append(EpisodicRecord(
                cluster_id=a, time_angle=time_angle,
                record_date=current_date, gain=gains[t],
            ))

            if prev_date is not None and current_date > prev_date:
                prev_date_seq = day_seq[t-1]
                date_diff = current_day_seq - prev_date_seq
                base_memory.decay(date_diff)
            prev_date = current_date

            active_nodes = base_memory.get_active_nodes()
            available_actions = sorted([-9, 0] + list(active_nodes))

            q_values = []
            for action in available_actions:
                Q = compute_episodic_Q(action, time_angle, base_memory, episode_decay=False)
                q_values.append(Q)
            q_values = np.asarray(q_values, dtype=np.float64)

            logits = beta * q_values
            probs_exploit = np.exp(logits - logsumexp(logits))

            if a == -1:
                action_prob = epsilon
            elif a in available_actions:
                idx = available_actions.index(a)
                action_prob = (1.0 - epsilon) * probs_exploit[idx]
            else:
                action_prob = epsilon

            loglik += np.log(action_prob + 1e-12)

            if not same_day_next[t]:
                for act in adopted_actions:
                    base_memory.add_visit(act)
                adopted_actions.clear()
                for record in cached_episodes:
                    base_memory.add_episodic_record(record)
                cached_episodes.clear()

        results.append({
            'phi': phi,
            'loglik': loglik,
            'n_active_nodes': len(base_memory.get_active_nodes())
        })

        if verbose:
            print(f"phi={phi:.2f}: loglik={loglik:.2f}, active_nodes={len(base_memory.get_active_nodes())}")

    return results


# Run diagnostic
sample_user_id = list(users_dict.keys())[0]
sample_df = users_dict[sample_user_id].to_dataframe()
episodic_data = prepare_episodic_data(sample_df)
theta = pack_params_epi(phi_episodic=0.4, beta=1.0, epsilon=0.1)

print("Testing different phi values:")
results = simulate_with_different_phi(theta, episodic_data)

range_val = max(r['loglik'] for r in results) - min(r['loglik'] for r in results)
print(f"\nLog-likelihood range: {range_val:.4f}")

if range_val < 1.0:
    print("闂備礁鐤囧▔鏇熷垔鐎靛摜绠? WARNING: phi has almost NO effect on log-likelihood!")
else:
    print("闂?phi does affect the model.")

Testing different phi values:
phi=0.01: loglik=-8035.39, active_nodes=33
phi=0.10: loglik=-7260.12, active_nodes=14
phi=0.40: loglik=-5954.11, active_nodes=6
phi=0.90: loglik=-4338.02, active_nodes=2

Log-likelihood range: 3697.3654
闂?phi does affect the model.


In [39]:
# =============================================================================
# Diagnostic: Detailed phi optimization analysis
# =============================================================================

def detailed_phi_analysis(theta, episodic_data, phi_values, verbose=True):
    results = []

    for phi in phi_values:
        config = EpisodicConfig(
            visit_threshold=3, phi_episodic=phi, memory_threshold=0.1,
            time_sim_std=1.0/12.0, epsilon_init=0.1, beta_init=1.0,
        )

        params = unpack_params_epi(theta)
        beta = params['beta']
        epsilon = params['epsilon']
        base_memory = Memory(config)

        day_seq = episodic_data['day_seq']
        time_angles = episodic_data['time_angles']
        date_array = episodic_data['date_array']
        action_seq = episodic_data['actions']
        n_records = episodic_data['n_records']
        same_day_next = episodic_data['same_day_next']
        gains = episodic_data['gains']

        prev_date = None
        adopted_actions = set()
        cached_episodes = list()
        loglik = 0.0
        n_explored = 0

        for t in range(n_records):
            a = int(action_seq[t])
            time_angle = float(time_angles[t])
            current_date = int(date_array[t])
            current_day_seq = int(day_seq[t])

            adopted_actions.add(a)
            cached_episodes.append(EpisodicRecord(
                cluster_id=a, time_angle=time_angle,
                record_date=current_date, gain=gains[t],
            ))

            if prev_date is not None and current_date > prev_date:
                base_memory.decay(current_day_seq - day_seq[t-1])
            prev_date = current_date

            active_nodes = base_memory.get_active_nodes()
            available_actions = sorted([-9, 0] + list(active_nodes))

            if a not in available_actions:
                n_explored += 1

            q_values = [compute_episodic_Q(a, time_angle, base_memory, episode_decay=False) 
                       for a in available_actions]
            q_values = np.asarray(q_values, dtype=np.float64)

            logits = beta * q_values
            probs_exploit = np.exp(logits - logsumexp(logits))

            if a == -1:
                action_prob = epsilon
            elif a in available_actions:
                idx = available_actions.index(a)
                action_prob = (1.0 - epsilon) * probs_exploit[idx]
            else:
                action_prob = epsilon

            loglik += np.log(action_prob + 1e-12)

            if not same_day_next[t]:
                for act in adopted_actions:
                    base_memory.add_visit(act)
                adopted_actions.clear()
                for record in cached_episodes:
                    base_memory.add_episodic_record(record)
                cached_episodes.clear()

        results.append({
            'phi': phi, 'loglik': loglik,
            'n_active_final': len(base_memory.get_active_nodes()),
            'explore_rate': n_explored / n_records
        })

        if verbose:
            print(f"phi={phi:.2f}: loglik={loglik:.1f}, active={len(base_memory.get_active_nodes())}, "
                  f"explore_rate={n_explored/n_records:.2%}")

    return results


# Run
sample_user_id = list(users_dict.keys())[0]
sample_df = users_dict[sample_user_id].to_dataframe()
episodic_data = prepare_episodic_data(sample_df)
theta = pack_params_epi(phi_episodic=0.4, beta=1.0, epsilon=0.1)

print("Effect of phi on exploration:")
results = detailed_phi_analysis(theta, episodic_data, [0.01, 0.1, 0.2, 0.4, 0.6, 0.8, 0.9])

Effect of phi on exploration:
phi=0.01: loglik=-8035.4, active=33, explore_rate=29.17%
phi=0.10: loglik=-7260.1, active=14, explore_rate=32.88%
phi=0.20: loglik=-6702.7, active=9, explore_rate=35.23%
phi=0.40: loglik=-5954.1, active=6, explore_rate=40.91%
phi=0.60: loglik=-5289.0, active=4, explore_rate=47.95%
phi=0.80: loglik=-4551.8, active=2, explore_rate=54.55%
phi=0.90: loglik=-4338.0, active=2, explore_rate=56.74%


In [41]:
# =============================================================================
# Episodic RL Model Fitting (Sample Test)
# =============================================================================

print("\n" + "="*60)
print("Episodic RL Model Fitting")
print("Features: Episodic memory, Time-weighted Q retrieval, Memory decay")
print("="*60)

# 1. Create configuration
epi_config = EpisodicConfig(
    visit_threshold=3,
    phi_episodic=0.1,
    memory_threshold=0.01,
    time_sim_std=1.0/12.0,
    epsilon_init=0.1,
    beta_init=1.0,
    maxiter=1000,
    ftol=1e-6
)

# 2. Fit Episodic model for sample users
epi_results = fit_epi_for_all_users(
    users_dict,
    config=epi_config,
    sample_size=10,  # Change to None for all users
    verbose=True
)

print("\nEpisodic Model Fitting Summary:")
print(epi_results[['user_id', 'n_records', 'log_likelihood', 'AIC', 'BIC',
                   'phi_episodic', 'beta', 'epsilon']].to_string(index=False))

# Save results
epi_results.to_csv('epi_estimation_results.csv', index=False)
print(f"\nEpisodic estimation results saved to 'epi_estimation_results.csv'")

# Aggregate statistics
print("\nAggregate Statistics Across Users:")
print(f"  Average phi_episodic (decay): {epi_results['phi_episodic'].mean():.4f} 闂?{epi_results['phi_episodic'].std():.4f}")
print(f"  Average beta (softmax temp): {epi_results['beta'].mean():.4f} 闂?{epi_results['beta'].std():.4f}")
print(f"  Average epsilon (explore): {epi_results['epsilon'].mean():.4f} 闂?{epi_results['epsilon'].std():.4f}")
print(f"  Average log-likelihood: {epi_results['log_likelihood'].mean():.2f} 闂?{epi_results['log_likelihood'].std():.2f}")


Episodic RL Model Fitting
Features: Episodic memory, Time-weighted Q retrieval, Memory decay
[1/10] Fitting Episodic for user 126272...	User 126272 Episodic model fit time: 170.71 seconds
[2/10] Fitting Episodic for user 278978...	User 278978 Episodic model fit time: 128.25 seconds
[3/10] Fitting Episodic for user 395753...	User 395753 Episodic model fit time: 184.04 seconds
[4/10] Fitting Episodic for user 506035...	User 506035 Episodic model fit time: 206.65 seconds
[5/10] Fitting Episodic for user 612431...	User 612431 Episodic model fit time: 275.18 seconds
[6/10] Fitting Episodic for user 661336...	User 661336 Episodic model fit time: 351.99 seconds
[7/10] Fitting Episodic for user 824617...	User 824617 Episodic model fit time: 8.21 seconds
[8/10] Fitting Episodic for user 928827...	User 928827 Episodic model fit time: 281.36 seconds
[9/10] Fitting Episodic for user 1017498...	User 1017498 Episodic model fit time: 55.59 seconds
[10/10] Fitting Episodic for user 1159109...	User 11

In [44]:
# =============================================================================
# Episodic RL Model Fitting (Sample Test)
# =============================================================================

print("\n" + "="*60)
print("Episodic RL Model Fitting")
print("Features: Episodic memory, Time-weighted Q retrieval, Memory decay")
print("="*60)

# 1. Create configuration
epi_config = EpisodicConfig(
    visit_threshold=3,
    phi_episodic=0.1,
    memory_threshold=0,
    time_sim_std=1.0/12.0,
    epsilon_init=0.1,
    beta_init=1.0,
    maxiter=1000,
    ftol=1e-6
)

# 2. Fit Episodic model for sample users
epi_results = fit_epi_for_all_users(
    users_dict,
    config=epi_config,
    sample_size=10,  # Change to None for all users
    verbose=True
)

print("\nEpisodic Model Fitting Summary:")
print(epi_results[['user_id', 'n_records', 'log_likelihood', 'AIC', 'BIC',
                   'phi_episodic', 'beta', 'epsilon']].to_string(index=False))

# Save results
epi_results.to_csv('epi_estimation_results.csv', index=False)
print(f"\nEpisodic estimation results saved to 'epi_estimation_results.csv'")

# Aggregate statistics
print("\nAggregate Statistics Across Users:")
print(f"  Average phi_episodic (decay): {epi_results['phi_episodic'].mean():.4f} 闂?{epi_results['phi_episodic'].std():.4f}")
print(f"  Average beta (softmax temp): {epi_results['beta'].mean():.4f} 闂?{epi_results['beta'].std():.4f}")
print(f"  Average epsilon (explore): {epi_results['epsilon'].mean():.4f} 闂?{epi_results['epsilon'].std():.4f}")
print(f"  Average log-likelihood: {epi_results['log_likelihood'].mean():.2f} 闂?{epi_results['log_likelihood'].std():.2f}")


Episodic RL Model Fitting
Features: Episodic memory, Time-weighted Q retrieval, Memory decay
[1/10] Fitting Episodic for user 126272...	User 126272 Episodic model fit time: 149.96 seconds
[2/10] Fitting Episodic for user 278978...	User 278978 Episodic model fit time: 114.38 seconds
[3/10] Fitting Episodic for user 395753...	User 395753 Episodic model fit time: 196.55 seconds
[4/10] Fitting Episodic for user 506035...	User 506035 Episodic model fit time: 242.20 seconds
[5/10] Fitting Episodic for user 612431...	User 612431 Episodic model fit time: 355.89 seconds
[6/10] Fitting Episodic for user 661336...	User 661336 Episodic model fit time: 487.11 seconds
[7/10] Fitting Episodic for user 824617...	User 824617 Episodic model fit time: 20.37 seconds
[8/10] Fitting Episodic for user 928827...	User 928827 Episodic model fit time: 490.92 seconds
[9/10] Fitting Episodic for user 1017498...	User 1017498 Episodic model fit time: 106.61 seconds
[10/10] Fitting Episodic for user 1159109...	User 

In [46]:
# =============================================================================
# Episodic RL Model Fitting (Sample Test)
# =============================================================================

print("\n" + "="*60)
print("Episodic RL Model Fitting")
print("Features: Episodic memory, Time-weighted Q retrieval, Memory decay")
print("="*60)

# 1. Create configuration
epi_config = EpisodicConfig(
    visit_threshold=3,
    phi_episodic=0.1,
    memory_threshold=0,
    time_sim_std=1.0/12.0,
    epsilon_init=0.1,
    beta_init=1.0,
    maxiter=1000,
    ftol=1e-6
)

# 2. Fit Episodic model for sample users
epi_results = fit_epi_for_all_users(
    users_dict,
    config=epi_config,
    sample_size=10,  # Change to None for all users
    verbose=True
)

print("\nEpisodic Model Fitting Summary:")
print(epi_results[['user_id', 'n_records', 'log_likelihood', 'AIC', 'BIC',
                   'phi_episodic', 'beta', 'epsilon']].to_string(index=False))

# Save results
epi_results.to_csv('epi_estimation_results.csv', index=False)
print(f"\nEpisodic estimation results saved to 'epi_estimation_results.csv'")

# Aggregate statistics
print("\nAggregate Statistics Across Users:")
print(f"  Average phi_episodic (decay): {epi_results['phi_episodic'].mean():.4f} 闂?{epi_results['phi_episodic'].std():.4f}")
print(f"  Average beta (softmax temp): {epi_results['beta'].mean():.4f} 闂?{epi_results['beta'].std():.4f}")
print(f"  Average epsilon (explore): {epi_results['epsilon'].mean():.4f} 闂?{epi_results['epsilon'].std():.4f}")
print(f"  Average log-likelihood: {epi_results['log_likelihood'].mean():.2f} 闂?{epi_results['log_likelihood'].std():.2f}")


Episodic RL Model Fitting
Features: Episodic memory, Time-weighted Q retrieval, Memory decay
[1/10] Fitting Episodic for user 126272...	User 126272 Episodic model fit time: 194.47 seconds
[2/10] Fitting Episodic for user 278978...	User 278978 Episodic model fit time: 165.09 seconds
[3/10] Fitting Episodic for user 395753...	User 395753 Episodic model fit time: 314.74 seconds
[4/10] Fitting Episodic for user 506035...	User 506035 Episodic model fit time: 1423.11 seconds
[5/10] Fitting Episodic for user 612431...	User 612431 Episodic model fit time: 15273.50 seconds
[6/10] Fitting Episodic for user 661336...	User 661336 Episodic model fit time: 513.82 seconds
[7/10] Fitting Episodic for user 824617...	User 824617 Episodic model fit time: 12.90 seconds
[8/10] Fitting Episodic for user 928827...	User 928827 Episodic model fit time: 594.02 seconds
[9/10] Fitting Episodic for user 1017498...	User 1017498 Episodic model fit time: 107.40 seconds
[10/10] Fitting Episodic for user 1159109...	Us

## Successor Representation (SR)

This method should ideally be integrated with the classic Dyna architecture. Here we prospect its structure.

Initialize:
* A network graph containing all node states, for random walk.
* A null world model, with the ability to online learning.
* A huge null sparse matrix to record the Successsor Representation (SR).

Procedures:
The agent decides to exploit or explore based on the world model and SR today.
At the end of the day:
1 Collect all 
2 
3 

In [None]:
# =============================================================================
# SR Dyna: Model-Based RL using Successor Representation (Revised)
# =============================================================================
"""
SR Dyna combines Successor Representation (SR) with Dyna architecture.

Key Concept (No gamma discount - trajectories end within day):
- SR[s][a][s'] = Expected visit count to future state s' 
                  from state s taking action a, WITHIN THE SAME DAY
- Q(s, a) = Σ_{s'} SR[s][a][s'] * R(s')  (weighted sum of rewards)

Key Differences from MF Episodic Memory:
- MF stores: Q-values (cumulative rewards from time t to end of day)
- SR stores: Visit counts to each future state s' from (s,a)
- Both use episodic memory + time similarity kernel for generalization
- Both can reuse EpisodicRecord-like structure

Algorithm:
1. Collect real trajectory data
2. Update SR from real transitions: SR[s,a][s'] += 1 for observed transition
3. Dyna planning: simulate N steps using learned model, update SR
4. Compute Q-values from SR: Q(s,a) = SR[s,a] · R (weighted by rewards)
5. Use Q-values for action selection (softmax policy)

Note: No gamma discount needed because all trajectories end within the day.
"""

from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass, field
from collections import defaultdict
import numpy as np
import pandas as pd
import warnings
from scipy.optimize import minimize
import time


# =============================================================================
# SR Dyna Configuration
# =============================================================================

@dataclass
class SRDynaConfig:
    """Configuration for SR Dyna model."""
    # Learning rates
    alpha_init: float = 0.1       # SR learning rate (for real transitions)
    alpha_plan: float = 0.1       # Planning SR learning rate
    
    # Policy parameters
    beta_init: float = 1.0        # Softmax temperature
    epsilon_init: float = 0.1     # Exploration rate
    
    # Forgetting (memory decay between days, NO gamma - no discount within day)
    phi_init: float = 0.1         # Memory decay rate between days
    
    # Dyna planning
    n_planning_steps: int = 10     # Number of simulated steps per real step
    
    # Reward function
    reward_type: str = 'log'
    reward_param_init: float = 1.0
    
    # Memory thresholds
    visit_threshold: int = 3       # Visits required to add node to active set
    memory_threshold: float = 0.01  # Memory strength threshold
    
    # Time similarity kernel
    sigma_t_init: float = 1.0 / 12.0  # 2 hours = 2/24 in time_angle units
    
    # Optimization settings
    maxiter: int = 1000
    ftol: float = 1e-6
    
    # Reference date
    ref_date: Tuple[int, int, int] = (2018, 12, 31)


# =============================================================================
# SR Record: Similar to EpisodicRecord but stores visit counts instead of Q-values
# =============================================================================

@dataclass
class SRRecord:
    """
    A single SR memory record for one (state, action) pair at a specific time.
    
    Unlike EpisodicRecord which stores Q-values (gains), this stores:
    - future_state: The destination state visited
    - visit_count: Number of visits to this future state from (s,a)
    - time_angle: Time when the action was taken
    - day_seq: Day sequence number
    - record_date: Date of the record
    - strength: Memory strength (for decay)
    """
    future_state: int        # Destination state s'
    visit_count: float      # Expected visit count (can be fractional)
    time_angle: float       # Time angle when action was taken
    day_seq: int           # Day sequence number
    record_date: int       # Date (YYYYMMDD)
    strength: float = 1.0  # Memory strength (for decay)


# =============================================================================
# SR Memory: Stores and manages Successor Representation
# =============================================================================

class SRMemory:
    """
    Memory structure for Successor Representation.
    
    Key difference from EpisodicMemory:
    - EpisodicMemory: stores Q-values (cumulative rewards)
    - SRMemory: stores visit counts to future states
    
    Structure:
    - SR[state][action] = List[SRRecord] (all visits to various future states)
    - R[state] = Expected reward at state (same as EpisodicMemory)
    - Memory decay applied when day changes (phi)
    - Time similarity kernel (sigma_t) for generalization to unknown times
    """

    def __init__(self, phi: float, config: Optional[SRDynaConfig] = None):
        self.config = config if config is not None else SRDynaConfig()
        self.phi = phi
        
        # SR matrices: SR[state][action] = List[SRRecord]
        # Each record stores: (future_state, visit_count, time_angle, ...)
        self.SR: Dict[int, Dict[int, List[SRRecord]]] = defaultdict(
            lambda: defaultdict(list)
        )
        
        # Decay factors: similar to EpisodicMemory.Q_decay
        self.SR_decay: Dict[int, Dict[int, float]] = defaultdict(
            lambda: defaultdict(float)
        )
        
        # Expected rewards: R[state] = expected reward at state
        self.R: Dict[int, float] = defaultdict(float)
        
        # Visit counts and memory strengths (same as EpisodicMemory)
        self.visit_counts: Dict[int, int] = defaultdict(int)
        self.memory_strength: Dict[int, float] = defaultdict(float)
        
        # Active states (above threshold)
        self._active_states: Optional[Set[int]] = None
        self.last_day: Optional[int] = None

    def add_transition(self, state: int, action: int, future_state: int,
                       reward: float, is_planning: bool = False):
        """
        Add a transition to SR memory (increment visit count).
        
        Similar to EpisodicMemory.add_record, but instead of storing Q-value,
        we store/increment the visit count to the future state.
        
        Parameters:
        - state: Current state s
        - action: Action taken a
        - future_state: Resulting state s'
        - reward: Immediate reward R(s')
        - is_planning: Whether this is a planning (simulated) update
        """
        # Find existing record for this (future_state, time_bucket) if any
        # For simplicity, we just append new records (non-parametric approach)
        
        if future_state > 0:  # Only valid states (not noise/outside)
            # Create new SR record
            # Note: time_angle and day_seq would be passed from caller
            # For now, we'll handle time in retrieve_sr
            pass  # Will be handled in add_record method
        
        # Update expected reward (exponential moving average)
        if state > 0:
            alpha = self.config.alpha_init if not is_planning else self.config.alpha_plan
            current_R = self.R.get(state, 0.0)
            self.R[state] = current_R + alpha * (reward - current_R)
        
        # Update visit counts and memory strength
        if state > 0:
            self.visit_counts[state] += 1
            if self.visit_counts[state] >= self.config.visit_threshold:
                self.memory_strength[state] = self.memory_strength.get(state, 0.0) + 1.0
        
        self._active_states = None
    
    def add_record(self, state: int, action: int, future_state: int,
                   time_angle: float, day_seq: int, record_date: int,
                   visit_count: float = 1.0):
        """
        Add an SR record to memory (similar to EpisodicMemory.add_record).
        
        Parameters:
        - state: Current state s
        - action: Action taken a  
        - future_state: Destination state s'
        - time_angle: Time when action was taken
        - day_seq: Day sequence number
        - record_date: Date
        - visit_count: Visit count to increment
        """
        if future_state <= 0:
            return
            
        rec = SRRecord(
            future_state=future_state,
            visit_count=visit_count,
            time_angle=time_angle,
            day_seq=day_seq,
            record_date=record_date,
            strength=1.0,
        )
        self.SR[state][action].append(rec)
        self.SR_decay[state][action] = 1.0
        self._active_states = None

    @staticmethod
    def compute_time_similarity(t1: float, t2: float, sigma_t: float) -> float:
        """Circular Gaussian kernel on normalized time angle [0, 1)."""
        diff = abs(float(t1) - float(t2))
        sigma = max(float(sigma_t), 1e-6)
        return float(np.exp(-0.5 * (diff / sigma) ** 2))

    def retrieve_sr(self, state: int, action: int, target_time: float,
                   sigma_t: Optional[float] = None) -> Dict[int, float]:
        """
        Retrieve SR values for (state, action) at target time.
        
        Similar to EpisodicMemory.retrieve_q, but returns a dictionary
        of future_state -> expected_visit_count
        
        Returns:
        - Dict mapping future_state -> time-weighted visit count
        """
        sigma = self.config.sigma_t_init if sigma_t is None else float(sigma_t)
        records = self.SR[int(state)][int(action)]
        
        if not records:
            return {}
        
        # Aggregate visit counts by future state, weighted by time similarity
        sr_values: Dict[int, float] = defaultdict(float)
        
        for rec in records:
            time_sim = self.compute_time_similarity(target_time, rec.time_angle, sigma)
            weight = time_sim * rec.strength
            sr_values[rec.future_state] += rec.visit_count * weight
        
        # Apply decay
        decay = self.SR_decay[int(state)][int(action)]
        for s in sr_values:
            sr_values[s] *= decay
        
        return dict(sr_values)

    def compute_Q(self, state: int, action: int, target_time: float,
                 sigma_t: Optional[float] = None) -> float:
        """
        Compute Q(s, a) using Successor Representation.
        
        Q(s, a) = Σ_{s'} SR[s,a][s'] * R(s')
        
        Parameters:
        - state: Current state s
        - action: Action a
        - target_time: Time for SR retrieval
        - sigma_t: Time similarity kernel width
        
        Returns:
        - Q-value estimate
        """
        sigma = self.config.sigma_t_init if sigma_t is None else float(sigma_t)
        
        # Get SR values (visit counts to each future state)
        sr_values = self.retrieve_sr(state, action, target_time, sigma)
        
        # Compute Q as weighted sum of rewards
        q_value = 0.0
        for future_state, sr_count in sr_values.items():
            reward_at_future = self.R.get(future_state, 0.0)
            q_value += sr_count * reward_at_future
        
        return q_value

    def decay(self, current_day: int):
        """Apply memory decay when day changes (same as EpisodicMemory)."""
        if self.last_day is None:
            self.last_day = current_day
            return
        
        day_diff = int(current_day - self.last_day)
        if day_diff > 0:
            factor = (1.0 - self.phi) ** day_diff
            
            # Decay all SR records
            for state in self.SR:
                for action in self.SR[state]:
                    for rec in self.SR[state][action]:
                        rec.strength *= factor
            
            # Decay SR decay factors
            for state in self.SR_decay:
                for action in self.SR_decay[state]:
                    self.SR_decay[state][action] *= factor
            
            # Decay rewards
            for state in self.R:
                self.R[state] *= factor
            
            # Decay memory strengths
            for state in self.memory_strength:
                self.memory_strength[state] *= factor
            
            self._active_states = None
        
        self.last_day = current_day

    def get_active_states(self) -> Set[int]:
        """Get states with memory strength above threshold."""
        if self._active_states is not None:
            return self._active_states
        
        self._active_states = {
            state_id
            for state_id, strength in self.memory_strength.items()
            if strength >= self.config.memory_threshold
            and self.visit_counts.get(state_id, 0) >= self.config.visit_threshold
        }
        return self._active_states

    def simulate_step(self, current_state: int, target_time: float) -> Tuple[int, int, float]:
        """
        Simulate one step using learned SR model (for Dyna planning).
        
        Uses the SR matrix to sample a future state based on visit frequencies.
        
        Returns:
        - (next_state, action, reward)
        """
        active_states = self.get_active_states()
        
        if current_state not in self.SR or not self.SR[current_state]:
            return (-1, -1, 0.0)  # No knowledge: exploration
        
        # Get all actions with records
        actions_with_records = list(self.SR[current_state].keys())
        
        if not actions_with_records:
            return (-1, -1, 0.0)
        
        # For each action, compute total visit count to active states
        action_weights = {}
        total_weight = 0.0
        
        for action in actions_with_records:
            sr_values = self.retrieve_sr(current_state, action, target_time)
            # Weight = sum of visit counts to active states
            weight = sum(
                sr_values.get(s, 0.0)
                for s in sr_values.keys()
                if s in active_states or s <= 0
            )
            if weight > 0:
                action_weights[action] = weight
                total_weight += weight
        
        if total_weight <= 0:
            return (-1, -1, 0.0)
        
        # Sample action based on weights
        actions = list(action_weights.keys())
        probs = [action_weights[a] / total_weight for a in actions]
        action = actions[np.random.choice(len(actions), p=probs)]
        
        # Sample next state based on SR values
        sr_values = self.retrieve_sr(current_state, action, target_time)
        
        if not sr_values:
            return (action, action, 0.0)
        
        # Normalize to probability distribution
        states = list(sr_values.keys())
        counts = [sr_values[s] for s in states]
        total = sum(counts)
        
        if total <= 0:
            return (action, action, 0.0)
        
        state_probs = [c / total for c in counts]
        next_state = states[np.random.choice(len(states), p=state_probs)]
        
        reward = self.R.get(next_state, 0.0)
        
        return (next_state, action, reward)


# =============================================================================
# Data Preparation for SR Dyna
# =============================================================================

def prepare_sr_dyna_data(user_df: pd.DataFrame,
                        config: Optional[SRDynaConfig] = None) -> Dict[str, Any]:
    """
    Prepare trajectory data for SR Dyna modeling.
    """
    if config is None:
        config = SRDynaConfig()
    
    df = user_df.sort_values('t_start').reset_index(drop=True)
    n_records = len(df)
    
    # Extract arrays
    states = df['cluster_id'].astype(int).to_numpy()
    date_array = df['date'].to_numpy()
    
    # Compute time angles
    time_angles = np.array([compute_time_angle(dt) for dt in df['t_end']])
    
    # Compute day sequence
    day_seq = compute_day_sequence(date_array, config.ref_date[0], 
                                   config.ref_date[1], config.ref_date[2])
    
    # Compute stay durations and rewards
    stay_minutes = (df['t_end'] - df['t_start']).dt.total_seconds() / 60.0
    stay_minutes = stay_minutes.to_numpy()
    stay_minutes = np.roll(stay_minutes, -1)
    stay_minutes[-1] = 0.0
    
    reward_array = compute_reward_array(stay_minutes, config.reward_type, 
                                        config.reward_param_init)
    
    # Compute actions
    actions = np.zeros(n_records, dtype=int)
    actions[-1] = -9
    for t in range(n_records - 1):
        if day_seq[t + 1] > day_seq[t]:
            actions[t] = -9  # End of day
        else:
            actions[t] = int(states[t + 1])
    
    # Same-day indicator
    same_day_next = np.zeros(n_records, dtype=bool)
    for t in range(n_records - 1):
        same_day_next[t] = (day_seq[t] == day_seq[t + 1])
    
    return {
        'states': states,
        'actions': actions,
        'day_seq': day_seq,
        'time_angles': time_angles,
        'date_array': date_array,
        'reward_array': reward_array,
        'same_day_next': same_day_next,
        'n_records': n_records,
    }


# =============================================================================
# Parameter Handling
# =============================================================================

def unpack_params_sr_dyna(theta: np.ndarray) -> Dict[str, float]:
    """
    Unpack SR Dyna parameters from optimization vector.
    
    Following Simple MF parameterization:
    - alpha = sigmoid(logit_alpha) ∈ (0, 1)  # 使用 logit 变换
    - beta = exp(log_beta) ∈ (0, +inf)       # 使用 log 变换
    - epsilon = sigmoid(logit_epsilon) ∈ (0, 1)  # 使用 logit 变换
    - phi = sigmoid(logit_phi) ∈ (0, 1)          # 使用 logit 变换
    """
    idx = 0
    
    # Learning rate: alpha = sigmoid(logit_alpha) ∈ (0, 1)
    logit_alpha = theta[idx]; idx += 1
    alpha = 1.0 / (1.0 + np.exp(-logit_alpha))
    
    # Softmax temperature: beta = exp(log_beta) ∈ (0, +inf)
    log_beta = theta[idx]; idx += 1
    beta = np.exp(log_beta)
    
    # Exploration: epsilon = sigmoid(logit_epsilon) ∈ (0, 1)
    logit_epsilon = theta[idx]; idx += 1
    epsilon = 1.0 / (1.0 + np.exp(-logit_epsilon))
    
    # Forgetting: phi = sigmoid(logit_phi) ∈ (0, 1)
    logit_phi = theta[idx]; idx += 1
    phi = 1.0 / (1.0 + np.exp(-logit_phi))
    
    # Clip values
    alpha = float(np.clip(alpha, 1e-6, 1.0 - 1e-6))
    epsilon = float(np.clip(epsilon, 1e-6, 1.0 - 1e-6))
    phi = float(np.clip(phi, 1e-6, 1.0 - 1e-6))
    
    return {
        'alpha': alpha,
        'beta': float(beta),
        'epsilon': epsilon,
        'phi': phi,
    }


def pack_params_sr_dyna(alpha: float, beta: float, epsilon: float, phi: float) -> np.ndarray:
    """Pack SR Dyna parameters into optimization vector (following Simple MF)."""
    return np.array([
        np.log(alpha / (1.0 - alpha)),   # logit_alpha
        np.log(beta),                      # log_beta
        np.log(epsilon / (1.0 - epsilon)),  # logit_epsilon
        np.log(phi / (1.0 - phi)),         # logit_phi
    ], dtype=np.float64)


# =============================================================================
# SR Dyna Simulation and Log-Likelihood
# =============================================================================

def simulate_and_loglik_sr_dyna(theta: np.ndarray,
                                sr_data: Dict[str, Any],
                                config: Optional[SRDynaConfig] = None) -> float:
    """
    Compute negative log-likelihood for SR Dyna model.
    
    Algorithm:
    1. At each step t:
       a. Get current state and time
       b. Apply memory decay if new day
       c. Compute Q-values using SR: Q(s,a) = SR[s,a] · R
       d. Sample action using softmax policy (with epsilon exploration)
       e. Add SR record from real transition
       f. Perform Dyna planning: simulate N steps, update SR
       g. Accumulate log-likelihood
    """
    if config is None:
        config = SRDynaConfig()
    
    # Unpack parameters
    params = unpack_params_sr_dyna(theta)
    alpha = params['alpha']
    beta = params['beta']
    epsilon = params['epsilon']
    phi = params['phi']
    
    # Extract data
    states = sr_data['states']
    actions = sr_data['actions']
    day_seq = sr_data['day_seq']
    time_angles = sr_data['time_angles']
    date_array = sr_data['date_array']
    reward_array = sr_data['reward_array']
    same_day_next = sr_data['same_day_next']
    n_records = sr_data['n_records']
    
    # Initialize SR memory
    memory = SRMemory(phi, config)
    
    # Track known states and actions
    known_states: Set[int] = {-1, 0}
    known_actions: Set[int] = {-9, -1, 0}
    
    loglik = 0.0
    
    for t in range(n_records):
        s = int(states[t])
        a = int(actions[t])
        r_t = float(reward_array[t])
        current_day = int(day_seq[t])
        current_date = int(date_array[t])
        time_angle = float(time_angles[t])
        
        # Apply decay if new day
        memory.decay(current_day)
        
        # Update known sets
        if a > 0:
            known_states.add(a)
            known_actions.add(a)
        
        s_perc = s if s in known_states else -1
        a_perc = a if a in known_actions else -1
        
        # Get available actions
        available_actions = sorted([act for act in known_actions if act != -1])
        
        # Compute Q-values using SR
        q_values = []
        for act in available_actions:
            q = memory.compute_Q(s_perc, act, time_angle, config.sigma_t_init)
            q_values.append(q)
        
        q_values = np.asarray(q_values, dtype=np.float64)
        
        # Softmax policy
        if len(q_values) > 0:
            logits = beta * q_values
            logits -= np.max(logits)
            softmax = np.exp(logits)
            softmax /= (np.sum(softmax) + 1e-12)
        else:
            softmax = np.array([1.0])
        
        # Map action to probability
        if a_perc in available_actions:
            idx_a = available_actions.index(a_perc)
            if idx_a < len(softmax):
                action_prob = (1.0 - epsilon) * softmax[idx_a]
            else:
                action_prob = epsilon / len(available_actions) if available_actions else epsilon
        elif a_perc == -1:
            action_prob = epsilon
        else:
            action_prob = epsilon / len(available_actions) if available_actions else epsilon
        
        loglik += np.log(action_prob + 1e-12)
        
        # Add SR record from real transition
        if t < n_records - 1:
            a_next = int(actions[t + 1])
            a_next_perc = a_next if a_next in known_actions else -1
            
            # Get next state
            if same_day_next[t]:
                next_state = int(states[t + 1])
            else:
                next_state = -9  # End of day
            
            # Add SR record (increment visit count)
            if a_perc != -1 and next_state > 0:
                memory.add_record(
                    state=s_perc,
                    action=a_perc,
                    future_state=next_state,
                    time_angle=time_angle,
                    day_seq=current_day,
                    record_date=current_date,
                    visit_count=1.0
                )
                # Update reward estimate
                memory.add_transition(s_perc, a_perc, next_state, r_t, is_planning=False)
            
            # Dyna planning: simulate additional steps
            for _ in range(config.n_planning_steps):
                # Random start state from visited states
                if len(memory.visit_counts) > 0:
                    visited_states = list(memory.visit_counts.keys())
                    counts = list(memory.visit_counts.values())
                    total = sum(counts)
                    probs = [c/total for c in counts]
                    plan_state = np.random.choice(visited_states, p=probs)
                else:
                    plan_state = s_perc
                
                # Simulate step (need to provide time - use current time for simplicity)
                plan_next_state, plan_action, plan_reward = memory.simulate_step(plan_state, time_angle)
                
                if plan_next_state > 0 and plan_action > 0:
                    memory.add_record(
                        state=plan_state,
                        action=plan_action,
                        future_state=plan_next_state,
                        time_angle=time_angle,
                        day_seq=current_day,
                        record_date=current_date,
                        visit_count=0.5 * alpha  # Smaller weight for planning
                    )
                    memory.add_transition(plan_state, plan_action, plan_next_state, 
                                         plan_reward, is_planning=True)
    
    return -loglik


# =============================================================================
# Model Fitting
# =============================================================================

def fit_sr_dyna_model(user_df: pd.DataFrame,
                      config: Optional[SRDynaConfig] = None,
                      verbose: bool = True) -> Dict[str, Any]:
    """Fit SR Dyna model for one user."""
    if config is None:
        config = SRDynaConfig()
    
    # Prepare data
    sr_data = prepare_sr_dyna_data(user_df, config)
    n_records = sr_data['n_records']
    
    if n_records < 2:
        return {
            'n_records': int(n_records),
            'log_likelihood': np.nan,
            'AIC': np.nan,
            'BIC': np.nan,
            'alpha': np.nan,
            'beta': np.nan,
            'epsilon': np.nan,
            'phi': np.nan,
            'converged': False,
            'n_iterations': 0,
            'optimization_message': 'Insufficient records',
        }
    
    # Initial parameters
    theta_init = pack_params_sr_dyna(
        alpha=config.alpha_init,
        beta=config.beta_init,
        epsilon=config.epsilon_init,
        phi=config.phi_init,
    )
    
    # Optimize (no bounds - following Simple MF approach for better optimization)
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        result = minimize(
            simulate_and_loglik_sr_dyna,
            theta_init,
            args=(sr_data, config),
            method='L-BFGS-B',
            options={'maxiter': config.maxiter, 'ftol': config.ftol, 'disp': False},
        )
    
    # Extract results
    fitted = unpack_params_sr_dyna(result.x)
    log_likelihood = -float(result.fun)
    
    k_params = 4  # alpha, beta, epsilon, phi (no gamma)
    aic = 2 * k_params - 2 * log_likelihood
    bic = k_params * np.log(max(n_records, 1)) - 2 * log_likelihood
    
    summary = {
        'n_records': int(n_records),
        'log_likelihood': float(log_likelihood),
        'AIC': float(aic),
        'BIC': float(bic),
        'alpha': float(fitted['alpha']),
        'beta': float(fitted['beta']),
        'epsilon': float(fitted['epsilon']),
        'phi': float(fitted['phi']),
        'converged': bool(result.success),
        'n_iterations': int(result.nit),
        'optimization_message': str(result.message),
    }
    
    if verbose:
        print(f"SR Dyna model fitting {'converged' if result.success else 'did not converge'}.")
        print(f"  Log-likelihood: {log_likelihood:.2f}, AIC: {aic:.2f}, BIC: {bic:.2f}")
        print(f"  alpha (SR rate): {fitted['alpha']:.4f}")
        print(f"  beta (softmax temp): {fitted['beta']:.4f}")
        print(f"  epsilon (explore): {fitted['epsilon']:.4f}")
        print(f"  phi (forgetting): {fitted['phi']:.4f}")
    
    return summary


def fit_sr_dyna_for_all_users(users_dict: Dict[int, Any],
                              config: Optional[SRDynaConfig] = None,
                              sample_size: Optional[int] = None,
                              verbose: bool = True) -> pd.DataFrame:
    """Fit SR Dyna model for all users."""
    if config is None:
        config = SRDynaConfig()
    
    user_ids = list(users_dict.keys())
    if sample_size is not None:
        user_ids = user_ids[:sample_size]
    
    results = []
    for i, user_id in enumerate(user_ids):
        if verbose:
            print(f"[{i+1}/{len(user_ids)}] Fitting SR Dyna for user {user_id}...", end='')
        
        user = users_dict[user_id]
        user_df = user.to_dataframe()
        
        try:
            t0 = time.time()
            result = fit_sr_dyna_model(user_df, config, verbose=False)
            elapsed = time.time() - t0
            if verbose:
                print(f"\tUser {user_id} SR Dyna fit time: {elapsed:.2f}s")
            result['user_id'] = user_id
            result['fit_time_seconds'] = elapsed
            results.append(result)
        except Exception as e:
            if verbose:
                print(f"  Error: {e}")
            results.append({
                'user_id': user_id,
                'n_records': 0,
                'log_likelihood': np.nan,
                'AIC': np.nan,
                'BIC': np.nan,
                'alpha': np.nan,
                'beta': np.nan,
                'epsilon': np.nan,
                'phi': np.nan,
                'converged': False,
                'n_iterations': None,
                'optimization_message': str(e),
                'fit_time_seconds': np.nan,
            })
    
    return pd.DataFrame(results)