# Data Augmentation for Contrastive Learning (Astronomy)

This notebook implements domain-aware augmentations suitable for contrastive learning on astronomical time-series / multi-band observations. Each augmentation is provided as a function that accepts per-object time-series data (pandas DataFrame) and returns an augmented copy.

Assumptions: the input DataFrame should contain at least an identifier for the source (e.g. `object_id`), a time column (`time`, `jd`, `mjd`, or `obsdate`), a brightness column (`flux` or `mag`), and a filter/band column (`filter`). Each function performs checks and falls back gracefully if columns are missing.

In [None]:
# Imports and helper utilities
import os
import numpy as np
import pandas as pd
from copy import deepcopy
import random

# Small helper: find time / brightness / filter cols
def detect_columns(df):
    cols_lower = [c.lower() for c in df.columns]
    time_col = next((c for c in df.columns if c.lower() in ['time','obsdate','jd','mjd','epoch']), None)
    flux_col = next((c for c in df.columns if c.lower() in ['flux','flux_calib','mag','mag_calib','instrumental_flux']), None)
    filter_col = next((c for c in df.columns if 'filter' in c.lower() or 'band' in c.lower()), None)
    id_col = next((c for c in df.columns if c.lower() in ['object_id','objid','id','source_id','ipac_gid']), None)
    # return 'filter' key (was 'filt' previously) for consistency with augmentation functions
    return dict(time=time_col, flux=flux_col, filter=filter_col, id=id_col)

# Utility: ensure DataFrame sorted by time for each object
def sort_by_time(df, time_col):
    if time_col is None or time_col not in df.columns:
        return df
    try:
        return df.sort_values(by=[time_col]).reset_index(drop=True)
    except Exception:
        return df

## 4.1 Temporal Augmentations
- Temporal jittering: add small random noise to timestamps.
- Random time shifts: shift the whole sequence by a random amount.
- Random cropping: return a contiguous partial subsequence to simulate partial coverage.

In [2]:
def temporal_jitter(df, time_col, sigma_fraction=0.01, seed=None):
    """Add Gaussian jitter to timestamps (fraction of median cadence).
    Handles datetime-like columns by converting to seconds since epoch, applying jitter,
    and converting back to datetimes. Returns DataFrame with same time dtype where possible."""
    if time_col is None or time_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    # try to convert to datetime; if not possible, fall back to numeric as before
    try:
        times_dt = pd.to_datetime(df[time_col], errors='coerce')
        times_s = times_dt.astype('datetime64[ns]').astype('int64') / 1e9
        use_datetime = not times_dt.isna().all()
    except Exception:
        times_s = df[time_col].astype(float).values
        use_datetime = False
    times = np.array(times_s, dtype=float)
    diffs = np.diff(np.sort(times)) if len(times) > 1 else np.array([1.0])
    median_cadence = np.median(diffs) if len(diffs) > 0 else 1.0
    sigma = sigma_fraction * median_cadence
    jitter = rng.normal(loc=0.0, scale=sigma, size=times.shape)
    new_times = times + jitter
    df_aug = df.copy()
    if use_datetime:
        df_aug[time_col] = pd.to_datetime(new_times, unit='s')
    else:
        df_aug[time_col] = new_times
    return df_aug

def random_time_shift(df, time_col, shift_range=None, seed=None):
    """Shift the entire sequence by a random amount. shift_range can be (min,max) in same units as time_col. If None, uses +/- 0.5 * duration."""
    if time_col is None or time_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    try:
        times_dt = pd.to_datetime(df[time_col], errors='coerce')
        times_s = times_dt.astype('datetime64[ns]').astype('int64') / 1e9
        use_datetime = not times_dt.isna().all()
    except Exception:
        times_s = df[time_col].astype(float).values
        use_datetime = False
    times = np.array(times_s, dtype=float)
    duration = times.max() - times.min() if len(times) > 1 else 0.0
    if shift_range is None:
        shift_range = (-0.5 * duration, 0.5 * duration)
    shift = float(rng.uniform(shift_range[0], shift_range[1]))
    new_times = times + shift
    df_aug = df.copy()
    if use_datetime:
        df_aug[time_col] = pd.to_datetime(new_times, unit='s')
    else:
        df_aug[time_col] = new_times
    return df_aug

def random_crop(df, time_col, min_fraction=0.5, seed=None):
    """Return a contiguous subsequence of the timeseries (per-object)."""
    if time_col is None or time_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    n = len(df)
    if n < 2:
        return df
    min_len = max(1, int(np.ceil(min_fraction * n)))
    start = int(rng.integers(0, n - min_len + 1))
    end = int(rng.integers(start + min_len, n + 1))
    return df.iloc[start:end].reset_index(drop=True)

