In [1]:
import os
import torch
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader

from whale_imitation import *

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [2]:
CSV_PATH = "sperm-whale-dialogues.csv"

assert os.path.exists(CSV_PATH), f"Missing {CSV_PATH}. Put it next to the notebook or update CSV_PATH."

df = load_csv(CSV_PATH)
df.head(), df.shape

(             REC  nClicks  Duration      ICI1      ICI2      ICI3      ICI4  \
 0  sw061b001_124       17  0.927792  0.033950  0.036117  0.048867  0.048908   
 1  sw061b001_124       20  1.092992  0.032817  0.039358  0.045325  0.049933   
 2  sw061b001_124        5  0.898250  0.275850  0.286783  0.197400  0.138217   
 3  sw061b001_124        5  0.865575  0.262275  0.276883  0.183333  0.143083   
 4  sw061b001_124        5  0.858317  0.266675  0.269358  0.178325  0.143958   
 
        ICI5      ICI6      ICI7  ...  ICI21  ICI22  ICI23  ICI24  ICI25  \
 0  0.042275  0.040483  0.040975  ...    0.0    0.0    0.0    0.0    0.0   
 1  0.044083  0.044783  0.043683  ...    0.0    0.0    0.0    0.0    0.0   
 2  0.000000  0.000000  0.000000  ...    0.0    0.0    0.0    0.0    0.0   
 3  0.000000  0.000000  0.000000  ...    0.0    0.0    0.0    0.0    0.0   
 4  0.000000  0.000000  0.000000  ...    0.0    0.0    0.0    0.0    0.0   
 
    ICI26  ICI27  ICI28  Whale      TsTo  
 0    0.0    0.0 

In [3]:
train_recs, val_recs, test_recs = split_recs(df, val_frac=0.1, test_frac=0.0, seed=0)

print("Num conversations (REC):", df["REC"].nunique())
print("Train RECs:", len(train_recs))
print("Val RECs:", len(val_recs))
print("Test RECs:", len(test_recs))

Num conversations (REC): 219
Train RECs: 197
Val RECs: 22
Test RECs: 0


In [4]:
cfg = BCConfig(
    hidden_size=256,
    num_layers=1,
    dropout=0.0,
    use_log_ici=True,
)

train_ds = WhaleBCDataset(CSV_PATH, cfg=cfg, recs=train_recs)
val_ds   = WhaleBCDataset(CSV_PATH, cfg=cfg, recs=val_recs, ici_mean=train_ds.ici_mean, ici_std=train_ds.ici_std)

print("Train episodes (rows):", len(train_ds))
print("Val episodes (rows):", len(val_ds))
print("Num whales (mapped):", train_ds.n_whales)
print("ICI norm mean/std:", train_ds.ici_mean, train_ds.ici_std)

Train episodes (rows): 3521
Val episodes (rows): 319
Num whales (mapped): 10
ICI norm mean/std: -1.8894754295816543 0.8070952322128051


In [5]:
BATCH_TRAIN = 32
BATCH_VAL   = 64

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_TRAIN,
    shuffle=True,
    collate_fn=lambda b: collate_bc(b, cfg=cfg),
    num_workers=0,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_VAL,
    shuffle=False,
    collate_fn=lambda b: collate_bc(b, cfg=cfg),
    num_workers=0,
)

In [6]:
model = GRUBCModel(n_whales=train_ds.n_whales, cfg=cfg)
model

GRUBCModel(
  (whale_emb): Embedding(10, 32)
  (type_emb): Embedding(4, 16)
  (gru): GRU(49, 256, batch_first=True)
  (head_eos): Linear(in_features=256, out_features=1, bias=True)
  (head_ici): Linear(in_features=256, out_features=1, bias=True)
)

In [7]:
pre_val = evaluate_bc(model, val_loader, device=device)
pre_val

{'loss_total': 1.7427553415298462,
 'loss_eos': 0.7539610266685486,
 'loss_ici': 0.9887943267822266}

In [8]:
history = train_bc(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=10,
    lr=1e-3,
    device=device,
)

history[:2], history[-2:]

  from .autonotebook import tqdm as notebook_tqdm
                                                                                                        

[BC] ep 1/10 train_total=0.7314 train_eos=0.3216 train_ici=0.4099 | val_total=0.4460 val_eos=0.1657 val_ici=0.2803


                                                                                                        

[BC] ep 2/10 train_total=0.4285 train_eos=0.1953 train_ici=0.2332 | val_total=0.3841 val_eos=0.1293 val_ici=0.2548


                                                                                                        

[BC] ep 3/10 train_total=0.3584 train_eos=0.1459 train_ici=0.2125 | val_total=0.3379 val_eos=0.1067 val_ici=0.2312


                                                                                                        

[BC] ep 4/10 train_total=0.3798 train_eos=0.1305 train_ici=0.2493 | val_total=0.3318 val_eos=0.1011 val_ici=0.2307


                                                                                                        

