In [39]:
%load_ext autoreload
%autoreload 2

from darts.datasets import ETTh1Dataset
from darts.models import NLinearModel
from darts.metrics.metrics import mae, mse
import numpy as np
import pandas as pd
import torch
import random
import csv
import datetime
import os
import gc
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler

import optuna
from optuna.trial import TrialState
from optuna.visualization import plot_optimization_history, plot_param_importances, plot_timeline

import models
import models_3d_concepts_on_time
import models_3d_atomics_on_variate_to_concepts
from preprocess_helpers import *
from helper import *
from param_initializations import *
from optimization_strategy import greedy_selection

device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cuda')

In [40]:
series = ETTh1Dataset().load()

print(series.start_time())
print(series.end_time())

# series.plot()

2016-07-01 00:00:00
2018-06-26 19:00:00


In [41]:
train_series, test_series = series.split_before(0.6)
val_series, test_series = test_series.split_before(0.5)


In [42]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data, targets, T, window_stride=1, pred_len=1):
        self.data = data
        self.targets = targets
        assert targets.size(0) == data.size(0)
        self.T = T # time window
        self.window_stride = window_stride
        self.pred_len = pred_len
        self.N, self.V = data.shape

    def __len__(self):
        return len(range(0, self.N - self.T - self.pred_len + 1, self.window_stride))

    def __getitem__(self, idx):
        start = idx * self.window_stride
        end = start + self.T

        X = self.data[start:end]
        # if mode == "S": # predict only target
        y = self.targets[end:end + self.pred_len].flatten()
        # elif mode == "MS": # predict all variables
        #   y = self.data[end:end + self.pred_len, :7].flatten()
        return X, y


In [43]:
def preprocess_data(series, seq_len, window_stride=1, pred_len=1, batch_size = 512):
    scaler = StandardScaler()
    
    train, test = series.split_before(0.6)
    val, test = test.split_before(0.5)
    
    print("Train/Val/Test", len(train), len(val), len(test))
    
    train_og = train.pd_dataframe()
    train = scaler.fit_transform(train_og)
    train = pd.DataFrame(train, columns=train_og.columns)
    X_train = train
    y_train = train[["OT"]]
    X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
    y_train = torch.tensor(y_train.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_train)
    X_train = torch.cat([X_train, indicators], axis=1)
    
    train_dataset = TimeSeriesDataset(X_train, y_train, seq_len, window_stride, pred_len)
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)

    val_og = val.pd_dataframe()
    val = scaler.transform(val_og)
    val = pd.DataFrame(val, columns=val_og.columns)
    X_val = val
    y_val = val[["OT"]]
    X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
    y_val = torch.tensor(y_val.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_val)
    X_val = torch.cat([X_val, indicators], axis=1)
    
    val_dataset = TimeSeriesDataset(X_val, y_val, seq_len, window_stride, pred_len)
    val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)

    test_og = test.pd_dataframe()
    test = scaler.transform(test_og)
    test = pd.DataFrame(test, columns=test_og.columns)
    X_test = test
    y_test = test[["OT"]]
    X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
    y_test = torch.tensor(y_test.to_numpy(), dtype=torch.float32)
    
    indicators = torch.isfinite(X_test)
    X_test = torch.cat([X_test, indicators], axis=1)
    
    test_dataset = TimeSeriesDataset(X_test, y_test, seq_len, window_stride, pred_len)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, test_loader, scaler


In [44]:
seq_len = 10
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=24)

for X,y in train_loader:
    print(X.shape)
    print(y.shape)
    break

print("Batches", len(train_loader), len(val_loader), len(test_loader))

Train/Val/Test 10451 3484 3485


torch.Size([512, 10, 14])
torch.Size([512, 24])
Batches 21 7 7


In [45]:
# plots
def plot_losses(train_losses, val_losses):
    plt.plot(train_losses, color="black", label="Train")
    plt.plot(val_losses, color="green", label="Val")
    plt.legend()
    plt.show()