## 4.2 Magnitude Augmentations
- Magnitude scaling: multiply flux by a random factor.
- Brightness warping: apply smooth multiplicative warp across time to simulate calibration/seeing changes.

In [3]:
def magnitude_scaling(df, flux_col, scale_range=(0.8,1.2), seed=None):
    if flux_col is None or flux_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    factor = float(rng.uniform(scale_range[0], scale_range[1]))
    df_aug = df.copy()
    df_aug[flux_col] = df_aug[flux_col].astype(float) * factor
    return df_aug

def brightness_warping(df, time_col, flux_col, n_knots=3, warp_scale=0.1, seed=None):
    """Apply a smooth multiplicative warp across time using piecewise linear interpolation.
    Handles datetime-like time columns by converting to seconds for interpolation."""
    if flux_col is None or flux_col not in df.columns:
        return df
    df_aug = df.copy()
    # try to get numeric times in seconds
    use_datetime = False
    try:
        times_dt = pd.to_datetime(df_aug[time_col], errors='coerce')
        if not times_dt.isna().all():
            times = times_dt.astype('datetime64[ns]').astype('int64') / 1e9
            use_datetime = True
        else:
            times = df_aug[time_col].astype(float).values
    except Exception:
        times = df_aug[time_col].astype(float).values
    tmin, tmax = np.min(times), np.max(times)
    if tmax == tmin:
        return df_aug
    knots = np.linspace(tmin, tmax, n_knots)
    rng = np.random.default_rng(seed)
    knot_factors = rng.normal(loc=1.0, scale=warp_scale, size=len(knots))
    factors = np.interp(times, knots, knot_factors)
    df_aug[flux_col] = df_aug[flux_col].astype(float) * factors
    return df_aug

## 4.3 Noise Augmentations
- Gaussian noise injection to flux.
- Photometric uncertainty simulation: add noise based on provided flux_err or an assumed S/N model.

In [4]:
def gaussian_noise_injection(df, flux_col, sigma_fraction=0.05, seed=None):
    if flux_col is None or flux_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    flux = df[flux_col].astype(float).values
    sigma = np.maximum(np.abs(flux) * sigma_fraction, 1e-8)
    noise = rng.normal(loc=0.0, scale=sigma)
    df_aug = df.copy()
    df_aug[flux_col] = flux + noise
    return df_aug

def photometric_uncertainty_simulation(df, flux_col, flux_err_col=None, seed=None):
    """If a flux_err column exists, perturb flux by that uncertainty; otherwise assume poisson or fractional error."""
    rng = np.random.default_rng(seed)
    if flux_col is None or flux_col not in df.columns:
        return df
    df_aug = df.copy()
    flux = df_aug[flux_col].astype(float).values
    if flux_err_col and flux_err_col in df_aug.columns:
        err = df_aug[flux_err_col].astype(float).values
        noise = rng.normal(loc=0.0, scale=err)
        df_aug[flux_col] = flux + noise
    else:
        rel_err = 0.05
        sigma = np.maximum(np.abs(flux) * rel_err, np.sqrt(np.maximum(flux, 0)) * 0.1 + 1e-8)
        noise = rng.normal(loc=0.0, scale=sigma)
        df_aug[flux_col] = flux + noise
    return df_aug

## 4.4 Multi-band Transformations
- Filter-dependent transformations: apply different augmentation strengths per band.
- Dropout of random bands: remove observations from random filters to mimic missing bands.

In [5]:
def filter_dependent_transform(df, flux_col, filter_col, per_filter_scale=None, seed=None):
    """Apply a multiplicative factor per filter. `per_filter_scale` is a dict {filter: (min,max)} or None to sample small variations."""
    if flux_col is None or flux_col not in df.columns or filter_col is None or filter_col not in df.columns:
        return df
    df_aug = df.copy()
    rng = np.random.default_rng(seed)
    unique_filters = df_aug[filter_col].dropna().unique()
    for f in unique_filters:
        mask = df_aug[filter_col] == f
        if per_filter_scale and f in per_filter_scale:
            lo, hi = per_filter_scale[f]
            factor = float(rng.uniform(lo, hi))
        else:
            factor = float(rng.uniform(0.9, 1.1))
        df_aug.loc[mask, flux_col] = df_aug.loc[mask, flux_col].astype(float) * factor
    return df_aug

def random_band_dropout(df, filter_col, dropout_prob=0.2, seed=None):
    """Randomly drop a fraction of observations for some bands. Returns a copy where dropped rows are removed."""
    if filter_col is None or filter_col not in df.columns:
        return df
    rng = np.random.default_rng(seed)
    df_aug = df.copy()
    mask = rng.random(size=df_aug.shape[0]) < dropout_prob
    df_aug = df_aug.loc[~mask].reset_index(drop=True)
    return df_aug

