In [None]:
# === Colab-ready: Simulação híbrida Plotly + Folium (otimizada) ===
# 10 tubarões, 3 dias (ping a cada 3h), global, CHL heat, hotspots dinâmicos, trajetórias, predição (RF).
# Copiar/colar no Google Colab e executar.

# 0) Instalar dependências (execute apenas uma vez)
!pip install -q plotly folium pandas numpy scikit-learn tqdm

# 1) imports
import math, random, pickle, json
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
from tqdm import tqdm
import plotly.graph_objects as go
import plotly.io as pio
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import folium
from folium.plugins import TimestampedGeoJson
from IPython.display import HTML, display

pio.renderers.default = "notebook"  # mostra inline no Colab

# 2) parâmetros (otimizados)
SEED = 42
random.seed(SEED); np.random.seed(SEED)

N_SHARKS = 10
DAYS = 3
PING_HOURS = 3
PTS_PER_DAY = int(24 / PING_HOURS)
TOTAL_STEPS = DAYS * PTS_PER_DAY   # 24 frames
START_DATE = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)

# domínio global (mundo)
LAT_MIN, LAT_MAX = -80.0, 80.0
LON_MIN, LON_MAX = -180.0, 180.0

# grade para CHL (reduzida)
GRID_LAT = 32
GRID_LON = 64
lat_centers = np.linspace(LAT_MIN, LAT_MAX, GRID_LAT)
lon_centers = np.linspace(LON_MIN, LON_MAX, GRID_LON)
LonG, LatG = np.meshgrid(lon_centers, lat_centers)

# bin edges para histogram (hotspots)
lat_edges = np.linspace(LAT_MIN, LAT_MAX, GRID_LAT + 1)
lon_edges = np.linspace(LON_MIN, LON_MAX, GRID_LON + 1)

# eddies leves
N_EDDIES = 12
eddy_params = []
for i in range(N_EDDIES):
    cx0 = np.random.uniform(LON_MIN, LON_MAX)
    cy0 = np.random.uniform(LAT_MIN, LAT_MAX)
    vx = np.random.uniform(-0.06, 0.06)
    vy = np.random.uniform(-0.06, 0.06)
    sigma = np.random.uniform(1.0, 4.0)
    amp = np.random.uniform(-1.2, 1.2)
    eddy_params.append((cx0, cy0, vx, vy, sigma, amp))

# campos ambiente (tempo-dependentes)
def ssh_field(t):
    f = np.zeros_like(LonG)
    for (cx0, cy0, vx, vy, sigma, amp) in eddy_params:
        cx = cx0 + vx * t
        cy = cy0 + vy * t
        f += amp * np.exp(-((LonG - cx)*2 + (LatG - cy)2) / (2 * sigma*2))
    return f

def sst_field(t):
    seasonal = 1.2 * np.sin(2 * math.pi * (t / 365.0))
    base = 20.0 - 0.05 * np.abs(LatG)
    noise = 0.45 * np.random.randn(*base.shape)
    return base + seasonal + noise

def chl_field(t, ssh):
    base = 0.08 + 1.0 * np.exp(-0.0009 * ((LonG + 20)*2 + (LatG - 10)*2))
    chl = base + 0.6 * np.maximum(0, ssh) + 0.12 * np.random.randn(*base.shape)
    return np.clip(chl, 0.01, None)

def current_field(t):
    return 0.18 + 0.08 * np.sin(0.02 * LonG + 0.03 * LatG + 0.05 * t)

def depth_field():
    return 50 + 5000 * np.clip((180 - np.abs(LonG)) / 360.0, 0, 1) + 120 * np.sin(0.03 * LatG)

