In [None]:
%load_ext autoreload
%autoreload 2

from darts.datasets import ETTh1Dataset
from darts.models import NLinearModel
from darts.metrics.metrics import mae, mse
import numpy as np
import pandas as pd
import torch
import random
import csv
import datetime
import os
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from sklearn.preprocessing import StandardScaler

import optuna
from optuna.trial import TrialState
from optuna.visualization import plot_optimization_history, plot_param_importances, plot_timeline

# from models import CBM, TaskType
from models_redesign import CBM_redesigned, TaskType
from preprocess_helpers import *
from helper import *
from param_initializations import *
from optimization_strategy import greedy_selection

device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device

In [None]:
series = ETTh1Dataset().load()

print(series.start_time())
print(series.end_time())

series.plot()

In [None]:
train_series, test_series = series.split_before(0.6)
val_series, test_series = test_series.split_before(0.5)


In [None]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data, targets, T, window_stride=1, pred_len=1):
        self.data = data
        self.targets = targets
        assert targets.size(0) == data.size(0)
        self.T = T # time window
        self.window_stride = window_stride
        self.pred_len = pred_len
        self.N, self.V = data.shape

    def __len__(self):
        return len(range(0, self.N - self.T - self.pred_len + 1, self.window_stride))

    def __getitem__(self, idx):
        start = idx * self.window_stride
        end = start + self.T

        X = self.data[start:end]
        y = self.targets[end:end + self.pred_len].squeeze(-1) # only OT
        # y = self.data[end:end + self.pred_len, :7].flatten() # all V
        return X, y


In [None]:
def preprocess_data(series, seq_len, window_stride=1, pred_len=1, batch_size = 1024):
    scaler = StandardScaler(with_std=False)
    
    train, test = series.split_before(0.6)
    val, test = test.split_before(0.5)
    
    print("Train/Val/Test", len(train), len(val), len(test))
    
    train_og = train.pd_dataframe()
    train = scaler.fit_transform(train_og)
    train = pd.DataFrame(train, columns=train_og.columns)
    X_train = train
    y_train = train[["OT"]]
    X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
    y_train = torch.tensor(y_train.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_train)
    X_train = torch.cat([X_train, indicators], axis=1)
    
    train_dataset = TimeSeriesDataset(X_train, y_train, seq_len, window_stride, pred_len)
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)

    val_og = val.pd_dataframe()
    val = scaler.transform(val_og)
    val = pd.DataFrame(val, columns=val_og.columns)
    X_val = val
    y_val = val[["OT"]]
    X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
    y_val = torch.tensor(y_val.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_val)
    X_val = torch.cat([X_val, indicators], axis=1)
    
    val_dataset = TimeSeriesDataset(X_val, y_val, seq_len, window_stride, pred_len)
    val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)

    test_og = test.pd_dataframe()
    test = scaler.transform(test_og)
    test = pd.DataFrame(test, columns=test_og.columns)
    X_test = test
    y_test = test[["OT"]]
    X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
    y_test = torch.tensor(y_test.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_test)
    X_test = torch.cat([X_test, indicators], axis=1)
    
    test_dataset = TimeSeriesDataset(X_test, y_test, seq_len, window_stride, pred_len)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, test_loader, scaler


In [None]:
seq_len = 10
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=24)

for X,y in train_loader:
    print(X.shape)
    print(y.shape)
    break

print("Batches", len(train_loader), len(val_loader), len(test_loader))

Train/Val/Test 10451 3484 3485
torch.Size([1024, 10, 14])
torch.Size([1024, 24])
Batches 11 4 4


In [None]:
def plot_losses(train_losses, val_losses):
    plt.plot(train_losses, color="black", label="Train")
    plt.plot(val_losses, color="green", label="Val")
    plt.legend()
    plt.show()


In [None]:
def initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim, top_k=''):
    logregbottleneck = CBM_redesigned(input_dim = input_dim, 
                                                changing_dim = changing_dim, 
                                                seq_len = seq_len,
                                                num_concepts = n_concepts,
                                                opt_lr = 3e-3, # 2e-4
                                                opt_weight_decay = 1e-05,
                                                l1_lambda=0.001,
                                                cos_sim_lambda=0.01,
                                                output_dim = output_dim,
                                                top_k=top_k,
                                                task_type=TaskType.REGRESSION,
                                                )
    logregbottleneck = logregbottleneck.to(device)
    return logregbottleneck

