## Process the raw data

TODO: Need to figure out how to load the iterable dataset directly into the Trainer. It needs to know there are different patients in there too.

In [None]:
import pandas as pd
from_path = '/data/shared/cache/data/gluroo_2026/groups_backups/groups.csv.gz'
groups_df = pd.read_csv(from_path, compression='gzip')

# Get a summary of which gid values appear more than once
gid_counts = groups_df['gid'].value_counts()
duplicate_gid_values = gid_counts[gid_counts > 1]

print(f"Number of gid values that appear more than once: {len(duplicate_gid_values)}")
print(duplicate_gid_values)

In [None]:
# There is a bug with the data export script where if a group onboarded multiple times, we would have duplicated gid.
# Deduplicate by keeping the row with the highest index for each gid
print(f"Original shape: {groups_df.shape}")
print(f"Original unique gid count: {groups_df['gid'].nunique()}")

# Sort by index to ensure we keep the highest index, then drop duplicates keeping the last occurrence
groups_df_dedup = groups_df.sort_index().drop_duplicates(subset=['gid'], keep='last')
to_path = '/data/shared/cache/data/gluroo_2026/raw/groups.csv.gz'
groups_df_dedup.to_csv(to_path, compression='gzip', index=False)


Original shape: (102875, 22)
Original unique gid count: 102844


In [None]:
import pandas as pd
from_path = '/data/shared/cache/data/gluroo_2026/raw/groups.csv.gz'
# Read only the header to get column names without loading the entire file
groups_df = pd.read_csv(from_path, nrows=5)
groups_df.columns

In [None]:
updated_groups_path = '/data/shared/cache/data/gluroo_2026/raw/groups.csv.gz'
updated_groups_df = pd.read_csv(updated_groups_path, compression='gzip')
updated_groups_df

In [None]:
from src.data.diabetes_datasets.data_loader import get_loader

# add dependency: 
# 1. pip install psycopg2-binary

loader = get_loader(
    data_source_name="gluroo_2026",
    keep_columns=None,
    use_cached=False,

    # Testing for now
    patients_per_batch=2,
    patients_per_file=6,
    number_of_patients_to_process=4,
    min_date_span_days=30,
    load_all=True,
)

In [None]:
loader.processed_data

In [None]:
import pandas as pd

from src.data.models import ColumnNames

# Combine all patients' data into a single DataFrame, add a column for patient_id if not present
all_data = []
for p_num, df in loader.processed_data.items():
    # note that newly filled rows won't have p_num
    df[ColumnNames.P_NUM.value] = p_num
    all_data.append(df)
if all_data:
    df_all = pd.concat(all_data, ignore_index=True)
    df_all.to_csv("gluroo_processed_data.csv", index=False)
else:
    print("No processed data found in loader.processed_data.")

## Iterable Dataset - WIP

In [None]:
stream_ds = loader.get_hf_streaming_dataset(
    columns=["datetime", "p_num", "bg_mM", "food_g", "dose_units", "cob", "iob"],
    # patient_ids=["gluroo_1", "gluroo_2"],  # drop to None to check all data
    batch_size=1024,
    validate_non_empty=True,  # default; set False if you donâ€™t want the peek
)
# first_batch = next(iter(stream_ds))

In [None]:
# first_batch

In [None]:
import torch
from transformers import TrainingArguments
import os

from src.models.ttm.model import create_ttm_model
from src.models.ttm.ttm import get_model

# Build model and training args
ttm = create_ttm_model(model_path="ibm-granite/granite-timeseries-ttm-r2")
model = ttm.model
out_dir = "./out"
batch_size = 1024
finetune_forecast_args = TrainingArguments(
    output_dir=out_dir,
    overwrite_output_dir=False,
    # learning_rate=learning_rate,
    # num_train_epochs=num_epochs,
    do_eval=True,
    eval_strategy="steps",
    eval_steps=1000,  # Evaluate every 1000 steps (less frequent = faster training)
    fp16=False,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=1,
    remove_unused_columns=False,
    report_to="none",
    save_strategy="steps",
    logging_strategy="steps",
    logging_steps=100,  # Log every 100 steps
    logging_first_step=True,  # Log the first step
    save_steps=2000,  # Save checkpoints every 2000 steps
    save_total_limit=100,
    max_steps=5000,  # Required for streaming datasets (no __len__); LR scheduler needs known total steps
    logging_dir=os.path.join(out_dir, "logs"),
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    use_cpu=False,
    # Additional logging control
    log_level="info",  # Control log verbosity
    disable_tqdm=False,  # Keep progress bars
)
# fp16 is only supported on CUDA; disable on MPS/CPU to avoid "fp16 requires a GPU (not 'mps')"
if not torch.cuda.is_available():
    finetune_forecast_args.fp16 = False
    finetune_forecast_args.bf16 = False
    os.environ["ACCELERATE_MIXED_PRECISION"] = "no"

# # Build train dataset from notebook's ts_df (from stream_ds). TTM expects datetime, bg_mM, p_num.
# train_df = ts_df.reset_index().rename(
#     columns={"timestamp": "datetime", "target": "bg_mM", "item_id": "p_num"}
# )
# train_loader, val_loader, _ = ttm._prepare_data(train_data=train_df)
# train_dataset = train_loader.dataset if train_loader else None
# eval_dataset = val_loader.dataset if val_loader else Non
#
#
finetune_forecast_model = get_model(
    "ibm-granite/granite-timeseries-ttm-r2",
    context_length=512,
    prediction_length=96,
    freq_prefix_tuning=False,
    prefer_l1_loss=False,
    prefer_longer_context=True,
    # Can also provide TTM Config args. A param?
    loss="mse",
    quantile=0.5,
)

# Import the pickle-safe data collator from the gluroo module

# Use Trainer directly
# trainer = Trainer(
#     model=finetune_forecast_model,
#     args=finetune_forecast_args,
#     train_dataset=stream_ds,
#     eval_dataset=stream_ds,  # For testing
#     data_collator=gluroo_data_collator,
# )
# trainer.train()
# trainer.save_model()

In [None]:
import pandas as pd

# Read groups.csv.gz using pandas
groups_df = pd.read_csv('/data/shared/cache/data/gluroo_2026/raw/groups.csv.gz', compression='gzip')
