# BB Retrain Strategy (Refactored)

This notebook demonstrates how to use the refactored code modules for data preparation, training, evaluation, and resource monitoring.

In [None]:
# Import refactored modules
from data_utils import split_cifar10_dataset, filter_by_samples, filter_by_classes, create_subset_dataloaders
from evaluate import evaluate_and_log, summarize_exits, plot_exit_distribution, per_class_accuracy, exit_class_heatmap, print_exit_byclass_table, extract_exit_dataset
from training import trainEE_KL, inference_loop, retrain_loop
from monitor import print_resource_usage

In [None]:
# Data loading and splitting
train_dataset_1, train_dataset_2, test_dataset = split_cifar10_dataset()
print(f'Dataset 1 size: {len(train_dataset_1)}')
print(f'Dataset 2 size: {len(train_dataset_2)}')

In [None]:
# Example: Filter test set by class ratios
classes_ratios = {
    'airplane': 0.8, 'automobile': 0.8, 'bird': 0.8, 'cat': 0.8,
    'deer': 0.6, 'dog': 0.6, 'frog': 0.6,
    'horse': 0.1, 'ship': 0.1, 'truck': 0.1
} 
subset = filter_by_samples(test_dataset, classes_ratios, seed=42)
print(f'Total samples in subset: {len(subset)}')

In [None]:
# Create DataLoader for new distribution
from torch.utils.data import DataLoader
new_distribution_loader = DataLoader(subset, batch_size=64, shuffle=True)

In [None]:
# Example: Evaluate model on new distribution
logs = evaluate_and_log(model_with_exits, new_distribution_loader, thresholds=(0.88, 0.54))
summarize_exits(logs)
plot_exit_distribution(logs)
# per_class_accuracy(logs, class_names)  # Uncomment if class_names is defined

In [None]:
# Example: Resource monitoring
print_resource_usage()

## Training and Inference Threads
You can use the provided `inference_loop` and `retrain_loop` functions in separate threads for concurrent execution.

In [None]:
import threading
model_lock = threading.Lock()
# Example: Start inference and retrain threads (uncomment to use)
# threading.Thread(target=inference_loop, args=(model_with_exits, new_distribution_loader, model_lock)).start()
# threading.Thread(target=retrain_loop, args=(model_with_exits, model_lock)).start()