## Regression

In [None]:
seq_len = 96
pred_len = 96


In [None]:
experiment_folder = f"/workdir/optimal-summaries-public/vasopressor/models/etth1/redesign-multi2single-L{seq_len}-T{pred_len}/"
model_path = experiment_folder + "forecasting_c{}.pt"
random_seed = 1

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [None]:
history_binary = []

set_seed(random_seed)

changing_dim = len(series.columns)
input_dim = 2 * changing_dim
n_concepts_list = list(range(2,21,2)) + list(np.arange(50,401,50))

train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)

for n_concepts in n_concepts_list:
    print("n_concepts", n_concepts)
    
    model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
    model.fit(train_loader, val_loader, None, model_path.format(n_concepts), 5000)
    
    model_shared.eval()
    with torch.no_grad():
        for batch_idx, (Xb, yb) in enumerate(test_loader):
            Xb, yb = Xb.to(device), yb.to(device)
            preds = model_shared.forward(Xb)
            
            mae = mae_metric(preds, yb).item()
            mse = mse_metric(preds, yb).item()
        mae = mae_metric.compute().item()
        mse = mse_metric.compute().item()
        mae_metric.reset()
        mse_metric.reset()
    
    history = [n_concepts, round(model_shared.val_losses[-1],2), round(mae,2), round(mse,2)]
    display(history)
    history_binary.append(np.array(history))
    
    plot_losses(model_shared.train_losses, model_shared.val_losses)
    
history_binary = np.array(history_binary)
history_binary.shape


