In [6]:
#!/usr/bin/env python3
"""
Bat-Rat Seasonal Behavior Feature Engineering
Polished and completed version for Assignment 3
"""

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import warnings
import os

warnings.filterwarnings("ignore")


class BatSeasonalFeatureEngineer:
    def __init__(self, data_path=None, df=None):
        if data_path and os.path.exists(data_path):
            print(f"📁 Loading data from: {data_path}")
            try:
                if data_path.lower().endswith(".csv"):
                    self.df = pd.read_csv(data_path)
                elif data_path.lower().endswith((".xls", ".xlsx")):
                    self.df = pd.read_excel(data_path)
                else:
                    self.df = pd.read_csv(data_path)
            except Exception as e:
                raise RuntimeError(f"Error loading data: {e}")
        elif df is not None:
            self.df = df.copy()
        else:
            raise ValueError("Provide either data_path or df parameter")

        self.original_shape = self.df.shape
        print(f"✅ Data loaded: {self.original_shape[0]} rows, {self.original_shape[1]} cols")

        # Tracking variables
        self.date_column = None
        self.encounter_columns = []
        self.behavior_columns = []
        self.id_columns = []

    def analyze_data_structure(self):
        print("\n" + "=" * 60)
        print("🔍 DATA STRUCTURE ANALYSIS")
        print("=" * 60)
        print(f"Dataset shape: {self.df.shape}\n")
        print("📊 Columns and dtypes:")
        for i, (col, dtype) in enumerate(self.df.dtypes.items()):
            print(f"  {i+1:2d}. {col:<30} ({dtype})")

        missing = self.df.isnull().sum()
        if missing.sum() > 0:
            print("\n❓ Missing values (only columns with >0):")
            missing_pct = (missing / len(self.df) * 100).round(2)
            for col in missing[missing > 0].index:
                print(f"  {col:<30} {missing[col]:>6} ({missing_pct[col]:>6.2f}%)")
        else:
            print("\n  No missing values found! 🎉")

        self._detect_column_types()

        print("\n🎯 Key columns identified:")
        if self.date_column:
            print(f"  Date column: {self.date_column}")
        if self.encounter_columns:
            print(f"  Encounter columns: {self.encounter_columns}")
        if self.behavior_columns:
            print(f"  Behavior columns: {self.behavior_columns}")
        if self.id_columns:
            print(f"  ID columns: {self.id_columns}")

        print("\n📋 Sample data (first 5 rows):")
        print(self.df.head())

        return self.df.describe(include="all")

    def _detect_column_types(self):
        cols = list(self.df.columns)
        lower = [c.lower() for c in cols]

        date_kw = ["date", "time", "timestamp", "day", "month", "year"]
        for i, name in enumerate(lower):
            if any(k in name for k in date_kw):
                self.date_column = cols[i]
                break

        encounter_kw = ["encounter", "rat", "contact", "interaction", "prey", "encounters"]
        for i, name in enumerate(lower):
            if any(k in name for k in encounter_kw):
                self.encounter_columns.append(cols[i])

        behavior_kw = ["activity", "behavior", "behaviour", "movement", "flight", "foraging"]
        for i, name in enumerate(lower):
            if any(k in name for k in behavior_kw):
                self.behavior_columns.append(cols[i])

        id_kw = ["id", "bat", "individual", "animal", "subject"]
        for i, name in enumerate(lower):
            if any(k in name for k in id_kw) and "date" not in name:
                self.id_columns.append(cols[i])

    def prepare_temporal_features(self, date_col=None, date_format=None):
        print("\n" + "=" * 60)
        print("📅 TEMPORAL FEATURE ENGINEERING")
        print("=" * 60)

        if date_col:
            self.date_column = date_col

        if not self.date_column:
            raise ValueError("No date column specified. Provide date_col or ensure auto-detection found a date column.")

        # parse dates
        if date_format:
            self.df["date"] = pd.to_datetime(self.df[self.date_column], format=date_format, errors="coerce")
        else:
            self.df["date"] = pd.to_datetime(self.df[self.date_column], infer_datetime_format=True, errors="coerce")

        null_dates = self.df["date"].isnull().sum()
        if null_dates > 0:
            print(f"⚠️  {null_dates} date(s) could not be parsed and are NaT")

        self.df = self.df.sort_values("date").reset_index(drop=True)

        self.df["year"] = self.df["date"].dt.year
        self.df["month"] = self.df["date"].dt.month
        self.df["day"] = self.df["date"].dt.day
        self.df["day_of_year"] = self.df["date"].dt.dayofyear
        self.df["weekday"] = self.df["date"].dt.dayofweek
        # isocalendar returns DataFrame in newer pandas; cast to int
        self.df["week"] = self.df["date"].dt.isocalendar().week.astype(int)

        if self.df["date"].notna().any():
            date_range = f"{self.df['date'].min().strftime('%Y-%m-%d')} to {self.df['date'].max().strftime('%Y-%m-%d')}"
            print(f"📊 Date range: {date_range}")
            print(f"📊 Total days: {(self.df['date'].max() - self.df['date'].min()).days}")
        else:
            print("⚠️  No valid dates to report range")

        print("✅ Basic temporal features created")

    def create_seasonal_features(self, location="australia"):
        print("\n" + "=" * 60)
        print("🌟 SEASONAL FEATURE ENGINEERING")
        print("=" * 60)

        if "month" not in self.df.columns:
            raise ValueError("Call prepare_temporal_features first (to create month/day columns).")

        loc = location.lower()
        if loc == "australia":
            def classify_season(month):
                if month in [12, 1, 2]:
                    return "summer"
                elif month in [3, 4, 5]:
                    return "autumn"
                elif month in [6, 7, 8]:
                    return "winter"
                else:
                    return "spring"

            season_food_map = {"winter": 0.9, "spring": 0.1, "summer": 0.4, "autumn": 0.6}
        else:
            def classify_season(month):
                if month in [12, 1, 2]:
                    return "winter"
                elif month in [3, 4, 5]:
                    return "spring"
                elif month in [6, 7, 8]:
                    return "summer"
                else:
                    return "autumn"

            season_food_map = {"winter": 0.9, "spring": 0.1, "summer": 0.3, "autumn": 0.7}

        self.df["season"] = self.df["month"].apply(classify_season)
        self.df["is_winter"] = (self.df["season"] == "winter").astype(int)
        self.df["is_spring"] = (self.df["season"] == "spring").astype(int)
        self.df["is_summer"] = (self.df["season"] == "summer").astype(int)
        self.df["is_autumn"] = (self.df["season"] == "autumn").astype(int)

        self.df["month_sin"] = np.sin(2 * np.pi * self.df["month"] / 12)
        self.df["month_cos"] = np.cos(2 * np.pi * self.df["month"] / 12)

        self.df["food_scarcity_index"] = self.df["season"].map(season_food_map).fillna(0.5)
        self.df["food_abundance_index"] = 1 - self.df["food_scarcity_index"]

        # days since season start: compute season start date per row then difference
        def season_start_date(row):
            year = int(row["year"])
            s = row["season"]
            m = row["month"]
            d = row["day"]
            if loc == "australia":
                # Austral seasons by month blocks
                if s == "summer":
                    # summer: Dec (prev year) / Jan / Feb => start 1 Dec (if month==12) else 1 Dec of prev year
                    if m == 12:
                        start = pd.Timestamp(year=year, month=12, day=1)
                    else:
                        # Jan or Feb -> start was Dec 1 of previous year
                        start = pd.Timestamp(year=year - 1, month=12, day=1)
                elif s == "autumn":
                    start = pd.Timestamp(year=year, month=3, day=1)
                elif s == "winter":
                    start = pd.Timestamp(year=year, month=6, day=1)
                else:  # spring
                    start = pd.Timestamp(year=year, month=9, day=1)
            else:
                # Northern hemisphere simple blocks
                if s == "winter":
                    if m == 12:
                        start = pd.Timestamp(year=year, month=12, day=1)
                    else:
                        start = pd.Timestamp(year=year - 1, month=12, day=1)
                elif s == "spring":
                    start = pd.Timestamp(year=year, month=3, day=1)
                elif s == "summer":
                    start = pd.Timestamp(year=year, month=6, day=1)
                else:
                    start = pd.Timestamp(year=year, month=9, day=1)
            return start

        self.df["season_start"] = self.df.apply(season_start_date, axis=1)
        self.df["days_since_season_start"] = (self.df["date"] - self.df["season_start"]).dt.days.clip(lower=0)

        # transitional windows: define +/- 14 days around season boundaries between seasons
        def is_seasonal_transition(row):
            date = row["date"]
            # compute the four canonical boundaries (roughly)
            # For Australia: boundaries are approx Mar1, Jun1, Sep1, Dec1
            if loc == "australia":
                boundaries = [pd.Timestamp(year=row["year"], month=3, day=1),
                              pd.Timestamp(year=row["year"], month=6, day=1),
                              pd.Timestamp(year=row["year"], month=9, day=1),
                              pd.Timestamp(year=row["year"], month=12, day=1)]
            else:
                boundaries = [pd.Timestamp(year=row["year"], month=3, day=1),
                              pd.Timestamp(year=row["year"], month=6, day=1),
                              pd.Timestamp(year=row["year"], month=9, day=1),
                              pd.Timestamp(year=row["year"], month=12, day=1)]
            # allow wrap-around near year boundaries
            for b in boundaries:
                if abs((date - b).days) <= 14:
                    return 1
            # also check boundaries in adjacent years to catch Jan/Dec near Dec/Jan boundary
            for b in [b + pd.DateOffset(years=1) for b in boundaries]:
                if abs((date - b).days) <= 14:
                    return 1
            for b in [b - pd.DateOffset(years=1) for b in boundaries]:
                if abs((date - b).days) <= 14:
                    return 1
            return 0

        self.df["is_transition_period"] = self.df.apply(is_seasonal_transition, axis=1)

        print("📊 Seasonal distribution:")
        seasonal_counts = self.df["season"].value_counts()
        for s, cnt in seasonal_counts.items():
            pct = cnt / len(self.df) * 100
            sc = season_food_map.get(s, np.nan)
            print(f"  {s.capitalize():<8} {cnt:>6} ({pct:5.1f}%) - Food scarcity: {sc}")

        print("✅ Seasonal features created")

    def create_encounter_features(self, encounter_cols=None):
        print("\n" + "=" * 60)
        print("🐀 ENCOUNTER FEATURE ENGINEERING")
        print("=" * 60)

        if encounter_cols:
            if isinstance(encounter_cols, str):
                encounter_cols = [encounter_cols]
            self.encounter_columns = encounter_cols

        # If no encounter col found, simulate a reasonable one (seeded reproducible)
        if not self.encounter_columns:
            np.random.seed(42)
            base = np.random.poisson(1.5, len(self.df))
            seasonal_multiplier = self.df["food_abundance_index"] * 1.8 + 0.2
            self.df["encounter_count"] = np.maximum(0, (base * seasonal_multiplier).astype(int))
            self.encounter_columns = ["encounter_count"]
            print("✅ Created simulated encounter_count")

        # Sorting for rolling/shifts
        if self.id_columns:
            id_col = self.id_columns[0]
            self.df = self.df.sort_values([id_col, "date"]).reset_index(drop=True)
            group_col = id_col
        else:
            self.df = self.df.sort_values("date").reset_index(drop=True)
            group_col = None

        for enc_col in self.encounter_columns:
            print(f"\n🔍 Processing encounter column: {enc_col}")
            total = self.df[enc_col].sum()
            mean = self.df[enc_col].mean()
            print(f"  Total: {total}, Mean: {mean:.3f}")

            windows = [7, 14, 30]
            for w in windows:
                if group_col:
                    self.df[f"{enc_col}_rolling_{w}d"] = (
                        self.df.groupby(group_col)[enc_col]
                        .rolling(w, min_periods=1)
                        .mean()
                        .reset_index(0, drop=True)
                    )
                    self.df[f"{enc_col}_std_{w}d"] = (
                        self.df.groupby(group_col)[enc_col]
                        .rolling(w, min_periods=1)
                        .std()
                        .reset_index(0, drop=True)
                    ).fillna(0)
                else:
                    self.df[f"{enc_col}_rolling_{w}d"] = self.df[enc_col].rolling(w, min_periods=1).mean()
                    self.df[f"{enc_col}_std_{w}d"] = self.df[enc_col].rolling(w, min_periods=1).std().fillna(0)

            # unique has_encounter column per encounter column (avoid overwrite)
            has_col = f"has_{enc_col}"
            self.df[has_col] = (self.df[enc_col] > 0).astype(int)

            # days_since_last_encounter
            if group_col:
                # compute per-group last encounter via forward/backward fills
                def compute_days_since(group):
                    # get dates where encounter happened
                    last_encounter_date = pd.Series(pd.NaT, index=group.index)
                    last_seen = pd.NaT
                    for idx, row in group.iterrows():
                        if row[has_col] == 1:
                            last_seen = row["date"]
                            last_encounter_date.loc[idx] = 0
                        else:
                            if pd.isna(last_seen):
                                last_encounter_date.loc[idx] = np.nan
                            else:
                                last_encounter_date.loc[idx] = (row["date"] - last_seen).days
                    return last_encounter_date

                self.df[f"days_since_last_{enc_col}"] = (
                    self.df.groupby(group_col).apply(compute_days_since).reset_index(level=0, drop=True)
                )
                # replace NaN with large number or 0 per your analysis needs; here fill with -1 to indicate no prior encounter
                self.df[f"days_since_last_{enc_col}"] = self.df[f"days_since_last_{enc_col}"].fillna(-1)
            else:
                # global: tracking last encounter date
                last_seen = pd.NaT
                days_since = []
                for _, row in self.df.iterrows():
                    if row[has_col] == 1:
                        last_seen = row["date"]
                        days_since.append(0)
                    else:
                        if pd.isna(last_seen):
                            days_since.append(np.nan)
                        else:
                            days_since.append((row["date"] - last_seen).days)
                self.df[f"days_since_last_{enc_col}"] = pd.Series(days_since).fillna(-1)

            # encounter frequency categories (quantile bins)
            if self.df[enc_col].max() > 0:
                q33, q67 = self.df[enc_col].quantile([0.33, 0.67]).values
                self.df[f"{enc_col}_frequency"] = pd.cut(
                    self.df[enc_col],
                    bins=[-0.1, q33, q67, np.inf],
                    labels=["low", "medium", "high"]
                ).astype(object)

            print("  ✅ Rolling windows and days-since-last created")

    def create_seasonal_interaction_features(self):
        print("\n" + "=" * 60)
        print("🎯 SEASONAL-ENCOUNTER INTERACTIONS")
        print("=" * 60)

        if not self.encounter_columns:
            print("❌ No encounter columns; run create_encounter_features first")
            return

        primary = self.encounter_columns[0]
        print(f"Primary encounter: {primary}")

        # season-encounter interactions
        self.df["winter_encounters"] = self.df["is_winter"] * self.df[primary]
        self.df["spring_encounters"] = self.df["is_spring"] * self.df[primary]
        self.df["summer_encounters"] = self.df["is_summer"] * self.df[primary]
        self.df["autumn_encounters"] = self.df["is_autumn"] * self.df[primary]

        # scarcity/abundance interactions (safe divide)
        eps = 1e-6
        self.df["scarcity_encounter_ratio"] = self.df[primary] / (self.df["food_scarcity_index"] + eps)
        self.df["abundance_encounter_product"] = self.df[primary] * self.df["food_abundance_index"]

        if f"{primary}_rolling_7d" in self.df.columns:
            self.df["winter_rolling_encounters"] = self.df["is_winter"] * self.df[f"{primary}_rolling_7d"]
            self.df["spring_rolling_encounters"] = self.df["is_spring"] * self.df[f"{primary}_rolling_7d"]

        # seasonal mean contrasts per individual (if id available)
        if self.id_columns:
            id_col = self.id_columns[0]
            seasonal_means = (
                self.df.groupby([id_col, "season"])[primary]
                .mean()
                .unstack(fill_value=np.nan)
            )
            if "winter" in seasonal_means.columns and "spring" in seasonal_means.columns:
                seasonal_means["winter_spring_contrast"] = seasonal_means["spring"] - seasonal_means["winter"]
                contrast_df = seasonal_means[["winter_spring_contrast"]].reset_index()
                self.df = self.df.merge(contrast_df, on=id_col, how="left")

        # weekend interactions, transition interactions
        self.df["weekend"] = (self.df["weekday"] >= 5).astype(int)
        self.df["winter_weekend_encounters"] = self.df["is_winter"] * self.df["weekend"] * self.df[primary]
        if "is_transition_period" in self.df.columns:
            self.df["transition_encounters"] = self.df["is_transition_period"] * self.df[primary]

        # lagged effects - note: these shifts are row-based approximations
        if self.id_columns:
            id_col = self.id_columns[0]
            self.df["prev_month_encounters"] = self.df.groupby(id_col)[primary].shift(30).fillna(0)
            self.df["prev_season_effect"] = self.df.groupby(id_col)["food_scarcity_index"].shift(90).fillna(0.5)
        else:
            self.df["prev_month_encounters"] = self.df[primary].shift(30).fillna(0)
            self.df["prev_season_effect"] = self.df["food_scarcity_index"].shift(90).fillna(0.5)

        print("✅ Interaction features created")

    def create_behavior_features(self):
        print("\n" + "=" * 60)
        print("🦇 BEHAVIOR FEATURE ENGINEERING")
        print("=" * 60)
        if not self.behavior_columns:
            print("📊 No behavior columns detected. Skipping.")
            return

        for bcol in self.behavior_columns:
            if bcol not in self.df.columns:
                continue
            self.df[f"{bcol}_zscore"] = (self.df[bcol] - self.df[bcol].mean()) / (self.df[bcol].std() + 1e-6)

            windows = [7, 14]
            for w in windows:
                if self.id_columns:
                    id_col = self.id_columns[0]
                    self.df[f"{bcol}_rolling_{w}d"] = (
                        self.df.groupby(id_col)[bcol]
                        .rolling(w, min_periods=1)
                        .mean()
                        .reset_index(0, drop=True)
                    )
                else:
                    self.df[f"{bcol}_rolling_{w}d"] = self.df[bcol].rolling(w, min_periods=1).mean()

            for season in ["winter", "spring"]:
                self.df[f"{season}_{bcol}"] = self.df[f"is_{season}"] * self.df[bcol]

            print(f"  ✅ Created behavior features for {bcol}")

    def analyze_seasonal_patterns(self):
        print("\n" + "=" * 60)
        print("📈 SEASONAL PATTERN ANALYSIS")
        print("=" * 60)

        if not self.encounter_columns:
            print("❌ No encounter columns; run create_encounter_features first")
            return

        primary = self.encounter_columns[0]
        seasonal_summary = self.df.groupby("season")[primary].agg(["count", "mean", "std", "min", "max"]).round(3)
        print("🔸 Seasonal Encounter Summary:")
        print(seasonal_summary)

        # comparisons winter vs spring
        winter = self.df[self.df["is_winter"] == 1][primary].dropna()
        spring = self.df[self.df["is_spring"] == 1][primary].dropna()
        if len(winter) > 0 and len(spring) > 0:
            w_mean = winter.mean()
            s_mean = spring.mean()
            print(f"\n🎯 Winter mean: {w_mean:.3f} | Spring mean: {s_mean:.3f}")
            ratio = (s_mean / w_mean) if w_mean != 0 else np.nan
            print(f"  Spring/Winter ratio: {ratio:.3f}")

            # statistical test
            try:
                from scipy import stats
                t_stat, p_value = stats.ttest_ind(winter, spring, equal_var=False, nan_policy="omit")
                print(f"  T-test: t={t_stat:.3f}, p={p_value:.3f}")
                print("  🎉 Significant difference" if p_value < 0.05 else "  📊 No significant difference")
            except Exception as e:
                print(f"  ⚠️ Could not run t-test: {e}")

        if "food_scarcity_index" in self.df.columns:
            corr = self.df["food_scarcity_index"].corr(self.df[primary])
            print(f"  Food scarcity vs encounters correlation: {corr:.3f}")

        return seasonal_summary

    def get_final_feature_set(self, save_path=None):
        print("\n" + "=" * 60)
        print("🎯 FINAL FEATURE SET")
        print("=" * 60)

        # core temporal & seasonal features
        core = [
            "date", "year", "month", "day", "day_of_year", "weekday", "week",
            "season", "is_winter", "is_spring", "is_summer", "is_autumn",
            "month_sin", "month_cos", "days_since_season_start", "is_transition_period",
            "food_scarcity_index", "food_abundance_index"
        ]

        # add encounter and rolling columns
        extras = []
        for enc in self.encounter_columns:
            extras += [enc]
            # rolling/std columns
            for w in [7, 14, 30]:
                extras.append(f"{enc}_rolling_{w}d")
                extras.append(f"{enc}_std_{w}d")
            extras.append(f"has_{enc}")
            extras.append(f"days_since_last_{enc}")
            extras.append(f"{enc}_frequency")
            extras.append("winter_encounters")
            extras.append("spring_encounters")
            extras.append("scarcity_encounter_ratio")
            extras.append("abundance_encounter_product")
            extras.append("prev_month_encounters")
            extras.append("prev_season_effect")
            extras.append("winter_rolling_encounters" if f"{enc}_rolling_7d" in self.df.columns else None)

        # behavior extras
        beh_extras = []
        for b in self.behavior_columns:
            beh_extras += [b, f"{b}_zscore", f"{b}_rolling_7d", f"{b}_rolling_14d", f"winter_{b}", f"spring_{b}"]

        # gather all existing columns to avoid KeyError
        candidate_cols = core + [c for c in extras + beh_extras if c]
        final_cols = [c for c in candidate_cols if c in self.df.columns]

        final_df = self.df[final_cols].copy()

        if save_path:
            final_df.to_csv(save_path, index=False)
            print(f"✅ Saved final feature set to: {save_path}")

        print(f"✅ Final feature set prepared: {final_df.shape[0]} rows, {final_df.shape[1]} cols")
        return final_df