[BC] ep 5/10 train_total=0.3150 train_eos=0.1134 train_ici=0.2016 | val_total=0.3212 val_eos=0.0963 val_ici=0.2249


                                                                                                       

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

epochs = [h["epoch"] for h in history]
val_total = [h.get("val_loss_total", np.nan) for h in history]
val_eos   = [h.get("val_loss_eos", np.nan) for h in history]
val_ici   = [h.get("val_loss_ici", np.nan) for h in history]

plt.figure()
plt.plot(epochs, val_total, label="val_total")
plt.plot(epochs, val_eos, label="val_eos")
plt.plot(epochs, val_ici, label="val_ici")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.title("BC validation losses")
plt.show()

In [None]:
# Choose an episode index from val set
episode_index = 0

ex = val_ds[episode_index]

# Find SOC index; everything before it is history
soc_idx = ex["tok_types"].index(cfg.TOK_SOC)

history_tokens = list(zip(
    ex["whale_ids"][:soc_idx],
    ex["tok_types"][:soc_idx],
    ex["ici_feats"][:soc_idx],
))

current_whale = ex["whale_ids"][soc_idx]

gen = rollout_coda(
    model=model,
    history_tokens=history_tokens,
    current_whale=current_whale,
    max_len=cfg.max_ici_cols,  # max ICIs to generate
    ici_mean=val_ds.ici_mean,
    ici_std=val_ds.ici_std,
    eos_threshold=0.5,
    device=device,
)

print("Generated ICI count:", len(gen))
print("Generated ICIs:", gen[:10], "..." if len(gen) > 10 else "")

In [None]:
# Extract the real coda ICIs for this episode from the dataset tensors
# ex contains normalized feats; convert back to raw ICIs for the real ICIs after SOC until EOS.

real_feats = []
for tt, f in zip(ex["tok_types"][soc_idx+1:], ex["ici_feats"][soc_idx+1:]):
    if tt == cfg.TOK_EOS:
        break
    if tt == cfg.TOK_ICI:
        real_feats.append(float(f))

from whale_imitation import feat_to_ici

real_ici = [feat_to_ici(f, cfg, val_ds.ici_mean, val_ds.ici_std) for f in real_feats]

print("Real ICI count:", len(real_ici))
print("Real ICIs:", real_ici[:10], "..." if len(real_ici) > 10 else "")

# Simple side-by-side for first min length
m = min(len(real_ici), len(gen))
pairs = list(zip(real_ici[:m], gen[:m]))
pairs[:10]

In [None]:
from whale_imitation import (
    GRUDiscriminator, DiscConfig,
    make_disc_loader_from_dataset,
    train_discriminator, evaluate_discriminator,
)

# How many episode pairs to sample for the discriminator dataset
NUM_PAIRS = min(1000, len(val_ds))   # increase if you have enough data
DISC_BATCH = 32
MAX_GEN_LEN = cfg.max_ici_cols
EOS_THRESH = 0.5

disc_loader = make_disc_loader_from_dataset(
    ds=val_ds,
    bc_model=model,
    num_pairs=NUM_PAIRS,
    batch_size=DISC_BATCH,
    max_len=MAX_GEN_LEN,
    eos_threshold=EOS_THRESH,
    shuffle=True,
    seed=0,
    device_for_generation=device,   # generation can be on GPU if available
)

print("Discriminator examples (real+fake):", 2 * NUM_PAIRS)

In [None]:
disc_cfg = DiscConfig(
    hidden_size=256,
    num_layers=1,
    dropout=0.0,
)

disc = GRUDiscriminator(
    n_whales=val_ds.n_whales,
    cfg=disc_cfg,
    bc_cfg=cfg,
)

pre = evaluate_discriminator(disc, disc_loader, device=device)
pre

In [None]:
disc_hist = train_discriminator(
    disc=disc,
    loader=disc_loader,
    epochs=5,
    lr=1e-3,
    device=device,
)

disc_hist

In [None]:
post = evaluate_discriminator(disc, disc_loader, device=device)
post

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot([h["epoch"] for h in disc_hist], [h["disc_loss"] for h in disc_hist])
plt.xlabel("epoch")
plt.ylabel("disc_loss")
plt.title("Discriminator training loss")
plt.show()

In [None]:
print("Discriminator accuracy:", post["disc_acc"])
print("Discriminator AUROC:", post["disc_auc"])

if np.isfinite(post["disc_auc"]):
    if post["disc_auc"] > 0.8:
        print("-> Discriminator can easily tell real vs generated (BC rollouts not very realistic yet).")
    elif post["disc_auc"] > 0.65:
        print("-> Discriminator has moderate power; BC is partially matching structure.")
    else:
        print("-> Discriminator struggles to separate; generated episodes look relatively real (or disc is underfit).")