Train/Val/Test 10451 3484 3485
n_concepts 2
Loaded model from /workdir/optimal-summaries-public/vasopressor/models/etth1/redesign-multi2single-L96-T96/forecasting_c2.pt


  0%|          | 0/4800 [00:00<?, ?it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])


  0%|          | 1/4800 [00:00<51:04,  1.57it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 2/4800 [00:00<33:39,  2.38it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 3/4800 [00:01<27:33,  2.90it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 4/4800 [00:01<25:17,  3.16it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 5/4800 [00:01<23:50,  3.35it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 6/4800 [00:01<22:54,  3.49it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 7/4800 [00:02<22:12,  3.60it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 8/4800 [00:02<22:12,  3.60it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 9/4800 [00:02<21:17,  3.75it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 10/4800 [00:02<20:57,  3.81it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 11/4800 [00:03<20:44,  3.85it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 12/4800 [00:03<20:36,  3.87it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 13/4800 [00:03<20:42,  3.85it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 14/4800 [00:04<20:32,  3.88it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 15/4800 [00:04<20:31,  3.88it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 16/4800 [00:04<20:18,  3.93it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 17/4800 [00:04<20:20,  3.92it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 18/4800 [00:05<20:11,  3.95it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 19/4800 [00:05<20:27,  3.89it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 20/4800 [00:05<20:34,  3.87it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 21/4800 [00:05<21:25,  3.72it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 22/4800 [00:06<20:47,  3.83it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 23/4800 [00:06<20:39,  3.85it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  0%|          | 24/4800 [00:06<20:28,  3.89it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  1%|          | 25/4800 [00:06<20:14,  3.93it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  1%|          | 26/4800 [00:07<20:28,  3.89it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  1%|          | 27/4800 [00:07<20:33,  3.87it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


  1%|          | 28/4800 [00:07<20:05,  3.96it/s]

torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([1024, 96, 14])
torch.Size([1024, 96, 98])
torch.Size([20, 96, 14])
torch.Size([20, 96, 98])


In [None]:
# plot
plt.plot(history_binary[:, 0], history_binary[:, 2], label='MAE')
plt.plot(history_binary[:, 0], history_binary[:, 3], label='MSE')

plt.xlabel('Num Concepts')
plt.ylabel('Criteria')
plt.title('Plot of Concepts vs Criteria')
plt.xticks(n_concepts_list)
plt.xscale('log')

for x,_y in zip(history_binary[:, 0], history_binary[:, 2]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center
    
for x,_y in zip(history_binary[:, 0], history_binary[:, 3]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,-10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center
    
plt.legend()
plt.show()


In [None]:
# Plot Prediction vs actual
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)
n_concepts = 2

model_shared = initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
model_shared.fit(train_loader, val_loader, None, model_path.format(n_concepts), 10000)

model_shared.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(val_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        preds = model_shared.forward(Xb)
        
        mae = mae_metric(preds, yb).item()
        mse = mse_metric(preds, yb).item()
        break
    mae = mae_metric.compute().item()
    mse = mse_metric.compute().item()
    mae_metric.reset()
    mse_metric.reset()


i = 20
yb = yb.cpu().numpy()[i]
preds = preds.cpu().numpy()[i]

print(yb.shape)
print(preds.shape)

plt.plot(yb, color="black", label="True")
plt.plot(preds, color="red", label="Pred")
plt.legend()
plt.show()


In [None]:
# feature weights
n_concepts = 5

model_shared = initializeModel(n_concepts, input_dim, changing_dim, seq_len)
model_shared.fit(train_loader, val_loader, None, model_path.format(n_concepts), 1000)

for name, param in model_shared.named_parameters():
    if "bottleneck.weight" in name:
        bottleneck_weights = param
feature_weights = bottleneck_weights.cpu().detach().numpy()

feature_weights.shape

In [None]:
# visualize weight magnitudes
for c in range(n_concepts):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    inds = np.argsort(-np.abs(feature_weights[c]))[:100]
    ax.bar(np.arange(1,101),np.abs(feature_weights[c])[inds])
    ax.set_xlabel("Top 100 features")
    ax.set_ylabel("abs value of feature coefficient")
    plt.show()


In [None]:
# get 90th percentile of feature weights
sum90p = np.sum(np.abs(feature_weights), axis=-1)*0.90
sum90p.shape


In [None]:
# get top K indizes
top_k_inds = []
for c in range(n_concepts):
    topkinds_conc = []
    curr_sum = 0
    inds = np.argsort(-np.abs(feature_weights[c])) #desc
    sorted_weights = feature_weights[c][inds]
    
    for ind, weight in zip(inds, sorted_weights):
        curr_sum += abs(weight)
        if curr_sum <= sum90p[c]:
            topkinds_conc.append(ind)
        else:
            break
    
    # if selects less than 10, choose 10 best
    if len(topkinds_conc) < 10:
        topkinds_conc = np.argsort(-np.abs(feature_weights[c]))[:10].tolist()
    
    top_k_inds.append(topkinds_conc)

top_k_inds

In [None]:
# write top k inds to csv
filename = experiment_folder + "top-k/top_k_inds_c{}.csv".format(n_concepts)

directory = os.path.dirname(filename)
if not os.path.exists(directory):
    os.makedirs(directory)

# writing to csv file 
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    # writing the data rows 
    csvwriter.writerows(top_k_inds)


In [None]:
best_aucs, best_auc_inds, best_auc_concepts = greedy_selection(auroc_metric, test_loader, top_k_inds, model_shared)


In [None]:
filename = experiment_folder + "top-k/bottleneck_r{}_c{}_topkinds.csv".format(random_seed, n_concepts)

# writing to csv file
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Best AUC", "Best AUC Concept #", "Best AUC ind #"])
    # writing the data rows 
    for row in zip(best_aucs, best_auc_concepts, best_auc_inds):
        csvwriter.writerow(list(row))


## Multi-class

In [None]:
experiment_folder = "/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/"
model_path = experiment_folder + "arabic_c{}.pt"
random_seed = 1

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [None]:
history_multiclass = []

set_seed(random_seed)

data, y_ohe, num_classes, weights = preprocess_data_multiclass(X, y)
train_loader, val_loader, test_loader = initialize_data(1, data, y_ohe, multiclass=True)

input_dim = data.shape[2]
changing_dim = X[0].shape[0]
seq_len = data.shape[1]

auroc_metric = AUROC(task="multiclass", num_classes=num_classes).to(device)
accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes).to(device)

for n_concepts in range(1,16):
    print(n_concepts)
    
    model_shared = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes)
    model_shared.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)
    
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model_shared.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb).item()
        acc = accuracy_metric(probs, yb).item()
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()
    
    history = [n_concepts, model_shared.val_losses[-1], auc, acc]
    history_multiclass.append(np.array(history))
history_multiclass = np.array(history_multiclass)
history_multiclass.shape


In [None]:
# plot
plt.plot(history_multiclass[:, 0], history_multiclass[:, 2], label='AUC')
plt.plot(history_multiclass[:, 0], history_multiclass[:, 3], label='ACC')

plt.xlabel('Num Concepts')
plt.ylabel('Criteria')
plt.title('Plot of Concepts vs Criteria')
plt.xticks(np.arange(min(history_multiclass[:, 0]), max(history_multiclass[:, 0])+1, 1))

for x,_y in zip(history_multiclass[:, 0], history_multiclass[:, 2]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center
    
for x,_y in zip(history_multiclass[:, 0], history_multiclass[:, 3]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,-10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center

plt.legend()
plt.show()


In [None]:
# feature weights
n_concepts = 5

model_shared = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes)
model_shared.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)

for name, param in model_shared.named_parameters():
    if "bottleneck.weight" in name:
        bottleneck_weights = param
feature_weights = bottleneck_weights.cpu().detach().numpy()

feature_weights.shape

In [None]:
# visualize weight magnitudes
for c in range(n_concepts):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    inds = np.argsort(-np.abs(feature_weights[c]))[:100]
    ax.bar(np.arange(1,101),np.abs(feature_weights[c])[inds])
    ax.set_xlabel("Top 100 features")
    ax.set_ylabel("abs value of feature coefficient")
    plt.show()


In [None]:
# get 90th percentile of feature weights
sum90p = np.sum(np.abs(feature_weights), axis=-1)*0.90
sum90p.shape


In [None]:
# get top K indizes
top_k_inds = []
for c in range(n_concepts):
    topkinds_conc = []
    curr_sum = 0
    inds = np.argsort(-np.abs(feature_weights[c])) #desc
    sorted_weights = feature_weights[c][inds]
    
    for ind, weight in zip(inds, sorted_weights):
        curr_sum += abs(weight)
        if curr_sum <= sum90p[c]:
            topkinds_conc.append(ind)
        else:
            break
    
    # if selects less than 10, choose 10 best
    if len(topkinds_conc) < 10:
        topkinds_conc = np.argsort(-np.abs(feature_weights[c]))[:10].tolist()
    
    top_k_inds.append(topkinds_conc)

top_k_inds

In [None]:
# write top k inds to csv
filename = experiment_folder + "top-k/top_k_inds_c{}.csv".format(n_concepts)

directory = os.path.dirname(filename)
if not os.path.exists(directory):
    os.makedirs(directory)

# writing to csv file 
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    # writing the data rows 
    csvwriter.writerows(top_k_inds)


In [None]:
data_cols = [i for i in range(1,14)] + [str(i) + "_ind" for i in range(1,14)]

for c, _list in enumerate(top_k_inds):
    for ind in _list:
        name, summary = getConcept(data_cols, input_dim, changing_dim, int(ind))
        print(f"Concept {c}: ID {ind}, Feature {name}, Summary {summary}")


In [None]:
greedy_results = greedy_selection(auroc_metric, test_loader, top_k_inds, model_shared, track_metrics={"acc": accuracy_metric})
greedy_results.head()

In [None]:
top_k_csv_file = experiment_folder + "top-k/bottleneck_r{}_c{}_topkinds.csv".format(random_seed, n_concepts)

# writing to csv file
with open(top_k_csv_file, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(greedy_results.columns)
    # writing the data rows 
    for row in greedy_results.itertuples(index=False):
        csvwriter.writerow(list(row))


In [None]:
data_cols = [i for i in range(1,14)] + [str(i) + "_ind" for i in range(1,14)]

sorted_ = greedy_results.sort_values(["Concept", "ID"])

for row in sorted_.itertuples(index=False):
    name, summary = getConcept(data_cols, input_dim, changing_dim, row[1])
    print(f"Concept {row[2]}: ID {row[1]}, Feature {name}, Summary {summary}")

In [None]:
plt.plot(greedy_results["Score"])
plt.plot(greedy_results["acc"])
plt.show()

In [None]:
top_k_csv_file = "/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/top-k/bottleneck_r1_c6_topkinds.csv"
n_concepts = 6
model_shared = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes, top_k=top_k_csv_file)
# model.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)

model_shared.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model_shared.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb).item()
        acc = accuracy_metric(probs, yb).item()
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()

print(auc)
print(acc)


In [None]:
model_shared.fit(train_loader, val_loader, weights, save_model_path="/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/top-k/arabic_c6_finetuned.pt", epochs=3000)

model_shared.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model_shared.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb)
        acc = accuracy_metric(probs, yb)
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()
    
print(auc)
print(acc)