# 3) Simular trajetórias (registros / pings)
records = []
for sid in range(N_SHARKS):
    lat = np.random.uniform(-40, 40)
    lon = np.random.uniform(-140, 140)
    heading = np.random.uniform(0, 2*math.pi)
    for step in range(TOTAL_STEPS):
        ssh = ssh_field(step)
        sst = sst_field(step)
        chl = chl_field(step, ssh)
        cur = current_field(step)
        depth = depth_field()
        # pegar células mais próximas (index)
        i_lat = int(np.argmin(np.abs(lat_centers - lat)))
        i_lon = int(np.argmin(np.abs(lon_centers - lon)))
        local = {
            'chl': float(chl[i_lat, i_lon]),
            'ssha': float(ssh[i_lat, i_lon]),
            'sst': float(sst[i_lat, i_lon]),
            'current': float(cur[i_lat, i_lon]),
            'depth': float(depth[i_lat, i_lon])
        }
        # probabilidade sintética de forrageio
        forage_score = 0.45*(local['chl'] > 0.9) + 0.25*(local['ssha'] > 0.12) + 0.2*(15 < local['sst'] < 28)
        p_forage = 1 / (1 + math.exp(-6 * (forage_score - 0.12)))
        is_forage = int(np.random.rand() < p_forage)
        odba = np.random.gamma(1.0 + 0.6 * (1 - is_forage), scale=0.25)
        # movimento
        if is_forage and np.random.rand() < 0.85:
            best_ang = None; best_val = -1e9
            for ang in np.linspace(0, 2*math.pi, 12, endpoint=False):
                cand_lat = lat + 0.02 * math.sin(ang)
                cand_lon = lon + 0.02 * math.cos(ang)
                i_la2 = int(np.argmin(np.abs(lat_centers - cand_lat)))
                i_lo2 = int(np.argmin(np.abs(lon_centers - cand_lon)))
                val = chl[i_la2, i_lo2]
                if val > best_val:
                    best_val = val; best_ang = ang
            heading = best_ang + np.random.normal(0, 0.45)
            step_len = max(0.005, np.random.normal(0.1, 0.03))
        else:
            cur_dir = 0.1 * np.sin(0.01 * lon) + 0.07 * np.cos(0.02 * lat)
            heading += np.random.normal(cur_dir, 0.8)
            step_len = max(0.005, np.random.normal(0.07, 0.04))
        lat += 0.01 * step_len * math.sin(heading)
        lon += 0.01 * step_len * math.cos(heading)
        # clamp no globo
        lat = float(np.clip(lat, LAT_MIN + 0.01, LAT_MAX - 0.01))
        lon = float(np.clip(lon, LON_MIN + 0.01, LON_MAX - 0.01))
        temp_sensor = local['sst'] + np.random.normal(0, 0.12)
        timestamp = (START_DATE + timedelta(hours=step * PING_HOURS)).isoformat()
        records.append({
            'shark_id': f'shark_{sid}',
            'step': step,
            'timestamp': timestamp,
            'lat': lat, 'lon': lon,
            'chl': local['chl'], 'ssha': local['ssha'], 'sst': local['sst'],
            'current': local['current'], 'depth': local['depth'],
            'temp_sensor': temp_sensor, 'odba': float(odba), 'is_forage': is_forage
        })

df = pd.DataFrame(records)
print("Pings simulados:", len(df))

# 4) Montar dataset supervisionado (lags -> next pos)
def build_features(df, n_lags=3):
    rows = []
    for sid, g in df.groupby('shark_id'):
        g = g.sort_values('step').reset_index(drop=True)
        for i in range(n_lags, len(g)-1):
            past = g.loc[i-n_lags:i]  # inclui i
            feat = {}
            for j in range(n_lags):
                feat[f'lat_lag_{j}'] = past.loc[i-n_lags+j, 'lat']
                feat[f'lon_lag_{j}'] = past.loc[i-n_lags+j, 'lon']
            for col in ['chl','ssha','sst','current','depth','temp_sensor','odba']:
                feat[col] = g.loc[i, col]
            feat['lat_next'] = g.loc[i+1, 'lat']
            feat['lon_next'] = g.loc[i+1, 'lon']
            rows.append(feat)
    return pd.DataFrame(rows)

feat_df = build_features(df, n_lags=3)
feature_cols = [c for c in feat_df.columns if c not in ['lat_next','lon_next']]
X = feat_df[feature_cols]
y_lat = feat_df['lat_next']; y_lon = feat_df['lon_next']

# treinar modelos RF (lat e lon)
X_train, X_test, ylat_train, ylat_test, ylon_train, ylon_test = train_test_split(
    X, y_lat, y_lon, test_size=0.25, random_state=SEED
)
scaler = StandardScaler().fit(X_train)
X_train_s = scaler.transform(X_train); X_test_s = scaler.transform(X_test)

rf_lat = RandomForestRegressor(n_estimators=150, max_depth=12, random_state=SEED, n_jobs=-1)
rf_lon = RandomForestRegressor(n_estimators=150, max_depth=12, random_state=SEED, n_jobs=-1)
rf_lat.fit(X_train_s, ylat_train); rf_lon.fit(X_train_s, ylon_train)

