In [1]:
from datastreams import generate_crypto_time_series
from models import *
import numpy as np
import pandas as pd
from driftdetector import DriftDetector
from tqdm import tqdm

In [11]:
def make_dataset(data, seq_len):
    time_series = []
    for i in range(len(data) - seq_len + 1):
        time_series.append(data[i:i + seq_len])
    return np.array(time_series)[:-1], data[seq_len:], data[seq_len-1:-1]

In [3]:
def increase_lr_only(optimizer, drifted, decay_state, factor=2.0, max_lr=0.01):
    if drifted:
        for i, pg in enumerate(optimizer.param_groups):
            base_lr = decay_state['base_lrs'][i]
            pg['lr'] = min(base_lr * factor, max_lr)

def increase_and_decay_lr(optimizer, drifted, decay_state, factor=2.0, decay_steps=10):
    base_lrs = decay_state['base_lrs']
    
    if drifted:
        for i, pg in enumerate(optimizer.param_groups):
            pg['lr'] = base_lrs[i] * factor
        decay_state['current_decay_step'] = decay_steps

    elif decay_state['current_decay_step'] > 0:
        decay_state['current_decay_step'] -= 1
        for i, pg in enumerate(optimizer.param_groups):
            boosted_lr = base_lrs[i] * factor
            ratio = decay_state['current_decay_step'] / decay_steps
            pg['lr'] = base_lrs[i] + (boosted_lr - base_lrs[i]) * ratio

In [24]:
def train_model(model, detector, x_all, y_all, optimizer, loss_fn, device, drift_action):

    model.train()
    preds = []
    all_loss = []

    # For drift action state
    base_lrs = [pg['lr'] for pg in optimizer.param_groups]
    decay_state = {
        'base_lrs': base_lrs,
        'current_decay_step': 0
    }

    step = 0
    mae_values = []
    mae_values_full = []

    for x, y in zip(x_all, y_all):
        step += 1
        x_ten = torch.tensor(x, dtype=torch.float32).view(1, -1).to(device)
        y_ten = torch.tensor(y, dtype=torch.float32).view(1, -1).to(device)
        # Training step
        optimizer.zero_grad()
        pred = model(x_ten)
        preds.append(pred.item())
        loss = loss_fn(pred, y_ten)
        loss.backward()
        optimizer.step()
        all_loss.append(loss.item())

        # Drift handling
        if detector.update(loss.item()):
            drift_action(optimizer, drifted=True, decay_state=decay_state)

        drift_action(optimizer, drifted=False, decay_state=decay_state)
        
        mae_full = torch.abs(pred - y_ten).mean().item()
        mae_values_full.append(mae_full)

        # Post-warmup MAE tracking
        if step >= 200:
            mae_values.append(mae_full)

    avg_loss = sum(all_loss) / len(all_loss) if all_loss else 0.0
    avg_mae_post_warmup = sum(mae_values) / len(mae_values) if mae_values else 0.0
    avg_mae_full = sum(mae_values_full) / len(mae_values_full) if mae_values_full else 0.0


    return model, avg_loss, avg_mae_post_warmup, avg_mae_full, preds

def simulate(prices, preds):
    budget = 100
    amount = 0
    max_budget = 100
    for i in range(len(prices)):
        if preds[i] > prices[i] and amount == 0:
            amount = budget / prices[i]
        if preds[i] < prices[i] and amount > 0:
            budget = amount * prices[i]
            amount = 0

        if budget > max_budget:
            max_budget = budget

    return budget, max_budget

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
paths = ["../data/btc_data.csv", "../data/eth_data.csv"]
detector_types = ["ADWIN", "PageHinkley"]
drift_actions = [increase_lr_only, increase_and_decay_lr]
seq_lens = [16, 32, 64]

In [16]:
hidden_sizess = [[16], [32, 32], [64, 64]]

In [17]:
len(paths) * len(detector_types) * len(drift_actions) * len(seq_lens) * len(hidden_sizess)

72

In [None]:
for path in paths:
    data = generate_crypto_time_series(path)
    for detector_type in detector_types:
        detector = DriftDetector(method=detector_type)
        for drift_action in drift_actions:
            for seq_len in seq_lens:
                for hidden_sizes in hidden_sizess:
                    x_all, y_all, prices = make_dataset(data, seq_len)
                    model = TimeSeriesMLP(
                        input_size=seq_len,
                        hidden_sizes=hidden_sizes,
                        output_size=1,
                    )
                    model.to(device)
                    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                    loss_fn = nn.MSELoss()
                    model, loss, mae_full, mae_warmup, preds = train_model(
                        model,
                        detector,
                        x_all,
                        y_all,
                        optimizer,
                        loss_fn,
                        device=device,
                        drift_action=drift_action
                    )
                    budget, max_budget = simulate(prices, preds)
                    # save stream, n_dim, n_points, detector_type, seq_len, hidden_sizes, train_loss, test_loss to CVS file
                    with open("crypto_mlp.csv", "a") as f:
                        f.write(
                            f"{path.split("/")[-1].split("_")[0]};1;{len(data)};{detector_type};{seq_len};{hidden_sizes};{loss};{mae_full};{mae_warmup};{budget};{max_budget}\n"
                        )

In [30]:
len(preds), prices.shape, len(x_all), len(y_all)

(17449, (17449,), 17449, 17449)

In [None]:
def simulate(prices, preds):
    budget = 100
    amount = 0
    max_budget = 100
    for i in range(len(prices)):
        if preds[i] > prices[i] and amount == 0:
            amount = budget / prices[i]
        if preds[i] < prices[i] and amount > 0:
            budget = amount * prices[i]
            amount = 0

        if budget > max_budget:
            max_budget = budget

    return budget, max_budget

In [31]:
simulate(prices, preds)

(np.float64(177.86073128870297), np.float64(191.46093239838908))

In [33]:
preds

[1947.1392822265625,
 2514.072998046875,
 3001.7666015625,
 3454.382080078125,
 3931.270751953125,
 4389.62158203125,
 4875.89013671875,
 5321.4775390625,
 5808.173828125,
 6258.42919921875,
 6747.05224609375,
 7154.353515625,
 7630.01513671875,
 8074.1181640625,
 8509.576171875,
 8918.287109375,
 9339.6806640625,
 9771.3779296875,
 10204.8310546875,
 10599.875,
 11026.181640625,
 11416.5185546875,
 11765.890625,
 12086.8212890625,
 12426.5556640625,
 12773.7958984375,
 13100.720703125,
 13433.8515625,
 13777.40625,
 14141.7431640625,
 14511.4462890625,
 15211.6845703125,
 15914.166015625,
 16629.5390625,
 17306.09765625,
 17963.40625,
 18549.81640625,
 19178.13671875,
 19766.560546875,
 20372.15625,
 20953.833984375,
 21540.08984375,
 22119.529296875,
 22665.056640625,
 23201.21484375,
 23715.6015625,
 24196.67578125,
 24651.85546875,
 25092.017578125,
 25505.818359375,
 25897.435546875,
 26263.046875,
 26608.640625,
 26934.673828125,
 27225.78515625,
 27495.06640625,
 27714.1875,
 27