def plot_mae_mse(history, title, dec="{:.3g}"):
    xticks = range(len(history[:, 0]))
    plt.plot(xticks, history[:, 2], label='MAE')
    plt.plot(xticks, history[:, 3], label='MSE')

    plt.xlabel('Num Concepts')
    plt.ylabel('Criteria')
    # plt.ylim(0, 1)
    xtick_labels = list(map(int, history[:, 0]))
    plt.xticks(xticks, xtick_labels)
    plt.yscale('log')

    if dec:
        for x,_y in zip(xticks, history[:, 2]):
            label = dec.format(_y)
            plt.annotate(label, # this is the text
                        (x,_y), # these are the coordinates to position the label
                        textcoords="offset points", # how to position the text
                        xytext=(0,-10), # distance from text to points (x,y)
                        ha='center') # horizontal alignment can be left, right or center
            
        for x,_y in zip(xticks, history[:, 3]):
            label = dec.format(_y)
            plt.annotate(label, # this is the text
                        (x,_y), # these are the coordinates to position the label
                        textcoords="offset points", # how to position the text
                        xytext=(0,-10), # distance from text to points (x,y)
                        ha='center') # horizontal alignment can be left, right or center

    plt.title(title)
    plt.legend()
    plt.show()

def plot_prediction_vs_true(yb, preds, title):
    plt.plot(yb, color="black", label="True")
    plt.plot(preds, color="red", label="Pred")
    plt.title(title)
    plt.legend()
    plt.show()


In [46]:
def initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim, top_k=''):
    model = models.CBM(input_dim = input_dim, 
                            changing_dim = changing_dim, 
                            seq_len = seq_len,
                            num_concepts = n_concepts,
                            opt_lr = 3e-3, # 2e-4
                            opt_weight_decay = 1e-05,
                            l1_lambda=0.001,
                            cos_sim_lambda=0.01,
                            output_dim = output_dim,
                            top_k=top_k,
                            task_type=models.TaskType.REGRESSION,
                            )
    model = model.to(device)
    return model

def initializeModel_with_atomics(n_atomics, n_concepts, input_dim, changing_dim, seq_len, output_dim, use_summaries_for_atomics, top_k=''):
    model = models_3d_atomics_on_variate_to_concepts.CBM(input_dim = input_dim, 
                            changing_dim = changing_dim, 
                            seq_len = seq_len,
                            num_concepts = n_concepts,
                            num_atomics = n_atomics,
                            use_summaries_for_atomics = use_summaries_for_atomics,
                            opt_lr = 3e-3, # 2e-4
                            opt_weight_decay = 1e-05,
                            l1_lambda=0.001,
                            cos_sim_lambda=0.01,
                            output_dim = output_dim,
                            top_k=top_k,
                            task_type=models_3d_atomics_on_variate_to_concepts.TaskType.REGRESSION,
                            )
    model = model.to(device)
    return model

In [47]:
random_seed = 1
set_seed(random_seed)


In [48]:
seq_len = 336
pred_len = 96
n_atomics_list = list(range(2,11,2))
n_concepts_list = list(range(2,11,2))
changing_dim = len(series.columns)
input_dim = 2 * changing_dim


## Original

In [None]:
experiment_folder = f"/workdir/optimal-summaries-public/vasopressor/models/etth1/multi2single-L{seq_len}-T{pred_len}/"
model_path_og = experiment_folder + "forecasting_c{}.pt"

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [51]:
history_og = []

train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)

for n_concepts in n_concepts_list:
    print("n_concepts", n_concepts)
    
    model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
    model.fit(train_loader, val_loader, None, save_model_path=model_path_og.format(n_concepts), max_epochs=10000)
    
    display(model)
    
    model.eval()
    with torch.inference_mode():
        for batch_idx, (Xb, yb) in enumerate(test_loader):
            Xb, yb = Xb.to(device), yb.to(device)
            preds = model.forward(Xb)
            
            mae = mae_metric(preds, yb).item()
            mse = mse_metric(preds, yb).item()
        mae = mae_metric.compute().item()
        mse = mse_metric.compute().item()
        mae_metric.reset()
        mse_metric.reset()
    
    history = [n_concepts, round(model.val_losses[-1],2), mae, mse]
    display(history)
    history_og.append(np.array(history))
    
    plot_losses(model.train_losses, model.val_losses)
    