from sklearn.metrics import mean_absolute_error
mae_lat = mean_absolute_error(ylat_test, rf_lat.predict(X_test_s))
mae_lon = mean_absolute_error(ylon_test, rf_lon.predict(X_test_s))
print(f"MAE (lat): {mae_lat:.5f}°, (lon): {mae_lon:.5f}°")

models_obj = {'rf_lat': rf_lat, 'rf_lon': rf_lon, 'scaler': scaler, 'feature_cols': feature_cols}

# 5) Construir frames Plotly (CHL + hotspots + trajetórias + predições)
frames = []
times = sorted(df['step'].unique())

# para reduzir sobrecarga do heatmap, subamostrar o grid (slice)
slice_r = max(1, GRID_LAT // 16)  # se GRID_LAT=32 => slice_r=2
grid_lats_sub = LatG[::slice_r, ::slice_r].flatten()
grid_lons_sub = LonG[::slice_r, ::slice_r].flatten()

print("Construindo frames (Plotly) — aguarde...")
for t in tqdm(times):
    ssh_t = ssh_field(t)
    chl_t = chl_field(t, ssh_t)
    # subamostrar chl para o heat
    chl_flat_sub = chl_t[::slice_r, ::slice_r].flatten()

    # hotspots: densidade local dos pings no momento t
    sub = df[df['step'] == t]
    if len(sub) > 0:
        H, yedges, xedges = np.histogram2d(sub['lat'], sub['lon'], bins=[lat_edges, lon_edges])
    else:
        H = np.zeros((len(lat_edges)-1, len(lon_edges)-1))
        yedges = lat_edges; xedges = lon_edges
    dens = H
    dens_flat = dens.flatten()
    if dens_flat.max() > 0:
        nonzero = dens_flat[dens_flat > 0]
        thresh = np.percentile(nonzero, 90) if len(nonzero) > 0 else np.inf
        mask = (dens_flat >= thresh) & (dens_flat > 0)
    else:
        mask = np.array([False] * dens_flat.shape[0])

    # centers for histogram cells
    lat_centers_hist = 0.5 * (yedges[:-1] + yedges[1:])
    lon_centers_hist = 0.5 * (xedges[:-1] + xedges[1:])
    LonC, LatC = np.meshgrid(lon_centers_hist, lat_centers_hist)
    centers_lons = LonC.flatten(); centers_lats = LatC.flatten()
    hotspot_lons = centers_lons[mask] if centers_lons.size > 0 else np.array([])
    hotspot_lats = centers_lats[mask] if centers_lats.size > 0 else np.array([])
    hotspot_vals = dens.flatten()[mask] if mask.size > 0 else np.array([])

    traces = []
    # CHL heat (subsampled markers)
    traces.append(go.Scattergeo(
        lon=grid_lons_sub, lat=grid_lats_sub,
        mode='markers',
        marker=dict(size=3, color=chl_flat_sub, colorscale='Viridis', opacity=0.75, showscale=True, colorbar=dict(title='CHL')),
        hoverinfo='text',
        text=[f"CHL: {v:.2f}" for v in chl_flat_sub],
        showlegend=False
    ))

    # hotspots = red translucent markers
    if len(hotspot_lats) > 0:
        sizes = [max(10, min(60, 6 + v * 10)) for v in hotspot_vals]
    else:
        sizes = []
    traces.append(go.Scattergeo(
        lon=list(hotspot_lons), lat=list(hotspot_lats),
        mode='markers',
        marker=dict(size=sizes, color='rgba(220,20,20,0.55)', line=dict(width=0)),
        hoverinfo='text',
        text=[f"Hotspot density: {int(v)}" for v in hotspot_vals],
        name='Hotspots',
        showlegend=False
    ))

    # sharks: trajetórias até t, posição atual, predição do próximo
    for sid in range(N_SHARKS):
        name = f'shark_{sid}'
        g = df[(df['shark_id'] == name) & (df['step'] <= t)].sort_values('step')
        if len(g) > 0:
            # linha + markers (histórico)
            traces.append(go.Scattergeo(
                lon=g['lon'], lat=g['lat'],
                mode='lines+markers',
                line=dict(width=2),
                marker=dict(size=4),
                hoverinfo='text',
                text=[f"{name} step {int(r)}" for r in g['step']],
                showlegend=False
            ))
            # current
            last = g.iloc[-1]
            traces.append(go.Scattergeo(
                lon=[last['lon']], lat=[last['lat']],
                mode='markers',
                marker=dict(size=9, symbol='circle'),
                hoverinfo='text',
                text=[f"{name} t={int(last['step'])}<br>CHL:{last['chl']:.2f}"],
                showlegend=False
            ))
            # predicted next (se possível)
            history = df[(df['shark_id'] == name) & (df['step'] <= t)].sort_values('step')
            if len(history) >= 3:
                last3 = history.tail(3).reset_index(drop=True)
                feat = {}
                for j in range(3):
                    feat[f'lat_lag_{j}'] = last3.loc[j, 'lat']
                    feat[f'lon_lag_{j}'] = last3.loc[j, 'lon']
                for col in ['chl','ssha','sst','current','depth','temp_sensor','odba']:
                    feat[col] = last3.loc[2, col]
                Xf = pd.DataFrame([feat])[feature_cols]
                Xfs = models_obj['scaler'].transform(Xf)
                plat = float(models_obj['rf_lat'].predict(Xfs)[0])
                plon = float(models_obj['rf_lon'].predict(Xfs)[0])
                # predicted marker (star)
                traces.append(go.Scattergeo(
                    lon=[plon], lat=[plat],
                    mode='markers',
                    marker=dict(size=12, symbol='star', color='red'),
                    hoverinfo='text',
                    text=[f"Pred {name}"],
                    showlegend=False
                ))
                # dashed connector
                traces.append(go.Scattergeo(
                    lon=[last['lon'], plon],
                    lat=[last['lat'], plat],
                    mode='lines',
                    line=dict(width=2, dash='dash', color='red'),
                    hoverinfo='none',
                    showlegend=False
                ))

    frames.append(go.Frame(data=traces, name=str(t)))

# 6) Montar figura Plotly com slider e play/pause
if len(frames) == 0:
    raise RuntimeError("Sem frames para animar (verifique dados).")

fig = go.Figure(data=frames[0].data, frames=frames)

steps = []
for i, fr in enumerate(frames):
    steps.append(dict(
        method="animate",
        args=[[fr.name], {"mode":"immediate", "frame":{"duration":300, "redraw":True}, "transition":{"duration":0}}],
        label=str(fr.name)
    ))
sliders = [dict(active=0, pad={"t": 50}, steps=steps, currentvalue={"prefix":"step: "})]

fig.update_layout(
    title="Shark sim — CHL heat, dynamic hotspots, tracks & next-step predictions",
    geo=dict(projection_type='mercator', showland=True, landcolor='rgb(230,230,230)',
             lataxis=dict(range=[LAT_MIN, LAT_MAX]), lonaxis=dict(range=[LON_MIN, LON_MAX])),
    updatemenus=[dict(type="buttons", buttons=[
        dict(label="Play", method="animate", args=[None, {"frame":{"duration":300, "redraw":True}, "fromcurrent":True}]),
        dict(label="Pause", method="animate", args=[[None], {"frame":{"duration":0, "redraw":False}, "mode":"immediate"}])
    ])],
    sliders=sliders,
    height=700
)

# mostrar inline
fig.show()

# salvar HTML interativo (Plotly)
plotly_path = "/content/shark_simulation_plotly.html"
pio.write_html(fig, file=plotly_path, auto_open=False)
print("Plotly salvo:", plotly_path)

# 7) Folium (satellite) — TimestampedGeoJson para pontos e hotspots (leve)
print("Construindo mapa Folium (satellite). Pode demorar alguns segundos...")

# posições como GeoJSON features
features_positions = []
for _, row in df.iterrows():
    features_positions.append({
        "type": "Feature",
        "geometry": {"type": "Point", "coordinates": [float(row['lon']), float(row['lat'])]},
        "properties": {
            "time": row['timestamp'],
            "popup": f"{row['shark_id']} t={int(row['step'])} CHL={row['chl']:.2f}",
            "icon": "circle",
            "shark_id": row['shark_id']
        }
    })

# hotspots features (per time)
hotspot_features = []
for t in times:
    sub = df[df['step'] == t]
    if len(sub) > 0:
        H, yedges, xedges = np.histogram2d(sub['lat'], sub['lon'], bins=[lat_edges, lon_edges])
    else:
        H = np.zeros((len(lat_edges)-1, len(lon_edges)-1))
        yedges = lat_edges; xedges = lon_edges
    dens_flat = H.flatten()
    if dens_flat.max() > 0:
        nonzero = dens_flat[dens_flat > 0]
        thr = np.percentile(nonzero, 90) if len(nonzero) > 0 else np.inf
        mask = (dens_flat >= thr) & (dens_flat > 0)
    else:
        mask = np.array([False] * dens_flat.shape[0])
    lat_centers_hist = 0.5 * (yedges[:-1] + yedges[1:])
    lon_centers_hist = 0.5 * (xedges[:-1] + xedges[1:])
    LonC, LatC = np.meshgrid(lon_centers_hist, lat_centers_hist)
    for idx, m in enumerate(mask):
        if m:
            hotspot_features.append({
                "type": "Feature",
                "geometry": {"type": "Point", "coordinates": [float(LonC.flatten()[idx]), float(LatC.flatten()[idx])]},
                "properties": {"time": (START_DATE + timedelta(hours=int(t * PING_HOURS))).isoformat(),
                               "style": {"color": "red", "radius": 8},
                               "popup": f"hotspot density {int(H.flatten()[idx])}"}
            })

positions_geojson = {"type": "FeatureCollection", "features": features_positions}
hotspots_geojson = {"type": "FeatureCollection", "features": hotspot_features}

# montar mapa Folium centrado (0,0)
m = folium.Map(location=[0, 0], zoom_start=2, tiles=None)
# camada satellite (Esri)
folium.TileLayer(tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
                 attr='Esri', name='Esri.WorldImagery', overlay=False, control=True).add_to(m)

# adicionar uma amostra de pontos CHL do tempo 0 (base visual)
ssh0 = ssh_field(0); chl0 = chl_field(0, ssh0)
sample = []
for i in range(0, GRID_LAT, max(1, GRID_LAT // 16)):
    for j in range(0, GRID_LON, max(1, GRID_LON // 32)):
        sample.append((float(LatG[i,j]), float(LonG[i,j]), float(chl0[i,j])))
for la, lo, cval in random.sample(sample, min(len(sample), 800)):
    folium.CircleMarker(location=[la, lo], radius=2, fill=True, fill_opacity=0.7, popup=f"CHL: {cval:.2f}").add_to(m)

# TimestampedGeoJson for positions and hotspots
TimestampedGeoJson(positions_geojson, period='PT3H', add_last_point=True, auto_play=False, loop=False).add_to(m)
TimestampedGeoJson(hotspots_geojson, period='PT3H', add_last_point=False, auto_play=False, loop=False).add_to(m)
folium.LayerControl().add_to(m)

folium_path = "/content/shark_simulation_folium.html"
m.save(folium_path)
print("Folium salvo:", folium_path)

# 8) salvar CSV com pings e mostrar previsões atuais (último step)
df.to_csv("/content/shark_simulation_data.csv", index=False)
print("CSV salvo: /content/shark_simulation_data.csv")

# calcular próximas predições para última posição de cada tubarão e mostrar
last_positions = df.groupby('shark_id').apply(lambda g: g.sort_values('step').tail(1)).reset_index(drop=True)
pred_rows = []
for _, row in last_positions.iterrows():
    sid = row['shark_id']
    g = df[df['shark_id'] == sid].sort_values('step')
    if len(g) >= 3:
        last3 = g.tail(3).reset_index(drop=True)
        feat = {}
        for j in range(3):
            feat[f'lat_lag_{j}'] = last3.loc[j, 'lat']
            feat[f'lon_lag_{j}'] = last3.loc[j, 'lon']
        for col in ['chl','ssha','sst','current','depth','temp_sensor','odba']:
            feat[col] = last3.loc[2, col]
        Xf = pd.DataFrame([feat])[feature_cols]
        Xfs = models_obj['scaler'].transform(Xf)
        plat = float(models_obj['rf_lat'].predict(Xfs)[0])
        plon = float(models_obj['rf_lon'].predict(Xfs)[0])
        pred_rows.append({'shark_id': sid, 'last_lat': last3.loc[2,'lat'], 'last_lon': last3.loc[2,'lon'], 'pred_lat': plat, 'pred_lon': plon})
pred_df = pd.DataFrame(pred_rows)
print("\nPredições de próximo passo (últimas posições):")
print(pred_df)

# links de download / abertura
display(HTML(f"<a href='{plotly_path}' target='_blank'>Abrir Plotly interativo (HTML)</a>"))
display(HTML(f"<a href='{folium_path}' target='_blank'>Abrir Folium (satellite) map</a>"))
display(HTML(f"<a href='/content/shark_simulation_data.csv' target='_blank'>Download CSV</a>"))