In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import argparse
from argparse import Namespace

import models.models_original as models_original
import models.models_3d_atomics as models_3d_atomics
import models.models_3d as models_3d
from models.data import *
from models.helper import *
from models.param_initializations import *
from models.optimization_strategy import *


In [2]:
args = {
    'random_states': [1],
    'dataset': 'spoken_arabic_digits',
    'model': 'atomics',
    'pruning': 'sparse_learning', # weight_gradient_magnitude
    'device': 'cuda',
    'save_load_path': '/workdir/optimal-summaries-public/_models/',
    'n_concepts': 4,
    'n_atomics': 10,
    'switch_encode_dim': True, # default True
    'switch_summaries_layer': False, # default True
    'switch_indicators': True, # default True
    'switch_use_only_last_timestep': False, # default False
    # 'switch_use_summaries': True, # default True
}

args = Namespace(**args)

print("All arguments:")
for arg in vars(args):
    print(f"{arg}: {getattr(args, arg)}")


All arguments:
random_states: [1]
dataset: spoken_arabic_digits
model: atomics
pruning: sparse_learning
device: cuda
save_load_path: /workdir/optimal-summaries-public/_models/
n_concepts: 4
n_atomics: 10
switch_encode_dim: True
switch_summaries_layer: False
switch_indicators: True
switch_use_only_last_timestep: False


In [3]:
def get_dataloader(random_state):
    set_seed(random_state)

    if args.dataset == "mimic":
        return get_MIMIC_dataloader(random_state = random_state)
    elif args.dataset == "tiselac":
        return get_tiselac_dataloader(random_state = random_state)
    elif args.dataset == "spoken_arabic_digits":
        return get_arabic_spoken_digits_dataloader(random_state = random_state)
    else:
        print("No known dataset selected")
        sys.exit(1)


def get_model(random_state):
    set_seed(random_state)
    
    train_loader, val_loader, test_loader, class_weights, num_classes, changing_dim, static_dim, seq_len = get_dataloader(random_state)
    
    if args.model == "original":
        model = models_original.CBM(n_concepts=args.n_concepts, use_indicators=args.switch_indicators, use_only_last_timestep=args.switch_use_only_last_timestep, static_dim=static_dim, changing_dim=changing_dim, seq_len=seq_len, output_dim=num_classes, device=args.device)
    elif args.model == "shared":
        model = models_3d.CBM(n_concepts=args.n_concepts, encode_time_dim=args.switch_encode_dim, use_indicators=args.switch_indicators, static_dim=static_dim, changing_dim=changing_dim, seq_len=seq_len, output_dim=num_classes, device=args.device)
    elif args.model == "atomics":
        model = models_3d_atomics.CBM(n_concepts=args.n_concepts, n_atomics=args.n_atomics, use_summaries_for_atomics=args.switch_summaries_layer, use_indicators=args.switch_indicators, static_dim=static_dim, changing_dim=changing_dim, seq_len=seq_len, output_dim=num_classes, device=args.device)
    else:
        print("No known model selected")
        sys.exit(1)
    return model


def get_trained_model(random_state):
    set_seed(random_state)

    train_loader, val_loader, test_loader, class_weights, num_classes, changing_dim, static_dim, seq_len = get_dataloader(random_state)
    
    model = get_model(random_state)
    model_path = model.get_model_path(base_path=args.save_load_path, dataset=args.dataset, pruning=args.pruning, seed=random_state)
    model.try_load_else_fit(train_loader, val_loader, p_weight=class_weights, save_model_path=model_path, max_epochs=10000, save_every_n_epochs=10, patience=10, sparse_fit=False)

    evaluate_classification(model=model, dataloader=val_loader, num_classes=num_classes)
    
    return model


def get_metrics(num_classes):
    if num_classes == 2:
        auroc_metric = AUROC(task="binary").to(args.device)
        accuracy_metric = Accuracy(task="binary").to(args.device)
        f1_metric = F1Score(task="binary").to(args.device)
        # conf_matrix = ConfusionMatrix(task="binary").to(args.device)
    else:
        average = "macro"
        auroc_metric = AUROC(task="multiclass", num_classes=num_classes, average = average).to(args.device)
        accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes, top_k=1, average = average).to(args.device)
        f1_metric = F1Score(task="multiclass", num_classes=num_classes, top_k=1, average = average).to(args.device)
        # conf_matrix = ConfusionMatrix(task="multiclass", num_classes=num_classes).to(args.device)
    
    return {"acc": accuracy_metric, "f1": f1_metric, "auc": auroc_metric}


In [4]:
makedir(args.save_load_path)


### Train and evaluate

In [5]:
result_df = pd.DataFrame(columns=["Model", "Dataset", "Seed", "Split", "Pruning", "Finetuned", "AUC", "ACC", "F1", "Total parameter", "Remaining parameter"])

# for random_state in args.random_states:
random_state = 1

model = get_model(random_state)
train_loader, val_loader, test_loader, class_weights, num_classes, changing_dim, static_dim, seq_len = get_dataloader(random_state)
model_path = model.get_model_path(base_path=args.save_load_path, dataset=args.dataset, pruning=args.pruning, seed=random_state)
model.try_load_else_fit(train_loader, val_loader, p_weight=class_weights, save_model_path=model_path, max_epochs=10, save_every_n_epochs=10, patience=10, sparse_fit=True)




Loaded model from /workdir/optimal-summaries-public/_models/spoken_arabic_digits/atomics/sparse_learning/atomics_num_concepts_4_num_atomics_10_use_summaries_for_atomics_False_use_indicators_True_use_summaries_True_seed_1.pt


In [12]:
(model.layer_time_to_atomics.weight != 0).sum()


tensor(62, device='cuda:0')

In [13]:
(model.layer_to_concepts.weight != 0).sum()


tensor(79, device='cuda:0')

In [None]:
total, remaining = get_total_and_remaining_parameters(model.regularized_layers)


In [None]:

metrics = evaluate_classification(model, val_loader)
result_df.loc[len(result_df)] = {"Model": model.get_short_model_name(), "Dataset": args.dataset, "Seed": random_state, "Split": "val", "Pruning": "sparse_learning", "Finetuned": True, "AUC": metrics[0], "ACC": metrics[1], "F1": metrics[2], "Total parameter": total, "Remaining parameter": remaining}
metrics = evaluate_classification(model, test_loader)
result_df.loc[len(result_df)] = {"Model": model.get_short_model_name(), "Dataset": args.dataset, "Seed": random_state, "Split": "test", "Pruning": "sparse_learning", "Finetuned": True, "AUC": metrics[0], "ACC": metrics[1], "F1": metrics[2], "Total parameter": total, "Remaining parameter": remaining}
    