## Utilities: apply augmentations to grouped time-series and create contrastive pairs
The following helper applies augmentations per object (grouped by `id_col`) and returns a concatenated augmented dataset. You can call augmentations sequentially (composition) to create positive pairs for contrastive learning.

In [6]:
def apply_to_group(df_group, funcs):
    """Apply a list of augmentation functions (each accepts and returns a DataFrame for the group)."""
    g = df_group.copy().reset_index(drop=True)
    for f, kwargs in funcs:
        g = f(g, **kwargs) if kwargs is not None else f(g)
    return g

def augment_dataset(df, id_col=None, funcs_per_group=None, sample_frac=1.0, random_state=None):
    """Apply augmentations per object id and return augmented examples.
    - `funcs_per_group` should be a list of tuples (func, kwargs) to apply to each group."""
    if funcs_per_group is None:
        return df
    if id_col is None or id_col not in df.columns:
        # treat whole DF as one sequence
        return apply_to_group(df, funcs_per_group)
    groups = df.groupby(id_col)
    out_rows = []
    rng = np.random.default_rng(random_state)
    ids = list(groups.groups.keys())
    if sample_frac < 1.0:
        k = max(1, int(len(ids) * sample_frac))
        ids = rng.choice(ids, size=k, replace=False).tolist()
    for objid in ids:
        g = groups.get_group(objid)
        g_sorted = sort_by_time(g, detect_columns(g)['time'])
        g_aug = apply_to_group(g_sorted, funcs_per_group)
        out_rows.append(g_aug)
    if len(out_rows) == 0:
        return pd.DataFrame(columns=df.columns)
    return pd.concat(out_rows, ignore_index=True)

def make_contrastive_pair(group_df, augA, augB):
    a = apply_to_group(group_df, augA)
    b = apply_to_group(group_df, augB)
    return a, b

## Example usage and saving augmented samples
Below is an example that loads the cleaned dataset, selects a small sample of objects, applies a pipeline of augmentations, and saves augmented CSVs for later training.

In [7]:
# Example usage (will run if the cleaned dataset exists in workspace)
DATA_CLEANED = 'ztf_image_search_results_full_cleaned.csv'
if not os.path.exists(DATA_CLEANED):
    print(DATA_CLEANED, 'not found in workspace â€” update path or run preprocessing first')
else:
    df_all = pd.read_csv(DATA_CLEANED)
    cols = detect_columns(df_all)
    print('Detected columns:', cols)
    idc = cols['id']
    if idc is None or idc not in df_all.columns:
        sample_df = df_all
    else:
        unique_ids = df_all[idc].dropna().unique()
        sel_ids = unique_ids[:50] if len(unique_ids) > 50 else unique_ids
        sample_df = df_all[df_all[idc].isin(sel_ids)].reset_index(drop=True)

    # define two augmentation pipelines for contrastive positives
    augA = [
        (lambda g, **kw: temporal_jitter(g, cols['time'], sigma_fraction=0.01, seed=42), None),
        (lambda g, **kw: magnitude_scaling(g, cols['flux'], scale_range=(0.95,1.05), seed=42), None),
        (lambda g, **kw: gaussian_noise_injection(g, cols['flux'], sigma_fraction=0.03, seed=42), None)
    ]
    augB = [
        (lambda g, **kw: random_time_shift(g, cols['time'], seed=24), None),
        (lambda g, **kw: brightness_warping(g, cols['time'], cols['flux'], n_knots=4, warp_scale=0.08, seed=24), None),
        (lambda g, **kw: photometric_uncertainty_simulation(g, cols['flux'], flux_err_col=None, seed=24), None)
    ]
    # create augmented sets (this will return concatenated groups)
    aug_set_A = augment_dataset(sample_df, id_col=cols['id'], funcs_per_group=augA, sample_frac=1.0, random_state=42)
    aug_set_B = augment_dataset(sample_df, id_col=cols['id'], funcs_per_group=augB, sample_frac=1.0, random_state=24)
    # save examples
    aug_set_A.to_csv('augmented_A_sample.csv', index=False)
    aug_set_B.to_csv('augmented_B_sample.csv', index=False)
    print('Saved augmented_A_sample.csv and augmented_B_sample.csv (sample)')

Detected columns: {'time': 'obsdate', 'flux': None, 'filt': 'filtercode', 'id': 'ipac_gid'}


ValueError: could not convert string to float: '2018-03-25 06:35:35+00'