history_og = np.array(history_og)
history_og.shape


Train/Val/Test 10451 3484 3485
n_concepts 2
test 96




NameError: name 'model_path_og' is not defined

In [None]:
plot_mae_mse(history_og, "Original")


In [None]:
# Plot Prediction vs actual
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)
n_concepts = 10

model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
model.fit(train_loader, val_loader, None, save_model_path=model_path_og.format(n_concepts), max_epochs=10000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(val_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        preds = model.forward(Xb)
        
        mae = mae_metric(preds, yb).item()
        mse = mse_metric(preds, yb).item()
        break
    mae = mae_metric.compute().item()
    mse = mse_metric.compute().item()
    mae_metric.reset()
    mse_metric.reset()


i = 20
yb = yb.cpu().numpy()[i]
preds = preds.cpu().numpy()[i]

print(yb.shape)
print(preds.shape)

plot_prediction_vs_true(yb, preds, title=f"Original - Predictions with {n_concepts} Concepts")


## Redesigned

In [28]:
experiment_folder = f"/workdir/optimal-summaries-public/vasopressor/models/etth1/atomics-from-summaries-L{seq_len}-T{pred_len}/"
model_path_re = experiment_folder + "forecasting_c{}.pt"

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [50]:
pred_len

96

In [54]:
history_re = []

train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)

for n_concepts in n_concepts_list:
    for n_atomics in n_atomics_list:
        print("n_atomics", n_atomics, "n_concepts", n_concepts)
        
        model = initializeModel_with_atomics(n_atomics, n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len, use_summaries_for_atomics=True)
        print("test", model.output_dim)
        model.fit(train_loader, val_loader, None, save_model_path=model_path_re.format(n_concepts), max_epochs=10000)
        
        print("Trained for ", model.curr_epoch+1)
        display(model)
        
        model.eval()
        with torch.no_grad():
            for batch_idx, (Xb, yb) in enumerate(test_loader):
                Xb, yb = Xb.to(device), yb.to(device)
                preds = model.forward(Xb)
                
                mae = mae_metric(preds, yb).item()
                mse = mse_metric(preds, yb).item()
            mae = mae_metric.compute().item()
            mse = mse_metric.compute().item()
            mae_metric.reset()
            mse_metric.reset()
        
        history = [n_atomics, n_concepts, round(model.val_losses[-1],2), mae, mse]
        display(history)
        history_re.append(np.array(history))
    
        plot_losses(model.train_losses, model.val_losses)
        torch.cuda.empty_cache()
    
history_re = np.array(history_re)
history_re.shape


Train/Val/Test 10451 3484 3485
n_atomics 2 n_concepts 2
test 96
Loaded model from /workdir/optimal-summaries-public/vasopressor/models/etth1/atomics-from-summaries-L336-T96/forecasting_c2.pt


  0%|          | 0/9980 [00:00<?, ?it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 1/9980 [00:00<1:04:09,  2.59it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 2/9980 [00:00<1:04:25,  2.58it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 3/9980 [00:01<1:03:34,  2.62it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 4/9980 [00:01<1:03:06,  2.63it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 5/9980 [00:01<1:03:52,  2.60it/s]Exception ignored in: <function _releaseLock at 0x7ff256da9750>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 6/9980 [00:02<1:03:35,  2.61it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 7/9980 [00:02<1:02:47,  2.65it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 8/9980 [00:03<1:01:55,  2.68it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 9/9980 [00:03<1:01:14,  2.71it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 10/9980 [00:03<1:01:52,  2.69it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 11/9980 [00:04<1:02:37,  2.65it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 12/9980 [00:04<1:01:59,  2.68it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 13/9980 [00:04<1:01:56,  2.68it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 14/9980 [00:05<1:01:42,  2.69it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 15/9980 [00:05<1:02:39,  2.65it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 16/9980 [00:06<1:02:31,  2.66it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 17/9980 [00:06<1:02:46,  2.64it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 18/9980 [00:06<1:02:51,  2.64it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 19/9980 [00:07<1:03:31,  2.61it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 20/9980 [00:07<1:13:14,  2.27it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([493, 7])
var_feats torch.Size([493, 7])
out torch.Size([493, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 21/9980 [00:08<1:09:48,  2.38it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 22/9980 [00:08<1:07:12,  2.47it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 23/9980 [00:08<1:05:31,  2.53it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 24/9980 [00:09<1:04:37,  2.57it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 25/9980 [00:09<1:04:20,  2.58it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 26/9980 [00:09<1:03:04,  2.63it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 27/9980 [00:10<1:02:40,  2.65it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 28/9980 [00:10<1:02:39,  2.65it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 29/9980 [00:11<1:03:01,  2.63it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 30/9980 [00:11<1:03:09,  2.63it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 31/9980 [00:11<1:03:02,  2.63it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 32/9980 [00:12<1:03:41,  2.60it/s]

mean_feats torch.Size([292, 7])
var_feats torch.Size([292, 7])
out torch.Size([292, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 33/9980 [00:12<1:03:37,  2.61it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 34/9980 [00:13<1:03:32,  2.61it/s]

mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size([512, 7])
var_feats torch.Size([512, 7])
out torch.Size([512, 96])
mean_feats torch.Size

  0%|          | 35/9980 [00:13<1:03:04,  2.63it/s]

In [None]:
plot_mae_mse(history_re, "Redesigned")


In [None]:
# Plot Prediction vs actual
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)
n_concepts = 600

model = initializeModel_redesigned(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
model.fit(train_loader, val_loader, None, save_model_path=model_path_re.format(n_concepts), max_epochs=10000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(val_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        preds = model.forward(Xb)
        
        mae = mae_metric(preds, yb).item()
        mse = mse_metric(preds, yb).item()
        break
    mae = mae_metric.compute().item()
    mse = mse_metric.compute().item()
    mae_metric.reset()
    mse_metric.reset()


i = 20
yb = yb.cpu().numpy()[i]
preds = preds.cpu().numpy()[i]

print(yb.shape)
print(preds.shape)

plot_prediction_vs_true(yb, preds, title=f"Redesigned - Predictions with {n_concepts} Concepts")


## Redesigned + LambdaLR as paper

In [None]:
experiment_folder = f"/workdir/optimal-summaries-public/vasopressor/models/etth1/time2con-lambdalr-multi2single-L{seq_len}-T{pred_len}/"
model_path_re_lamdalr = experiment_folder + "forecasting_c{}.pt"

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [None]:
history_re_lambdalr = []

train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)

for n_concepts in n_concepts_list:
    print("n_concepts", n_concepts)
    
    model = initializeModel_redesigned(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=model.optimizer, patience=5)
    lr_lambda = lambda epoch: 0.5 ** ((epoch - 1) // 1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=model.optimizer, lr_lambda=lr_lambda)
    model.fit(train_loader, val_loader, None, save_model_path=model_path_re_lamdalr.format(n_concepts), max_epochs=10000, scheduler=scheduler)
    
    display(model)
    
    model.eval()
    with torch.no_grad():
        for batch_idx, (Xb, yb) in enumerate(test_loader):
            Xb, yb = Xb.to(device), yb.to(device)
            preds = model.forward(Xb)
            
            mae = mae_metric(preds, yb).item()
            mse = mse_metric(preds, yb).item()
        mae = mae_metric.compute().item()
        mse = mse_metric.compute().item()
        mae_metric.reset()
        mse_metric.reset()
    
    history = [n_concepts, round(model.val_losses[-1],2), mae, mse]
    display(history)
    history_re_lambdalr.append(np.array(history))
    
    plot_losses(model.train_losses, model.val_losses)
    torch.cuda.empty_cache()
    
history_re_lambdalr = np.array(history_re_lambdalr)
history_re_lambdalr.shape


In [None]:
plot_mae_mse(history_re_lambdalr, "Redesigned + LambdaLR")


In [None]:
# Plot Prediction vs actual
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)
n_concepts = 4

model = initializeModel_redesigned(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
# model.fit(train_loader, val_loader, None, save_model_path=model_path_re_lamdalr.format(n_concepts), max_epochs=10000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(val_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        preds = model.forward(Xb)
        
        mae = mae_metric(preds, yb).item()
        mse = mse_metric(preds, yb).item()
        break
    mae = mae_metric.compute().item()
    mse = mse_metric.compute().item()
    mae_metric.reset()
    mse_metric.reset()


i = 20
yb = yb.cpu().numpy()[i]
preds = preds.cpu().numpy()[i]

print(yb.shape)
print(preds.shape)

plot_prediction_vs_true(yb, preds, title=f"Redesigned + LambdaLR - Predictions with {n_concepts} Concepts")


## Redesigned + ReduceLROnPlateau

In [None]:
experiment_folder = f"/workdir/optimal-summaries-public/vasopressor/models/etth1/time2con-reduceonplateau-multi2single-L{seq_len}-T{pred_len}/"
model_path_re_reduceonplateau = experiment_folder + "forecasting_c{}.pt"

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [None]:
history_re_reduceonplateau = []

train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)

for n_concepts in n_concepts_list:
    print("n_concepts", n_concepts)
    
    model = initializeModel_redesigned(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=model.optimizer, patience=5) # half patience of early stopping
    model.fit(train_loader, val_loader, None, save_model_path=model_path_re_reduceonplateau.format(n_concepts), max_epochs=10000, scheduler=scheduler)
    
    display(model)
    
    model.eval()
    with torch.no_grad():
        for batch_idx, (Xb, yb) in enumerate(test_loader):
            Xb, yb = Xb.to(device), yb.to(device)
            preds = model.forward(Xb)
            
            mae = mae_metric(preds, yb).item()
            mse = mse_metric(preds, yb).item()
        mae = mae_metric.compute().item()
        mse = mse_metric.compute().item()
        mae_metric.reset()
        mse_metric.reset()
    
    history = [n_concepts, round(model.val_losses[-1],2), mae, mse]
    display(history)
    history_re_reduceonplateau.append(np.array(history))
    
    plot_losses(model.train_losses, model.val_losses)
    torch.cuda.empty_cache()
    
history_re_reduceonplateau = np.array(history_re_reduceonplateau)
history_re_reduceonplateau.shape


In [None]:
plot_mae_mse(history_re_reduceonplateau, "Redesigned + ReduceLROnPlateau")


In [None]:
# Plot Prediction vs actual
train_loader, val_loader, test_loader, scaler = preprocess_data(series, seq_len, pred_len=pred_len)

mae_metric = MeanAbsoluteError().to(device)
mse_metric = MeanSquaredError().to(device)
n_concepts = 10

model = initializeModel_redesigned(n_concepts, input_dim, changing_dim, seq_len, output_dim=pred_len)
# model.fit(train_loader, val_loader, None, save_model_path=model_path_re_reduceonplateau.format(n_concepts), max_epochs=10000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(val_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        preds = model.forward(Xb)
        
        mae = mae_metric(preds, yb).item()
        mse = mse_metric(preds, yb).item()
        break
    mae = mae_metric.compute().item()
    mse = mse_metric.compute().item()
    mae_metric.reset()
    mse_metric.reset()


i = 20
yb = yb.cpu().numpy()[i]
preds = preds.cpu().numpy()[i]

print(yb.shape)
print(preds.shape)

plot_prediction_vs_true(yb, preds, title=f"Redesigned + ReduceLROnPlateau - Predictions with {n_concepts} Concepts")


## Optimization

In [None]:
# feature weights
n_concepts = 5

model = initializeModel(n_concepts, input_dim, changing_dim, seq_len)
model.fit(train_loader, val_loader, None, model_path.format(n_concepts), 1000)

for name, param in model.named_parameters():
    if "bottleneck.weight" in name:
        bottleneck_weights = param
feature_weights = bottleneck_weights.cpu().detach().numpy()

feature_weights.shape

In [None]:
# visualize weight magnitudes
for c in range(n_concepts):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    inds = np.argsort(-np.abs(feature_weights[c]))[:100]
    ax.bar(np.arange(1,101),np.abs(feature_weights[c])[inds])
    ax.set_xlabel("Top 100 features")
    ax.set_ylabel("abs value of feature coefficient")
    plt.show()


In [None]:
# get 90th percentile of feature weights
sum90p = np.sum(np.abs(feature_weights), axis=-1)*0.90
sum90p.shape


In [None]:
# get top K indizes
top_k_inds = []
for c in range(n_concepts):
    topkinds_conc = []
    curr_sum = 0
    inds = np.argsort(-np.abs(feature_weights[c])) #desc
    sorted_weights = feature_weights[c][inds]
    
    for ind, weight in zip(inds, sorted_weights):
        curr_sum += abs(weight)
        if curr_sum <= sum90p[c]:
            topkinds_conc.append(ind)
        else:
            break
    
    # if selects less than 10, choose 10 best
    if len(topkinds_conc) < 10:
        topkinds_conc = np.argsort(-np.abs(feature_weights[c]))[:10].tolist()
    
    top_k_inds.append(topkinds_conc)

top_k_inds

In [None]:
# write top k inds to csv
filename = experiment_folder + "top-k/top_k_inds_c{}.csv".format(n_concepts)

directory = os.path.dirname(filename)
if not os.path.exists(directory):
    os.makedirs(directory)

# writing to csv file 
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    # writing the data rows 
    csvwriter.writerows(top_k_inds)


In [None]:
best_aucs, best_auc_inds, best_auc_concepts = greedy_selection(auroc_metric, test_loader, top_k_inds, model)


In [None]:
filename = experiment_folder + "top-k/bottleneck_r{}_c{}_topkinds.csv".format(random_seed, n_concepts)

# writing to csv file
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Best AUC", "Best AUC Concept #", "Best AUC ind #"])
    # writing the data rows 
    for row in zip(best_aucs, best_auc_concepts, best_auc_inds):
        csvwriter.writerow(list(row))


## Multi-class

In [None]:
experiment_folder = "/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/"
model_path = experiment_folder + "arabic_c{}.pt"
random_seed = 1

if not os.path.exists(experiment_folder):
    os.makedirs(experiment_folder)

In [None]:
history_multiclass = []

set_seed(random_seed)

data, y_ohe, num_classes, weights = preprocess_data_multiclass(X, y)
train_loader, val_loader, test_loader = initialize_data(1, data, y_ohe, multiclass=True)

input_dim = data.shape[2]
changing_dim = X[0].shape[0]
seq_len = data.shape[1]

auroc_metric = AUROC(task="multiclass", num_classes=num_classes).to(device)
accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes).to(device)

for n_concepts in range(1,16):
    print(n_concepts)
    
    model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes)
    model.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)
    
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb).item()
        acc = accuracy_metric(probs, yb).item()
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()
    
    history = [n_concepts, model.val_losses[-1], auc, acc]
    history_multiclass.append(np.array(history))
history_multiclass = np.array(history_multiclass)
history_multiclass.shape


In [None]:
# plot
plt.plot(history_multiclass[:, 0], history_multiclass[:, 2], label='AUC')
plt.plot(history_multiclass[:, 0], history_multiclass[:, 3], label='ACC')

plt.xlabel('Num Concepts')
plt.ylabel('Criteria')
plt.title('Plot of Concepts vs Criteria')
plt.xticks(np.arange(min(history_multiclass[:, 0]), max(history_multiclass[:, 0])+1, 1))

for x,_y in zip(history_multiclass[:, 0], history_multiclass[:, 2]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center
    
for x,_y in zip(history_multiclass[:, 0], history_multiclass[:, 3]):
    label = "{:.2f}".format(_y)
    plt.annotate(label, # this is the text
                 (x,_y), # these are the coordinates to position the label
                 textcoords="offset points", # how to position the text
                 xytext=(0,-10), # distance from text to points (x,y)
                 ha='center') # horizontal alignment can be left, right or center

plt.legend()
plt.show()


In [None]:
# feature weights
n_concepts = 5

model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes)
model.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)

for name, param in model.named_parameters():
    if "bottleneck.weight" in name:
        bottleneck_weights = param
feature_weights = bottleneck_weights.cpu().detach().numpy()

feature_weights.shape

In [None]:
# visualize weight magnitudes
for c in range(n_concepts):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    inds = np.argsort(-np.abs(feature_weights[c]))[:100]
    ax.bar(np.arange(1,101),np.abs(feature_weights[c])[inds])
    ax.set_xlabel("Top 100 features")
    ax.set_ylabel("abs value of feature coefficient")
    plt.show()


In [None]:
# get 90th percentile of feature weights
sum90p = np.sum(np.abs(feature_weights), axis=-1)*0.90
sum90p.shape


In [None]:
# get top K indizes
top_k_inds = []
for c in range(n_concepts):
    topkinds_conc = []
    curr_sum = 0
    inds = np.argsort(-np.abs(feature_weights[c])) #desc
    sorted_weights = feature_weights[c][inds]
    
    for ind, weight in zip(inds, sorted_weights):
        curr_sum += abs(weight)
        if curr_sum <= sum90p[c]:
            topkinds_conc.append(ind)
        else:
            break
    
    # if selects less than 10, choose 10 best
    if len(topkinds_conc) < 10:
        topkinds_conc = np.argsort(-np.abs(feature_weights[c]))[:10].tolist()
    
    top_k_inds.append(topkinds_conc)

top_k_inds

In [None]:
# write top k inds to csv
filename = experiment_folder + "top-k/top_k_inds_c{}.csv".format(n_concepts)

directory = os.path.dirname(filename)
if not os.path.exists(directory):
    os.makedirs(directory)

# writing to csv file 
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    # writing the data rows 
    csvwriter.writerows(top_k_inds)


In [None]:
data_cols = [i for i in range(1,14)] + [str(i) + "_ind" for i in range(1,14)]

for c, _list in enumerate(top_k_inds):
    for ind in _list:
        name, summary = getConcept(data_cols, input_dim, changing_dim, int(ind))
        print(f"Concept {c}: ID {ind}, Feature {name}, Summary {summary}")


In [None]:
greedy_results = greedy_selection(auroc_metric, test_loader, top_k_inds, model, track_metrics={"acc": accuracy_metric})
greedy_results.head()

In [None]:
top_k_csv_file = experiment_folder + "top-k/bottleneck_r{}_c{}_topkinds.csv".format(random_seed, n_concepts)

# writing to csv file
with open(top_k_csv_file, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(greedy_results.columns)
    # writing the data rows 
    for row in greedy_results.itertuples(index=False):
        csvwriter.writerow(list(row))


In [None]:
data_cols = [i for i in range(1,14)] + [str(i) + "_ind" for i in range(1,14)]

sorted_ = greedy_results.sort_values(["Concept", "ID"])

for row in sorted_.itertuples(index=False):
    name, summary = getConcept(data_cols, input_dim, changing_dim, row[1])
    print(f"Concept {row[2]}: ID {row[1]}, Feature {name}, Summary {summary}")

In [None]:
plt.plot(greedy_results["Score"])
plt.plot(greedy_results["acc"])
plt.show()

In [None]:
top_k_csv_file = "/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/top-k/bottleneck_r1_c6_topkinds.csv"
n_concepts = 6
model = initializeModel(n_concepts, input_dim, changing_dim, seq_len, num_classes, top_k=top_k_csv_file)
# model.fit(train_loader, val_loader, weights, model_path.format(n_concepts), 1000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb).item()
        acc = accuracy_metric(probs, yb).item()
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()

print(auc)
print(acc)


In [None]:
model.fit(train_loader, val_loader, weights, save_model_path="/workdir/optimal-summaries-public/vasopressor/models/arabic/multiclass/top-k/arabic_c6_finetuned.pt", epochs=3000)

model.eval()
with torch.no_grad():
    for batch_idx, (Xb, yb) in enumerate(test_loader):
        Xb, yb = Xb.to(device), yb.to(device)
        probs = model.forward_probabilities(Xb)
        
        auc = auroc_metric(probs, yb)
        acc = accuracy_metric(probs, yb)
    auc = auroc_metric.compute().item()
    acc = accuracy_metric.compute().item()
    auroc_metric.reset()
    accuracy_metric.reset()
    
print(auc)
print(acc)
