## Provide list of paths for edits and run trials for all 10 classes

In [1]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
from datetime import datetime
import pandas as pd

In [2]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, informal_log, list_to_dict, write_lists, write_json, ensure_files
from utils.model_utils import prepare_device
from parse_config import ConfigParser
# from data_loader import data_loaders
import datasets.datasets as module_data
import model.model as module_arch
from utils.knn_utils import load_and_analyze_knn
from utils.results_to_csv import store_csv
from edit_knn import main as edit

In [3]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [4]:
# Define constants, paths
config_path = 'configs/copies/cinic10_imagenet_segmentation_edit_trials.json'
class_list_path = 'metadata/cinic-10/class_names.txt'

analyze_in_edit = True
sort_type = 'softmax'

In [5]:
# Load config file
config_dict = read_json(config_path)
# Load class list and obtain target class idx
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

n_select = 100

# Set K
K = config_dict['editor']['K']

device, device_ids = prepare_device(config_dict['n_gpu'])

In [6]:
# Load datasets
data_loader_args = dict(config_dict["data_loader"]["args"])
dataset_args = dict(config_dict["dataset_args"])

# Create validation data loader
val_image_paths = read_lists(config_dict['dataset_paths']['valid_images'])
val_labels = read_lists(config_dict['dataset_paths']['valid_labels'])
val_paths_data_loader = torch.utils.data.DataLoader(
    module_data.CINIC10Dataset(
        data_dir="",
        image_paths=val_image_paths,
        labels=val_labels,
        return_paths=True,
        **dataset_args
    ),
    **data_loader_args
)

# Create data loader for covariance matrix
covariance_image_paths = read_lists(config_dict['covariance_dataset']['images'])
covariance_labels = read_lists(config_dict['covariance_dataset']['labels'])

covariance_data_loader = torch.utils.data.DataLoader(
    module_data.CINIC10Dataset(
        data_dir="",
        image_paths=covariance_image_paths,
        labels=covariance_labels,
        **dataset_args
    ),
    **data_loader_args
)

In [7]:
# Obtain timestamp
paths_timestamp = '0126_161209'
timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
# timestamp = '0120_155829'

In [None]:
for target_class_idx, target_class_name in enumerate(class_list):
    # if target_class_idx == 0: 
    #     continue
    # Create save directories and logging paths
    save_root = config_dict['trainer']['save_dir']
    save_trials_path = os.path.join(save_root, config_dict['name'], '{}_{}'.format(target_class_name, n_select), timestamp, 'trial_paths.txt')
    progress_report_path = os.path.join(save_root, config_dict['name'], '{}_{}'.format(target_class_name, n_select), timestamp, 'progress_report.txt')
    informal_log("Current target class: {}".format(target_class_name), progress_report_path)
    
    
    if os.path.exists(save_trials_path):
        print("Path {} already exists. Overwriting.".format(save_trials_path))
    else:
        if os.path.exists(progress_report_path):
            os.remove(progress_report_path)
        print("Printing progress reports to {}".format(progress_report_path))
        informal_log("Saving path to directories for each trial to {}".format(save_trials_path), progress_report_path)
    
    # Obtain paths for keys and values
    paths_dir = os.path.join('paths', 'edits', 'semantics', '{}_{}'.format(target_class_name, n_select), paths_timestamp)
    key_image_paths_path = os.path.join(paths_dir, 'key_images_{}.txt'.format(sort_type))
    key_image_paths = read_lists(key_image_paths_path)

    value_image_paths_path = os.path.join(paths_dir, 'value_images_{}.txt'.format(sort_type))
    value_image_paths = read_lists(value_image_paths_path)
    n_trials = len(value_image_paths)
    assert len(key_image_paths) == n_trials

    non_existent_key_paths = ensure_files(key_image_paths)
    non_existent_value_paths = ensure_files(value_image_paths)
    
    if len(non_existent_key_paths) > 0:
        raise ValueError("Following paths are non existent: {}".format(non_existent_key_paths))

    if len(non_existent_value_paths) > 0:
        raise ValueError("Following paths are non existent: {}".format(non_existent_value_paths))
        
    informal_log("Key image paths stored at {}".format(key_image_paths_path), progress_report_path)
    informal_log("Value image paths stored at {}".format(value_image_paths_path), progress_report_path)
    
    # Run edit for each key and value pair
    for idx, (key_path, value_path) in enumerate(zip(key_image_paths, value_image_paths)):
        split = os.path.basename(os.path.dirname(os.path.dirname(key_path)))
        class_name = os.path.basename(os.path.dirname(key_path))
        file_name = os.path.basename(key_path).split(".")[0]
        key_image_id = "{}-{}-{}".format(class_name, split, file_name)
        # Print Progress
        informal_log("({}) Starting Trial {}/{}...".format(datetime.now().strftime(r'%m%d_%H%M%S'), idx + 1, n_trials), progress_report_path)

        # Create run id 
        value_image_id = os.path.splitext(os.path.basename(value_path))[0]
        run_id = os.path.join('{}_{}'.format(target_class_name, n_select), timestamp, 'results', key_image_id, value_image_id)
        informal_log("Current run_id: {}".format(run_id), progress_report_path)

        # Read config file as json and make updates to key and value paths
        config_dict = read_json(config_path)
        config_dict['editor'].update({
            'key_paths_file': key_path,
            'value_paths_file': value_path
        })

        # Create config object
        config = ConfigParser(config_dict, run_id=run_id)

        # Log the current trial path
        informal_log(os.path.dirname(config.save_dir), save_trials_path)

        informal_log("Calling edit()...", progress_report_path)

        edit(
            config=config,
            val_paths_data_loader=val_paths_data_loader,
            covariance_data_loader=covariance_data_loader,
            do_analyze_knn=analyze_in_edit)

        # Print progress
        informal_log("Finished trial {}/{}. Results saved to {}".format(idx + 1, n_trials, os.path.dirname(config.save_dir)),
                    progress_report_path)



Current target class: airplane
Printing progress reports to saved/edit/trials/CINIC10_ImageNet-VGG_16/airplane_100/0214_112633/progress_report.txt
Saving path to directories for each trial to saved/edit/trials/CINIC10_ImageNet-VGG_16/airplane_100/0214_112633/trial_paths.txt
Key image paths stored at paths/edits/semantics/airplane_100/0126_161209/key_images_softmax.txt
Value image paths stored at paths/edits/semantics/airplane_100/0126_161209/value_images_softmax.txt
(0214_112633) Starting Trial 1/158...
Current run_id: airplane_100/0214_112633/results/airplane-train-n03365231_4635/felzenszwalb_masked_softmax
saved/edit/trials/CINIC10_ImageNet-VGG_16/airplane_100/0214_112633/results/airplane-train-n03365231_4635/felzenszwalb_masked_softmax
Calling edit()...
Created ModelWrapperSanturkar model with 33646666 trainable parameters
Restored weights from external_code/PyTorch_CIFAR10/cifar10_models/state_dicts/vgg16_bn.pt
Using passed in data loader for validation.
Key images: data/cinic-10-i

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 27.00it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0058, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:06<00:00, 602.54it/s]

Loss (orig, final): 0.07915869355201721 0.0057900333777070045
L2 norm of weight change: 0.36614692211151123
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.19it/s]


Post-edit metrics: {'TP': array([6587, 5243, 5310, 3921, 3953, 2634, 5607, 5282, 3323, 4428]), 'TN': array([56758, 61325, 58795, 59692, 60703, 62217, 61131, 61292, 62708,
       61667]), 'FPs': array([6242, 1675, 4205, 3308, 2297,  783, 1869, 1708,  292, 1333]), 'FNs': array([ 413, 1757, 1690, 3079, 3047, 4366, 1393, 1718, 3677, 2572]), 'accuracy': 0.6612571428571429, 'per_class_accuracy': array([0.90492857, 0.95097143, 0.91578571, 0.90875714, 0.92365714,
       0.92644286, 0.9534    , 0.95105714, 0.9433    , 0.94421429]), 'per_class_accuracy_mean': 0.9322514285714287, 'precision': array([0.5134461 , 0.757878  , 0.55806621, 0.54239867, 0.63248   ,
       0.77085162, 0.75      , 0.75565093, 0.91922545, 0.76861656]), 'precision_mean': 0.6968613544637976, 'recall': array([0.941     , 0.749     , 0.75857143, 0.56014286, 0.56471429,
       0.37628571, 0.801     , 0.75457143, 0.47471429, 0.63257143]), 'recall_mean': 0.6612571428571428, 'predicted_class_distribution': array([12829,  6918,  95

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.62it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0128, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:07<00:00, 596.46it/s]

Loss (orig, final): 0.19816546142101288 0.0128183513879776
L2 norm of weight change: 0.9437816143035889
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.08it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6966, 4547, 3295,  436,    0,   74, 4580, 3833,    0, 3099]), 'TN': array([22542, 61951, 62553, 62839, 63000, 63000, 62568, 62734, 63000,
       62643]), 'FPs': array([40458,  1049,   447,   161,     0,     0,   432,   266,     0,
         357]), 'FNs': array([  34, 2453, 3705, 6564, 7000, 6926, 2420, 3167, 7000, 3901]), 'accuracy': 0.3832857142857143, 'per_class_accuracy': array([0.42154286, 0.94997143, 0.94068571, 0.90392857, 0.9       ,
       0.90105714, 0.95925714, 0.95095714, 0.9       , 0.93917143]), 'per_class_accuracy_mean': 0.876657142857143, 'precision': array([0.14688765, 0.81254467, 0.88054516, 0.73031826, 0.        ,
       1.        , 0.91380686, 0.93510612, 0.        , 0.89670139]), 'precision_mean': 0.6315910123421599, 'recall': array([0.99514286, 0.64957143, 0.47071429, 0.06228571, 0.        ,
       0.01057143, 0.65428571, 0.54757143, 0.        , 0.44271429]), 'recall_mean': 0.3832857142857143, 'predicted_class_distribution': array([4

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.44it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0169, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 715.70it/s]

Loss (orig, final): 0.13328108191490173 0.016880877315998077
L2 norm of weight change: 0.5766276717185974
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 43.60it/s]


Post-edit metrics: {'TP': array([6659, 5553, 5097, 2802,    3, 2463, 5369, 4894, 3513, 4212]), 'TN': array([45738, 60134, 59500, 61713, 63000, 62371, 61681, 61975, 62642,
       61811]), 'FPs': array([17262,  2866,  3500,  1287,     0,   629,  1319,  1025,   358,
        1189]), 'FNs': array([ 341, 1447, 1903, 4198, 6997, 4537, 1631, 2106, 3487, 2788]), 'accuracy': 0.5795, 'per_class_accuracy': array([0.74852857, 0.93838571, 0.92281429, 0.92164286, 0.90004286,
       0.9262    , 0.95785714, 0.95527143, 0.94507143, 0.94318571]), 'per_class_accuracy_mean': 0.9159, 'precision': array([0.27837465, 0.65957952, 0.59288124, 0.68525312, 1.        ,
       0.7965718 , 0.8027811 , 0.82682886, 0.90751744, 0.77985558]), 'precision_mean': 0.7329643302707979, 'recall': array([9.51285714e-01, 7.93285714e-01, 7.28142857e-01, 4.00285714e-01,
       4.28571429e-04, 3.51857143e-01, 7.67000000e-01, 6.99142857e-01,
       5.01857143e-01, 6.01714286e-01]), 'recall_mean': 0.5795000000000001, 'predicted_class

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.90it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0046, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 908.58it/s]

Loss (orig, final): 0.0834755077958107 0.004609189927577972
L2 norm of weight change: 0.42828142642974854
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:22<00:00, 12.17it/s]


Post-edit metrics: {'TP': array([6764, 5094, 4925, 3784, 4120, 2398, 5537, 5460, 1255, 4696]), 'TN': array([53398, 61493, 60078, 60177, 60304, 62437, 61393, 60525, 62980,
       61248]), 'FPs': array([9602, 1507, 2922, 2823, 2696,  563, 1607, 2475,   20, 1752]), 'FNs': array([ 236, 1906, 2075, 3216, 2880, 4602, 1463, 1540, 5745, 2304]), 'accuracy': 0.6290428571428571, 'per_class_accuracy': array([0.85945714, 0.95124286, 0.92861429, 0.91372857, 0.92034286,
       0.92621429, 0.95614286, 0.94264286, 0.91764286, 0.94205714]), 'per_class_accuracy_mean': 0.9258085714285714, 'precision': array([0.41329586, 0.77170126, 0.62762839, 0.5727259 , 0.60446009,
       0.80986153, 0.77505599, 0.68809074, 0.98431373, 0.72828784]), 'precision_mean': 0.6975421326568167, 'recall': array([0.96628571, 0.72771429, 0.70357143, 0.54057143, 0.58857143,
       0.34257143, 0.791     , 0.78      , 0.17928571, 0.67085714]), 'recall_mean': 0.6290428571428571, 'predicted_class_distribution': array([16366,  6601,  78

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 35.23it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0253, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 897.22it/s]

Loss (orig, final): 0.19681312143802643 0.025264278054237366
L2 norm of weight change: 0.669150710105896
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.84it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6571, 5315,    0, 3844, 4092, 2649, 5272, 4865, 4425, 4597]), 'TN': array([47717, 61176, 63000, 60008, 60424, 62207, 61810, 62003, 62008,
       61277]), 'FPs': array([15283,  1824,     0,  2992,  2576,   793,  1190,   997,   992,
        1723]), 'FNs': array([ 429, 1685, 7000, 3156, 2908, 4351, 1728, 2135, 2575, 2403]), 'accuracy': 0.5947142857142858, 'per_class_accuracy': array([0.77554286, 0.94987143, 0.9       , 0.91217143, 0.92165714,
       0.92651429, 0.95831429, 0.95525714, 0.94904286, 0.94105714]), 'per_class_accuracy_mean': 0.9189428571428572, 'precision': array([0.30067722, 0.74450203, 0.        , 0.56231714, 0.61367726,
       0.76961069, 0.81584649, 0.82992153, 0.81687281, 0.72737342]), 'precision_mean': 0.618079859438555, 'recall': array([0.93871429, 0.75928571, 0.        , 0.54914286, 0.58457143,
       0.37842857, 0.75314286, 0.695     , 0.63214286, 0.65671429]), 'recall_mean': 0.5947142857142858, 'predicted_class_distribution': array([2

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 41.59it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0237, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 818.14it/s]

Loss (orig, final): 0.17158308625221252 0.02373489737510681
L2 norm of weight change: 0.5284871459007263
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.90it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6536, 5320,    0, 3893, 4156, 2695, 5435, 4914, 4501, 4615]), 'TN': array([48845, 61173, 63000, 59904, 60287, 62148, 61595, 61932, 61942,
       61239]), 'FPs': array([14155,  1827,     0,  3096,  2713,   852,  1405,  1068,  1058,
        1761]), 'FNs': array([ 464, 1680, 7000, 3107, 2844, 4305, 1565, 2086, 2499, 2385]), 'accuracy': 0.6009285714285715, 'per_class_accuracy': array([0.79115714, 0.9499    , 0.9       , 0.91138571, 0.92061429,
       0.92632857, 0.95757143, 0.95494286, 0.94918571, 0.94077143]), 'per_class_accuracy_mean': 0.9201857142857144, 'precision': array([0.31588613, 0.74436827, 0.        , 0.55701817, 0.60503712,
       0.75979701, 0.79459064, 0.82146439, 0.809678  , 0.72380803]), 'precision_mean': 0.6131647772959874, 'recall': array([0.93371429, 0.76      , 0.        , 0.55614286, 0.59371429,
       0.385     , 0.77642857, 0.702     , 0.643     , 0.65928571]), 'recall_mean': 0.6009285714285715, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 43.16it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0144, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 906.63it/s]

Loss (orig, final): 0.14408427476882935 0.014407632872462273
L2 norm of weight change: 0.4438484013080597
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 30.50it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6447, 5450,    0, 3961, 4225, 2892, 5690, 5087, 4651, 4674]), 'TN': array([51750, 60820, 63000, 59780, 60164, 61900, 61103, 61723, 61804,
       61033]), 'FPs': array([11250,  2180,     0,  3220,  2836,  1100,  1897,  1277,  1196,
        1967]), 'FNs': array([ 553, 1550, 7000, 3039, 2775, 4108, 1310, 1913, 2349, 2326]), 'accuracy': 0.6153857142857143, 'per_class_accuracy': array([0.83138571, 0.94671429, 0.9       , 0.91058571, 0.91984286,
       0.9256    , 0.95418571, 0.95442857, 0.94935714, 0.93867143]), 'per_class_accuracy_mean': 0.9230771428571428, 'precision': array([0.36429903, 0.71428571, 0.        , 0.55159449, 0.59835717,
       0.7244489 , 0.74996705, 0.79934004, 0.79545066, 0.70380967]), 'precision_mean': 0.6001552716756159, 'recall': array([0.921     , 0.77857143, 0.        , 0.56585714, 0.60357143,
       0.41314286, 0.81285714, 0.72671429, 0.66442857, 0.66771429]), 'recall_mean': 0.6153857142857143, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 41.97it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0058, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:50<00:00, 788.52it/s]

Loss (orig, final): 0.17762506008148193 0.005809440743178129
L2 norm of weight change: 0.7439542412757874
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.33it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6873, 4563, 4962, 3886, 3671, 2176, 5450, 5293,    0, 3581]), 'TN': array([47327, 61934, 59756, 59893, 61088, 62599, 61548, 60897, 63000,
       62413]), 'FPs': array([15673,  1066,  3244,  3107,  1912,   401,  1452,  2103,     0,
         587]), 'FNs': array([ 127, 2437, 2038, 3114, 3329, 4824, 1550, 1707, 7000, 3419]), 'accuracy': 0.5779285714285715, 'per_class_accuracy': array([0.77428571, 0.94995714, 0.92454286, 0.91112857, 0.92512857,
       0.92535714, 0.95711429, 0.94557143, 0.9       , 0.94277143]), 'per_class_accuracy_mean': 0.9155857142857144, 'precision': array([0.30484343, 0.81062356, 0.6046795 , 0.55569856, 0.65753179,
       0.8443927 , 0.7896262 , 0.71565711, 0.        , 0.85916507]), 'precision_mean': 0.6142217918365541, 'recall': array([0.98185714, 0.65185714, 0.70885714, 0.55514286, 0.52442857,
       0.31085714, 0.77857143, 0.75614286, 0.        , 0.51157143]), 'recall_mean': 0.5779285714285715, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.80it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0058, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:00<00:00, 659.73it/s]

Loss (orig, final): 0.16259099543094635 0.00579608790576458
L2 norm of weight change: 0.6714223623275757
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.22it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6620, 5095,    0, 3765, 4129, 2400, 5197, 4634, 4390, 4660]), 'TN': array([46159, 61476, 63000, 60140, 60279, 62453, 61924, 62226, 62023,
       61210]), 'FPs': array([16841,  1524,     0,  2860,  2721,   547,  1076,   774,   977,
        1790]), 'FNs': array([ 380, 1905, 7000, 3235, 2871, 4600, 1803, 2366, 2610, 2340]), 'accuracy': 0.5841428571428572, 'per_class_accuracy': array([0.75398571, 0.95101429, 0.9       , 0.91292857, 0.92011429,
       0.92647143, 0.95887143, 0.95514286, 0.94875714, 0.941     ]), 'per_class_accuracy_mean': 0.9168285714285715, 'precision': array([0.28217041, 0.76975374, 0.        , 0.56830189, 0.60277372,
       0.81438751, 0.82847123, 0.8568787 , 0.81796162, 0.72248062]), 'precision_mean': 0.6263179433408438, 'recall': array([0.94571429, 0.72785714, 0.        , 0.53785714, 0.58985714,
       0.34285714, 0.74242857, 0.662     , 0.62714286, 0.66571429]), 'recall_mean': 0.5841428571428571, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:21<00:00, 12.68it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0042, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 747.92it/s]

Loss (orig, final): 0.09873473644256592 0.004218454472720623
L2 norm of weight change: 0.3920460641384125
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.86it/s]


Post-edit metrics: {'TP': array([6550, 5161, 3284, 3960, 4316, 2521, 5615, 4943, 4432, 4725]), 'TN': array([53510, 61437, 62369, 59709, 59938, 62331, 61224, 61848, 62019,
       61122]), 'FPs': array([9490, 1563,  631, 3291, 3062,  669, 1776, 1152,  981, 1878]), 'FNs': array([ 450, 1839, 3716, 3040, 2684, 4479, 1385, 2057, 2568, 2275]), 'accuracy': 0.6501, 'per_class_accuracy': array([0.858     , 0.9514    , 0.9379    , 0.90955714, 0.91791429,
       0.92645714, 0.95484286, 0.95415714, 0.9493    , 0.94067143]), 'per_class_accuracy_mean': 0.9300200000000002, 'precision': array([0.40835411, 0.76754908, 0.83882503, 0.54613157, 0.58498238,
       0.79028213, 0.75970775, 0.81099262, 0.81876963, 0.71558383]), 'precision_mean': 0.7041178128118145, 'recall': array([0.93571429, 0.73728571, 0.46914286, 0.56571429, 0.61657143,
       0.36014286, 0.80214286, 0.70614286, 0.63314286, 0.675     ]), 'recall_mean': 0.6501, 'predicted_class_distribution': array([16040,  6724,  3915,  7251,  7378,  3190,

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.18it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0045, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 628.57it/s]


Loss (orig, final): 0.1518983393907547 0.004508628509938717
L2 norm of weight change: 0.6822460293769836
Performing post-edit metric & KNN calculations on validation set.


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.46it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6735, 4859, 5295, 3501, 3989, 2808, 5282, 4535, 3231,    0]), 'TN': array([46825, 61730, 58526, 60790, 60523, 62033, 61807, 62261, 62740,
       63000]), 'FPs': array([16175,  1270,  4474,  2210,  2477,   967,  1193,   739,   260,
           0]), 'FNs': array([ 265, 2141, 1705, 3499, 3011, 4192, 1718, 2465, 3769, 7000]), 'accuracy': 0.5747857142857142, 'per_class_accuracy': array([0.76514286, 0.95127143, 0.91172857, 0.91844286, 0.9216    ,
       0.9263    , 0.95841429, 0.95422857, 0.94244286, 0.9       ]), 'per_class_accuracy_mean': 0.9149571428571429, 'precision': array([0.29397643, 0.79278838, 0.54202068, 0.61302749, 0.61691927,
       0.74384106, 0.8157529 , 0.85987865, 0.92552277, 0.        ]), 'precision_mean': 0.6203727629273414, 'recall': array([0.96214286, 0.69414286, 0.75642857, 0.50014286, 0.56985714,
       0.40114286, 0.75457143, 0.64785714, 0.46157143, 0.        ]), 'recall_mean': 0.5747857142857142, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.70it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0104, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 624.10it/s]

Loss (orig, final): 0.10011674463748932 0.010351084172725677
L2 norm of weight change: 0.3529185354709625
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.96it/s]


Post-edit metrics: {'TP': array([6452, 5668, 5306, 3541, 4066, 3027, 5592, 4983, 4335, 3748]), 'TN': array([57232, 60503, 58806, 60575, 60490, 61739, 61158, 61876, 62061,
       62278]), 'FPs': array([5768, 2497, 4194, 2425, 2510, 1261, 1842, 1124,  939,  722]), 'FNs': array([ 548, 1332, 1694, 3459, 2934, 3973, 1408, 2017, 2665, 3252]), 'accuracy': 0.6674, 'per_class_accuracy': array([0.90977143, 0.9453    , 0.91588571, 0.91594286, 0.92222857,
       0.92522857, 0.95357143, 0.95512857, 0.94851429, 0.94322857]), 'per_class_accuracy_mean': 0.93348, 'precision': array([0.52798691, 0.69418249, 0.55852632, 0.59353   , 0.618309  ,
       0.70592351, 0.75221953, 0.81594891, 0.82195677, 0.83847875]), 'precision_mean': 0.6927062181195082, 'recall': array([0.92171429, 0.80971429, 0.758     , 0.50585714, 0.58085714,
       0.43242857, 0.79885714, 0.71185714, 0.61928571, 0.53542857]), 'recall_mean': 0.6674, 'predicted_class_distribution': array([12220,  8165,  9500,  5966,  6576,  4288,  7434,  61

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.67it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0073, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:01<00:00, 647.66it/s]

Loss (orig, final): 0.09693054854869843 0.007289453875273466
L2 norm of weight change: 0.4085290729999542
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.53it/s]


Post-edit metrics: {'TP': array([6372, 4957, 4981, 4160, 4253, 2807, 5503, 4253, 5023, 4259]), 'TN': array([57120, 61619, 59966, 59084, 60089, 61961, 61429, 62566, 60865,
       61869]), 'FPs': array([5880, 1381, 3034, 3916, 2911, 1039, 1571,  434, 2135, 1131]), 'FNs': array([ 628, 2043, 2019, 2840, 2747, 4193, 1497, 2747, 1977, 2741]), 'accuracy': 0.6652571428571429, 'per_class_accuracy': array([0.90702857, 0.95108571, 0.92781429, 0.90348571, 0.91917143,
       0.92525714, 0.95617143, 0.95455714, 0.94125714, 0.94468571]), 'per_class_accuracy_mean': 0.9330514285714286, 'precision': array([0.52007835, 0.78210792, 0.62145976, 0.51510649, 0.59366276,
       0.72984919, 0.77791914, 0.90740346, 0.70173233, 0.79016698]), 'precision_mean': 0.6939486378772713, 'recall': array([0.91028571, 0.70814286, 0.71157143, 0.59428571, 0.60757143,
       0.401     , 0.78614286, 0.60757143, 0.71757143, 0.60842857]), 'recall_mean': 0.6652571428571428, 'predicted_class_distribution': array([12252,  6338,  80

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.13it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0052, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:57<00:00, 693.81it/s]

Loss (orig, final): 0.1756054013967514 0.005170649383217096
L2 norm of weight change: 0.758717954158783
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 21.01it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6883, 5034, 4775, 3667,    0, 2428, 5174, 4362,    0, 3676]), 'TN': array([38816, 61492, 60211, 60371, 63000, 62430, 61973, 62357, 63000,
       62349]), 'FPs': array([24184,  1508,  2789,  2629,     0,   570,  1027,   643,     0,
         651]), 'FNs': array([ 117, 1966, 2225, 3333, 7000, 4572, 1826, 2638, 7000, 3324]), 'accuracy': 0.5142714285714286, 'per_class_accuracy': array([0.65284286, 0.95037143, 0.92837143, 0.91482857, 0.9       ,
       0.92654286, 0.95924286, 0.95312857, 0.9       , 0.94321429]), 'per_class_accuracy_mean': 0.9028542857142856, 'precision': array([0.22155342, 0.76948945, 0.63127975, 0.58243329, 0.        ,
       0.80987325, 0.83438155, 0.87152847, 0.        , 0.84954934]), 'precision_mean': 0.5570088519788772, 'recall': array([0.98328571, 0.71914286, 0.68214286, 0.52385714, 0.        ,
       0.34685714, 0.73914286, 0.62314286, 0.        , 0.52514286]), 'recall_mean': 0.5142714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 22.71it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0063, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 631.69it/s]

Loss (orig, final): 0.12652577459812164 0.006264050491154194
L2 norm of weight change: 0.5582863092422485
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.61it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6815, 5032, 5065, 3961, 3853, 2398, 5525, 5359,    0, 4359]), 'TN': array([51077, 61542, 59633, 59717, 60870, 62433, 61378, 60925, 63000,
       61792]), 'FPs': array([11923,  1458,  3367,  3283,  2130,   567,  1622,  2075,     0,
        1208]), 'FNs': array([ 185, 1968, 1935, 3039, 3147, 4602, 1475, 1641, 7000, 2641]), 'accuracy': 0.6052428571428572, 'per_class_accuracy': array([0.82702857, 0.95105714, 0.92425714, 0.90968571, 0.92461429,
       0.92615714, 0.95575714, 0.94691429, 0.9       , 0.94501429]), 'per_class_accuracy_mean': 0.9210485714285715, 'precision': array([0.36369943, 0.77534669, 0.60068786, 0.54679735, 0.64399131,
       0.80876897, 0.77305163, 0.72087705, 0.        , 0.78300701]), 'precision_mean': 0.6016227293882677, 'recall': array([0.97357143, 0.71885714, 0.72357143, 0.56585714, 0.55042857,
       0.34257143, 0.78928571, 0.76557143, 0.        , 0.62271429]), 'recall_mean': 0.6052428571428571, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 14.24it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0086, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:51<00:00, 776.14it/s]

Loss (orig, final): 0.07867942750453949 0.008569806814193726
L2 norm of weight change: 0.2587571144104004
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:25<00:00, 10.96it/s]


Post-edit metrics: {'TP': array([6422, 5332, 5165, 3952, 4178, 2775, 5690, 5299, 3953, 4653]), 'TN': array([58806, 61180, 59474, 59593, 60299, 62031, 60946, 61334, 62458,
       61298]), 'FPs': array([4194, 1820, 3526, 3407, 2701,  969, 2054, 1666,  542, 1702]), 'FNs': array([ 578, 1668, 1835, 3048, 2822, 4225, 1310, 1701, 3047, 2347]), 'accuracy': 0.6774142857142857, 'per_class_accuracy': array([0.93182857, 0.95017143, 0.92341429, 0.90778571, 0.9211    ,
       0.9258    , 0.95194286, 0.9519    , 0.94872857, 0.94215714]), 'per_class_accuracy_mean': 0.9354828571428572, 'precision': array([0.60493595, 0.74552573, 0.59429295, 0.53702949, 0.60735572,
       0.7411859 , 0.7347624 , 0.76080402, 0.87942158, 0.73217939]), 'precision_mean': 0.6937493107621864, 'recall': array([0.91742857, 0.76171429, 0.73785714, 0.56457143, 0.59685714,
       0.39642857, 0.81285714, 0.757     , 0.56471429, 0.66471429]), 'recall_mean': 0.6774142857142857, 'predicted_class_distribution': array([10616,  7152,  86

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:17<00:00, 15.48it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0034, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 644.34it/s]

Loss (orig, final): 0.08048973977565765 0.0033711367286741734
L2 norm of weight change: 0.2762582004070282
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.28it/s]


Post-edit metrics: {'TP': array([6398, 5334, 4372, 4062, 4290, 2833, 5646, 5100, 4694, 4686]), 'TN': array([57835, 61140, 61198, 59487, 60014, 61975, 61172, 61706, 61771,
       61117]), 'FPs': array([5165, 1860, 1802, 3513, 2986, 1025, 1828, 1294, 1229, 1883]), 'FNs': array([ 602, 1666, 2628, 2938, 2710, 4167, 1354, 1900, 2306, 2314]), 'accuracy': 0.6773571428571429, 'per_class_accuracy': array([0.91761429, 0.94962857, 0.93671429, 0.90784286, 0.91862857,
       0.92582857, 0.95454286, 0.95437143, 0.9495    , 0.94004286]), 'per_class_accuracy_mean': 0.9354714285714287, 'precision': array([0.55331661, 0.74145121, 0.70813087, 0.53623762, 0.58960968,
       0.7343183 , 0.75541879, 0.79762277, 0.7925038 , 0.71335059]), 'precision_mean': 0.6921960234425077, 'recall': array([0.914     , 0.762     , 0.62457143, 0.58028571, 0.61285714,
       0.40471429, 0.80657143, 0.72857143, 0.67057143, 0.66942857]), 'recall_mean': 0.677357142857143, 'predicted_class_distribution': array([11563,  7194,  617

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.49it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0042, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 634.41it/s]

Loss (orig, final): 0.1415470838546753 0.004236942157149315
L2 norm of weight change: 0.5658667087554932
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.88it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6549, 5241,    0, 4075, 4217, 2662, 5387, 4941, 4474, 4695]), 'TN': array([49405, 61295, 63000, 59541, 60187, 62175, 61693, 61886, 61981,
       61078]), 'FPs': array([13595,  1705,     0,  3459,  2813,   825,  1307,  1114,  1019,
        1922]), 'FNs': array([ 451, 1759, 7000, 2925, 2783, 4338, 1613, 2059, 2526, 2305]), 'accuracy': 0.6034428571428572, 'per_class_accuracy': array([0.79934286, 0.95051429, 0.9       , 0.9088    , 0.92005714,
       0.92624286, 0.95828571, 0.95467143, 0.94935714, 0.93961429]), 'per_class_accuracy_mean': 0.9206885714285715, 'precision': array([0.32510921, 0.75453498, 0.        , 0.54088134, 0.59985775,
       0.76340694, 0.80475052, 0.81601982, 0.81449117, 0.70953604]), 'precision_mean': 0.6128587783605545, 'recall': array([0.93557143, 0.74871429, 0.        , 0.58214286, 0.60242857,
       0.38028571, 0.76957143, 0.70585714, 0.63914286, 0.67071429]), 'recall_mean': 0.6034428571428571, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.36it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0080, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:00<00:00, 666.19it/s]

Loss (orig, final): 0.12984928488731384 0.008042232133448124
L2 norm of weight change: 0.5835991501808167
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:24<00:00, 11.22it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6560,    0, 4856, 4881, 4152, 1941, 5259, 4651, 4024, 3836]), 'TN': array([49189, 63000, 60272, 55876, 60313, 62694, 61922, 62273, 62387,
       62234]), 'FPs': array([13811,     0,  2728,  7124,  2687,   306,  1078,   727,   613,
         766]), 'FNs': array([ 440, 7000, 2144, 2119, 2848, 5059, 1741, 2349, 2976, 3164]), 'accuracy': 0.5737142857142857, 'per_class_accuracy': array([0.79641429, 0.9       , 0.9304    , 0.86795714, 0.92092857,
       0.92335714, 0.95972857, 0.95605714, 0.94872857, 0.94385714]), 'per_class_accuracy_mean': 0.9147428571428572, 'precision': array([0.32202641, 0.        , 0.64029536, 0.40658059, 0.6071063 ,
       0.86381842, 0.82988796, 0.86481964, 0.86780246, 0.83355063]), 'precision_mean': 0.6235887770621321, 'recall': array([0.93714286, 0.        , 0.69371429, 0.69728571, 0.59314286,
       0.27728571, 0.75128571, 0.66442857, 0.57485714, 0.548     ]), 'recall_mean': 0.5737142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:21<00:00, 12.98it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0543, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:52<00:00, 767.24it/s]

Loss (orig, final): 0.22232955694198608 0.05428614839911461
L2 norm of weight change: 0.6694274544715881
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.44it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6765, 5255, 5118, 3896, 3893, 2539, 5673, 5327,    0, 4342]), 'TN': array([52130, 61337, 59563, 59748, 60821, 62318, 60944, 61151, 63000,
       61796]), 'FPs': array([10870,  1663,  3437,  3252,  2179,   682,  2056,  1849,     0,
        1204]), 'FNs': array([ 235, 1745, 1882, 3104, 3107, 4461, 1327, 1673, 7000, 2658]), 'accuracy': 0.6115428571428572, 'per_class_accuracy': array([0.84135714, 0.95131429, 0.92401429, 0.9092    , 0.92448571,
       0.92652857, 0.95167143, 0.94968571, 0.9       , 0.94482857]), 'per_class_accuracy_mean': 0.9223085714285715, 'precision': array([0.38361213, 0.7596126 , 0.59824664, 0.54504757, 0.64113966,
       0.78826451, 0.73398887, 0.74233556, 0.        , 0.7829066 ]), 'precision_mean': 0.5975154151886893, 'recall': array([0.96642857, 0.75071429, 0.73114286, 0.55657143, 0.55614286,
       0.36271429, 0.81042857, 0.761     , 0.        , 0.62028571]), 'recall_mean': 0.6115428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 14.23it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0522, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 714.70it/s]

Loss (orig, final): 0.21074600517749786 0.05221380665898323
L2 norm of weight change: 0.6158848404884338
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:23<00:00, 11.80it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6754, 5262, 5137, 3893, 3962, 2556, 5678, 5315,    0, 4385]), 'TN': array([52424, 61327, 59545, 59759, 60716, 62302, 60942, 61174, 63000,
       61753]), 'FPs': array([10576,  1673,  3455,  3241,  2284,   698,  2058,  1826,     0,
        1247]), 'FNs': array([ 246, 1738, 1863, 3107, 3038, 4444, 1322, 1685, 7000, 2615]), 'accuracy': 0.6134571428571428, 'per_class_accuracy': array([0.8454    , 0.95127143, 0.92402857, 0.90931429, 0.92397143,
       0.92654286, 0.95171429, 0.94984286, 0.9       , 0.94482857]), 'per_class_accuracy_mean': 0.9226914285714285, 'precision': array([0.38972879, 0.75875991, 0.59788175, 0.54569666, 0.63432597,
       0.78549478, 0.73397104, 0.74429352, 0.        , 0.77858665]), 'precision_mean': 0.5968739074599503, 'recall': array([0.96485714, 0.75171429, 0.73385714, 0.55614286, 0.566     ,
       0.36514286, 0.81114286, 0.75928571, 0.        , 0.62642857]), 'recall_mean': 0.6134571428571429, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:20<00:00, 13.19it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0520, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:51<00:00, 780.74it/s]

Loss (orig, final): 0.2147991955280304 0.051950253546237946
L2 norm of weight change: 0.6605945229530334
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 15.12it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6755, 5244, 5124, 3977, 3888, 2551, 5661, 5337,    0, 4387]), 'TN': array([52398, 61357, 59585, 59569, 60827, 62302, 60998, 61134, 63000,
       61754]), 'FPs': array([10602,  1643,  3415,  3431,  2173,   698,  2002,  1866,     0,
        1246]), 'FNs': array([ 245, 1756, 1876, 3023, 3112, 4449, 1339, 1663, 7000, 2613]), 'accuracy': 0.6132, 'per_class_accuracy': array([0.84504286, 0.95144286, 0.92441429, 0.9078    , 0.9245    ,
       0.92647143, 0.95227143, 0.94958571, 0.9       , 0.94487143]), 'per_class_accuracy_mean': 0.92264, 'precision': array([0.38918016, 0.76143459, 0.60007027, 0.53685205, 0.6414783 ,
       0.78516467, 0.73874462, 0.74094127, 0.        , 0.77880348]), 'precision_mean': 0.5972669403356059, 'recall': array([0.965     , 0.74914286, 0.732     , 0.56814286, 0.55542857,
       0.36442857, 0.80871429, 0.76242857, 0.        , 0.62671429]), 'recall_mean': 0.6132, 'predicted_class_distribution': array([17357,  6887,  8539,  7408,  6061, 

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 16.53it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0070, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:57<00:00, 697.48it/s]

Loss (orig, final): 0.17569628357887268 0.006957086734473705
L2 norm of weight change: 0.6688560247421265
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.31it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6830, 5274, 5038, 3811, 3621, 2477, 5605, 5336,    0, 3987]), 'TN': array([50163, 61256, 59632, 60030, 61226, 62380, 61173, 60984, 63000,
       62135]), 'FPs': array([12837,  1744,  3368,  2970,  1774,   620,  1827,  2016,     0,
         865]), 'FNs': array([ 170, 1726, 1962, 3189, 3379, 4523, 1395, 1664, 7000, 3013]), 'accuracy': 0.5997, 'per_class_accuracy': array([0.81418571, 0.95042857, 0.92385714, 0.91201429, 0.92638571,
       0.92652857, 0.95397143, 0.94742857, 0.9       , 0.9446    ]), 'per_class_accuracy_mean': 0.9199399999999999, 'precision': array([0.34728225, 0.75149615, 0.59933381, 0.5620115 , 0.67117702,
       0.79980626, 0.75417115, 0.7257889 , 0.        , 0.821723  ]), 'precision_mean': 0.6032790047577332, 'recall': array([0.97571429, 0.75342857, 0.71971429, 0.54442857, 0.51728571,
       0.35385714, 0.80071429, 0.76228571, 0.        , 0.56957143]), 'recall_mean': 0.5997, 'predicted_class_distribution': array([19667,  7018,  8406,  67

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.13it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0125, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 628.06it/s]

Loss (orig, final): 0.18212363123893738 0.012492194771766663
L2 norm of weight change: 0.6878743171691895
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.27it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6628, 5343, 5314, 3377, 3912, 2976, 5423, 4495, 3839,    0]), 'TN': array([48697, 61144, 58554, 60998, 60698, 61830, 61560, 62327, 62499,
       63000]), 'FPs': array([14303,  1856,  4446,  2002,  2302,  1170,  1440,   673,   501,
           0]), 'FNs': array([ 372, 1657, 1686, 3623, 3088, 4024, 1577, 2505, 3161, 7000]), 'accuracy': 0.5901, 'per_class_accuracy': array([0.79035714, 0.94981429, 0.9124    , 0.91964286, 0.923     ,
       0.9258    , 0.9569    , 0.9546    , 0.94768571, 0.9       ]), 'per_class_accuracy_mean': 0.9180200000000001, 'precision': array([0.3166595 , 0.74218641, 0.54446721, 0.62781186, 0.62954619,
       0.71780029, 0.79017922, 0.86977554, 0.88456221, 0.        ]), 'precision_mean': 0.6122988440257038, 'recall': array([0.94685714, 0.76328571, 0.75914286, 0.48242857, 0.55885714,
       0.42514286, 0.77471429, 0.64214286, 0.54842857, 0.        ]), 'recall_mean': 0.5901, 'predicted_class_distribution': array([20931,  7199,  9760,  53

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.27it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0032, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 616.09it/s]

Loss (orig, final): 0.0917544886469841 0.0031842796597629786
L2 norm of weight change: 0.3404562473297119
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.26it/s]


Post-edit metrics: {'TP': array([6476, 5425, 5303, 3760, 4073, 2887, 5521, 4969, 4330, 3779]), 'TN': array([56476, 61003, 58836, 60193, 60466, 61937, 61338, 61885, 62111,
       62278]), 'FPs': array([6524, 1997, 4164, 2807, 2534, 1063, 1662, 1115,  889,  722]), 'FNs': array([ 524, 1575, 1697, 3240, 2927, 4113, 1479, 2031, 2670, 3221]), 'accuracy': 0.6646142857142857, 'per_class_accuracy': array([0.89931429, 0.94897143, 0.91627143, 0.91361429, 0.92198571,
       0.92605714, 0.95512857, 0.95505714, 0.94915714, 0.94367143]), 'per_class_accuracy_mean': 0.9329228571428573, 'precision': array([0.49815385, 0.73093506, 0.56015633, 0.57255977, 0.61646738,
       0.73088608, 0.76862035, 0.81673241, 0.82966085, 0.8395912 ]), 'precision_mean': 0.6963763287204721, 'recall': array([0.92514286, 0.775     , 0.75757143, 0.53714286, 0.58185714,
       0.41242857, 0.78871429, 0.70985714, 0.61857143, 0.53985714]), 'recall_mean': 0.6646142857142856, 'predicted_class_distribution': array([13000,  7422,  94

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.36it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0062, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 617.32it/s]

Loss (orig, final): 0.12159065902233124 0.006190362852066755
L2 norm of weight change: 0.5683299899101257
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.29it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6583, 5240,    0, 4049, 3974, 2759, 4605, 5069, 4471, 4573]), 'TN': array([47221, 61287, 63000, 59560, 60657, 62049, 62488, 61721, 61961,
       61379]), 'FPs': array([15779,  1713,     0,  3440,  2343,   951,   512,  1279,  1039,
        1621]), 'FNs': array([ 417, 1760, 7000, 2951, 3026, 4241, 2395, 1931, 2529, 2427]), 'accuracy': 0.5903285714285714, 'per_class_accuracy': array([0.76862857, 0.95038571, 0.9       , 0.9087    , 0.9233    ,
       0.92582857, 0.95847143, 0.95414286, 0.94902857, 0.94217143]), 'per_class_accuracy_mean': 0.9180657142857143, 'precision': array([0.29438333, 0.75363153, 0.        , 0.54065963, 0.62909609,
       0.74366577, 0.89994137, 0.79851922, 0.81143376, 0.73829512]), 'precision_mean': 0.6209625818730505, 'recall': array([0.94042857, 0.74857143, 0.        , 0.57842857, 0.56771429,
       0.39414286, 0.65785714, 0.72414286, 0.63871429, 0.65328571]), 'recall_mean': 0.5903285714285713, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 27.34it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0214, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 631.55it/s]

Loss (orig, final): 0.19405096769332886 0.021448753774166107
L2 norm of weight change: 0.6955103278160095
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.85it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6810, 5053, 5039, 4309, 3475, 2400, 5505, 5280,    0, 4285]), 'TN': array([50612, 61528, 59781, 58792, 61450, 62439, 61426, 61249, 63000,
       61879]), 'FPs': array([12388,  1472,  3219,  4208,  1550,   561,  1574,  1751,     0,
        1121]), 'FNs': array([ 190, 1947, 1961, 2691, 3525, 4600, 1495, 1720, 7000, 2715]), 'accuracy': 0.6022285714285714, 'per_class_accuracy': array([0.82031429, 0.95115714, 0.926     , 0.90144286, 0.9275    ,
       0.92627143, 0.95615714, 0.95041429, 0.9       , 0.9452    ]), 'per_class_accuracy_mean': 0.9204457142857143, 'precision': array([0.35472445, 0.77440613, 0.61019617, 0.50592932, 0.69154229,
       0.81053698, 0.77765221, 0.75096003, 0.        , 0.79263781]), 'precision_mean': 0.6068585396020969, 'recall': array([0.97285714, 0.72185714, 0.71985714, 0.61557143, 0.49642857,
       0.34285714, 0.78642857, 0.75428571, 0.        , 0.61214286]), 'recall_mean': 0.6022285714285716, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.67it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0033, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:56<00:00, 711.78it/s]

Loss (orig, final): 0.06355391442775726 0.0033166962675750256
L2 norm of weight change: 0.23482918739318848
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.88it/s]


Post-edit metrics: {'TP': array([6348, 5344, 4582, 4026, 4305, 2836, 5710, 5116, 4704, 4723]), 'TN': array([58708, 61133, 60852, 59474, 60006, 61964, 61018, 61675, 61792,
       61072]), 'FPs': array([4292, 1867, 2148, 3526, 2994, 1036, 1982, 1325, 1208, 1928]), 'FNs': array([ 652, 1656, 2418, 2974, 2695, 4164, 1290, 1884, 2296, 2277]), 'accuracy': 0.6813428571428571, 'per_class_accuracy': array([0.92937143, 0.94967143, 0.93477143, 0.90714286, 0.91872857,
       0.92571429, 0.95325714, 0.95415714, 0.94994286, 0.93992857]), 'per_class_accuracy_mean': 0.9362685714285714, 'precision': array([0.59661654, 0.74109   , 0.6808321 , 0.53310381, 0.58980682,
       0.73243802, 0.74232969, 0.7942866 , 0.79566982, 0.71011878]), 'precision_mean': 0.6916292188642118, 'recall': array([0.90685714, 0.76342857, 0.65457143, 0.57514286, 0.615     ,
       0.40514286, 0.81571429, 0.73085714, 0.672     , 0.67471429]), 'recall_mean': 0.6813428571428571, 'predicted_class_distribution': array([10640,  7211,  67

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 16.34it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0070, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:59<00:00, 674.76it/s]

Loss (orig, final): 0.08682620525360107 0.007002019323408604
L2 norm of weight change: 0.27361607551574707
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 18.02it/s]


Post-edit metrics: {'TP': array([6298, 5372, 4332, 4018, 4399, 2948, 5736, 4991, 4943, 4683]), 'TN': array([58860, 61075, 61326, 59531, 59774, 61810, 60971, 61897, 61395,
       61081]), 'FPs': array([4140, 1925, 1674, 3469, 3226, 1190, 2029, 1103, 1605, 1919]), 'FNs': array([ 702, 1628, 2668, 2982, 2601, 4052, 1264, 2009, 2057, 2317]), 'accuracy': 0.6817142857142857, 'per_class_accuracy': array([0.93082857, 0.94924286, 0.93797143, 0.90784286, 0.91675714,
       0.92511429, 0.95295714, 0.95554286, 0.94768571, 0.93948571]), 'per_class_accuracy_mean': 0.9363428571428573, 'precision': array([0.60337229, 0.73619296, 0.72127872, 0.53666355, 0.57691803,
       0.71242146, 0.73869929, 0.8190023 , 0.75488699, 0.70933051]), 'precision_mean': 0.6908766096749537, 'recall': array([0.89971429, 0.76742857, 0.61885714, 0.574     , 0.62842857,
       0.42114286, 0.81942857, 0.713     , 0.70614286, 0.669     ]), 'recall_mean': 0.6817142857142857, 'predicted_class_distribution': array([10438,  7297,  60

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:27<00:00, 10.05it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0044, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 748.29it/s]

Loss (orig, final): 0.09549523890018463 0.004378254991024733
L2 norm of weight change: 0.4044760763645172
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:39<00:00,  2.75it/s]


Post-edit metrics: {'TP': array([6514, 5286, 3125, 4108, 4267, 2741, 5630, 5059, 4507, 4725]), 'TN': array([54775, 61225, 62467, 59402, 60049, 62077, 61243, 61727, 61965,
       61032]), 'FPs': array([8225, 1775,  533, 3598, 2951,  923, 1757, 1273, 1035, 1968]), 'FNs': array([ 486, 1714, 3875, 2892, 2733, 4259, 1370, 1941, 2493, 2275]), 'accuracy': 0.6566, 'per_class_accuracy': array([0.87555714, 0.95015714, 0.93702857, 0.90728571, 0.9188    ,
       0.92597143, 0.95532857, 0.95408571, 0.9496    , 0.93938571]), 'per_class_accuracy_mean': 0.93132, 'precision': array([0.44195671, 0.74861918, 0.85429196, 0.5330911 , 0.59116099,
       0.74808952, 0.76214972, 0.79895768, 0.81324432, 0.70596145]), 'precision_mean': 0.6997522622156349, 'recall': array([0.93057143, 0.75514286, 0.44642857, 0.58685714, 0.60957143,
       0.39157143, 0.80428571, 0.72271429, 0.64385714, 0.675     ]), 'recall_mean': 0.6566, 'predicted_class_distribution': array([14739,  7061,  3658,  7706,  7218,  3664,  7387,  63

100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:16<00:00,  3.60it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0030, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:49<00:00, 813.77it/s]

Loss (orig, final): 0.10101480036973953 0.0029556145891547203
L2 norm of weight change: 0.4681950509548187
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 24.95it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6514, 5286,    0, 4127, 4277, 2737, 5609, 5041, 4526, 4731]), 'TN': array([51229, 61215, 63000, 59345, 60017, 62069, 61279, 61755, 61948,
       60991]), 'FPs': array([11771,  1785,     0,  3655,  2983,   931,  1721,  1245,  1052,
        2009]), 'FNs': array([ 486, 1714, 7000, 2873, 2723, 4263, 1391, 1959, 2474, 2269]), 'accuracy': 0.6121142857142857, 'per_class_accuracy': array([0.8249    , 0.95001429, 0.9       , 0.90674286, 0.91848571,
       0.9258    , 0.95554286, 0.95422857, 0.94962857, 0.93888571]), 'per_class_accuracy_mean': 0.922422857142857, 'precision': array([0.35624829, 0.74756046, 0.        , 0.53032639, 0.58911846,
       0.74618321, 0.76521146, 0.80194082, 0.81140194, 0.70192878]), 'precision_mean': 0.6049919806995567, 'recall': array([0.93057143, 0.75514286, 0.        , 0.58957143, 0.611     ,
       0.391     , 0.80128571, 0.72014286, 0.64657143, 0.67585714]), 'recall_mean': 0.6121142857142857, 'predicted_class_distribution': array([1

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 14.04it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0041, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:45<00:00, 877.79it/s]

Loss (orig, final): 0.13322454690933228 0.004062718711793423
L2 norm of weight change: 0.5569638013839722
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:43<00:00,  2.66it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6583, 5269,    0, 4035, 4156, 2626, 5455, 4945, 4365, 4679]), 'TN': array([49110, 61232, 63000, 59609, 60324, 62210, 61571, 61850, 62087,
       61120]), 'FPs': array([13890,  1768,     0,  3391,  2676,   790,  1429,  1150,   913,
        1880]), 'FNs': array([ 417, 1731, 7000, 2965, 2844, 4374, 1545, 2055, 2635, 2321]), 'accuracy': 0.6016142857142858, 'per_class_accuracy': array([0.79561429, 0.95001429, 0.9       , 0.9092    , 0.92114286,
       0.92622857, 0.95751429, 0.95421429, 0.94931429, 0.93998571]), 'per_class_accuracy_mean': 0.9203228571428573, 'precision': array([0.32154545, 0.74875657, 0.        , 0.54336116, 0.60831382,
       0.76873536, 0.7924172 , 0.81132075, 0.82701781, 0.71337094]), 'precision_mean': 0.6134839070803886, 'recall': array([0.94042857, 0.75271429, 0.        , 0.57642857, 0.59371429,
       0.37514286, 0.77928571, 0.70642857, 0.62357143, 0.66842857]), 'recall_mean': 0.6016142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:07<00:00,  4.05it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0071, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:49<00:00, 810.80it/s]

Loss (orig, final): 0.06533744186162949 0.007131048012524843
L2 norm of weight change: 0.24099388718605042
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 15.02it/s]


Post-edit metrics: {'TP': array([6340, 5299, 4578, 4115, 4314, 2810, 5707, 5090, 4711, 4740]), 'TN': array([58786, 61206, 60880, 59236, 59985, 61990, 61036, 61742, 61777,
       61066]), 'FPs': array([4214, 1794, 2120, 3764, 3015, 1010, 1964, 1258, 1223, 1934]), 'FNs': array([ 660, 1701, 2422, 2885, 2686, 4190, 1293, 1910, 2289, 2260]), 'accuracy': 0.6814857142857143, 'per_class_accuracy': array([0.93037143, 0.95007143, 0.93511429, 0.90501429, 0.91855714,
       0.92571429, 0.95347143, 0.95474286, 0.94982857, 0.94008571]), 'per_class_accuracy_mean': 0.9362971428571429, 'precision': array([0.60072011, 0.74707458, 0.68348761, 0.5222744 , 0.58862055,
       0.73560209, 0.7439708 , 0.80182735, 0.79389956, 0.71021876]), 'precision_mean': 0.6927695805507137, 'recall': array([0.90571429, 0.757     , 0.654     , 0.58785714, 0.61628571,
       0.40142857, 0.81528571, 0.72714286, 0.673     , 0.67714286]), 'recall_mean': 0.6814857142857143, 'predicted_class_distribution': array([10554,  7093,  66

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:21<00:00, 12.52it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0061, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 817.40it/s]

Loss (orig, final): 0.1514863520860672 0.006135610397905111
L2 norm of weight change: 0.7275477647781372
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:29<00:00,  9.29it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6905, 5073, 3915, 3569, 3354, 2137, 5549, 5146,    0, 4178]), 'TN': array([44211, 61450, 61823, 60563, 61560, 62618, 61379, 61254, 63000,
       61968]), 'FPs': array([18789,  1550,  1177,  2437,  1440,   382,  1621,  1746,     0,
        1032]), 'FNs': array([  95, 1927, 3085, 3431, 3646, 4863, 1451, 1854, 7000, 2822]), 'accuracy': 0.5689428571428572, 'per_class_accuracy': array([0.73022857, 0.95032857, 0.93911429, 0.91617143, 0.92734286,
       0.92507143, 0.95611429, 0.94857143, 0.9       , 0.94494286]), 'per_class_accuracy_mean': 0.9137885714285714, 'precision': array([0.26873978, 0.76596708, 0.7688531 , 0.59423909, 0.69962453,
       0.84835252, 0.77391911, 0.7466628 , 0.        , 0.80191939]), 'precision_mean': 0.6268277407295224, 'recall': array([0.98642857, 0.72471429, 0.55928571, 0.50985714, 0.47914286,
       0.30528571, 0.79271429, 0.73514286, 0.        , 0.59685714]), 'recall_mean': 0.5689428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:47<00:00,  5.79it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0037, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 861.76it/s]

Loss (orig, final): 0.09510952234268188 0.003749515162780881
L2 norm of weight change: 0.41380104422569275
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:33<00:00,  8.18it/s]


Post-edit metrics: {'TP': array([6717, 5185, 4736, 3854, 4007, 2441, 5605, 5100, 3152, 4596]), 'TN': array([53795, 61362, 60402, 59956, 60565, 62406, 61208, 61607, 62751,
       61341]), 'FPs': array([9205, 1638, 2598, 3044, 2435,  594, 1792, 1393,  249, 1659]), 'FNs': array([ 283, 1815, 2264, 3146, 2993, 4559, 1395, 1900, 3848, 2404]), 'accuracy': 0.6484714285714286, 'per_class_accuracy': array([0.86445714, 0.95067143, 0.93054286, 0.91157143, 0.92245714,
       0.92638571, 0.95447143, 0.95295714, 0.94147143, 0.94195714]), 'per_class_accuracy_mean': 0.9296942857142858, 'precision': array([0.42186911, 0.75992965, 0.64575948, 0.55871267, 0.6220118 ,
       0.80428336, 0.75773962, 0.78546127, 0.92678624, 0.73477218]), 'precision_mean': 0.7017325378500721, 'recall': array([0.95957143, 0.74071429, 0.67657143, 0.55057143, 0.57242857,
       0.34871429, 0.80071429, 0.72857143, 0.45028571, 0.65657143]), 'recall_mean': 0.6484714285714286, 'predicted_class_distribution': array([15922,  6823,  73

100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:03<00:00,  4.31it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0051, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 867.65it/s]

Loss (orig, final): 0.14130845665931702 0.0051463390700519085
L2 norm of weight change: 0.6888740658760071
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:32<00:00,  8.43it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6935, 5025, 4105, 3231,   15, 2288, 5158, 4508,    0, 3695]), 'TN': array([35700, 61509, 61638, 61046, 63000, 62530, 62020, 62162, 63000,
       62355]), 'FPs': array([27300,  1491,  1362,  1954,     0,   470,   980,   838,     0,
         645]), 'FNs': array([  65, 1975, 2895, 3769, 6985, 4712, 1842, 2492, 7000, 3305]), 'accuracy': 0.49942857142857144, 'per_class_accuracy': array([0.60907143, 0.95048571, 0.93918571, 0.91824286, 0.90021429,
       0.92597143, 0.95968571, 0.95242857, 0.9       , 0.94357143]), 'per_class_accuracy_mean': 0.8998857142857144, 'precision': array([0.20257047, 0.77117864, 0.75086885, 0.62314368, 1.        ,
       0.82958666, 0.84033887, 0.84324729, 0.        , 0.85138249]), 'precision_mean': 0.6712316944948341, 'recall': array([0.99071429, 0.71785714, 0.58642857, 0.46157143, 0.00214286,
       0.32685714, 0.73685714, 0.644     , 0.        , 0.52785714]), 'recall_mean': 0.49942857142857144, 'predicted_class_distribution': array

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:55<00:00,  4.92it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0557, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 858.97it/s]

Loss (orig, final): 0.20767760276794434 0.0556865818798542
L2 norm of weight change: 0.6049613952636719
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:43<00:00,  6.35it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6724, 5324, 5131, 3965, 3913, 2665, 5701, 5334,    0, 4439]), 'TN': array([53089, 61215, 59581, 59558, 60808, 62183, 60910, 61182, 63000,
       61670]), 'FPs': array([9911, 1785, 3419, 3442, 2192,  817, 2090, 1818,    0, 1330]), 'FNs': array([ 276, 1676, 1869, 3035, 3087, 4335, 1299, 1666, 7000, 2561]), 'accuracy': 0.6170857142857142, 'per_class_accuracy': array([0.85447143, 0.95055714, 0.92445714, 0.90747143, 0.92458571,
       0.9264    , 0.95158571, 0.95022857, 0.9       , 0.94441429]), 'per_class_accuracy_mean': 0.9234171428571429, 'precision': array([0.404208  , 0.74890983, 0.60011696, 0.53530444, 0.64095004,
       0.76536473, 0.73174175, 0.74580537, 0.        , 0.76945744]), 'precision_mean': 0.5941858569865279, 'recall': array([0.96057143, 0.76057143, 0.733     , 0.56642857, 0.559     ,
       0.38071429, 0.81442857, 0.762     , 0.        , 0.63414286]), 'recall_mean': 0.6170857142857142, 'predicted_class_distribution': array([16635,  7109,  85

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:35<00:00,  7.75it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0058, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:50<00:00, 796.33it/s]

Loss (orig, final): 0.1228443831205368 0.005846063140779734
L2 norm of weight change: 0.5000464916229248
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:02<00:00,  4.41it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6521, 5188, 5503, 3552, 4132, 2966, 5373, 4742, 4075,    0]), 'TN': array([51275, 61390, 57539, 60551, 60329, 61836, 61594, 62192, 62346,
       63000]), 'FPs': array([11725,  1610,  5461,  2449,  2671,  1164,  1406,   808,   654,
           0]), 'FNs': array([ 479, 1812, 1497, 3448, 2868, 4034, 1627, 2258, 2925, 7000]), 'accuracy': 0.6007428571428571, 'per_class_accuracy': array([0.82565714, 0.95111429, 0.9006    , 0.91575714, 0.92087143,
       0.92574286, 0.95667143, 0.9562    , 0.94887143, 0.9       ]), 'per_class_accuracy_mean': 0.9201485714285715, 'precision': array([0.3573934 , 0.76316564, 0.50191536, 0.59190135, 0.6073791 ,
       0.71815981, 0.79259478, 0.85441441, 0.86170438, 0.        ]), 'precision_mean': 0.6048628220782863, 'recall': array([0.93157143, 0.74114286, 0.78614286, 0.50742857, 0.59028571,
       0.42371429, 0.76757143, 0.67742857, 0.58214286, 0.        ]), 'recall_mean': 0.6007428571428571, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:38<00:00,  7.13it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0065, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:54<00:00, 737.18it/s]

Loss (orig, final): 0.12332012504339218 0.0064758555963635445
L2 norm of weight change: 0.4898682236671448
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.50it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6379, 5271, 5704, 3395, 4167, 2998, 5328, 4671, 4344,    0]), 'TN': array([53079, 61283, 56055, 60804, 60255, 61791, 61642, 62321, 62027,
       63000]), 'FPs': array([9921, 1717, 6945, 2196, 2745, 1209, 1358,  679,  973,    0]), 'FNs': array([ 621, 1729, 1296, 3605, 2833, 4002, 1672, 2329, 2656, 7000]), 'accuracy': 0.6036714285714285, 'per_class_accuracy': array([0.8494    , 0.95077143, 0.88227143, 0.91712857, 0.92031429,
       0.92555714, 0.95671429, 0.95702857, 0.94815714, 0.9       ]), 'per_class_accuracy_mean': 0.9207342857142858, 'precision': array([0.39134969, 0.75429307, 0.45094474, 0.6072259 , 0.60286458,
       0.71262182, 0.79688902, 0.87308411, 0.81700207, 0.        ]), 'precision_mean': 0.6006275011503048, 'recall': array([0.91128571, 0.753     , 0.81485714, 0.485     , 0.59528571,
       0.42828571, 0.76114286, 0.66728571, 0.62057143, 0.        ]), 'recall_mean': 0.6036714285714286, 'predicted_class_distribution': array([16300,  6988, 126

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.17it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0032, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:59<00:00, 677.67it/s]

Loss (orig, final): 0.09498517960309982 0.003190010553225875
L2 norm of weight change: 0.31329286098480225
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:20<00:00, 13.24it/s]


Post-edit metrics: {'TP': array([6417, 5382, 5263, 4008, 3189, 2861, 5507, 5103, 4450, 4592]), 'TN': array([56936, 61031, 59021, 59663, 61840, 61964, 61367, 61702, 61986,
       61262]), 'FPs': array([6064, 1969, 3979, 3337, 1160, 1036, 1633, 1298, 1014, 1738]), 'FNs': array([ 583, 1618, 1737, 2992, 3811, 4139, 1493, 1897, 2550, 2408]), 'accuracy': 0.6681714285714285, 'per_class_accuracy': array([0.90504286, 0.94875714, 0.91834286, 0.90958571, 0.92898571,
       0.92607143, 0.95534286, 0.95435714, 0.94908571, 0.94077143]), 'per_class_accuracy_mean': 0.9336342857142859, 'precision': array([0.5141415 , 0.73214529, 0.56946548, 0.54567733, 0.73327202,
       0.73415448, 0.77128852, 0.79721918, 0.81442167, 0.72543444]), 'precision_mean': 0.693721989916064, 'recall': array([0.91671429, 0.76885714, 0.75185714, 0.57257143, 0.45557143,
       0.40871429, 0.78671429, 0.729     , 0.63571429, 0.656     ]), 'recall_mean': 0.6681714285714284, 'predicted_class_distribution': array([12481,  7351,  924

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:36<00:00,  7.50it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0044, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 819.82it/s]

Loss (orig, final): 0.06192366033792496 0.0044164350256323814
L2 norm of weight change: 0.2132744938135147
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 17.04it/s]


Post-edit metrics: {'TP': array([6298, 5408, 5273, 3992, 3695, 2942, 5626, 5170, 4587, 4660]), 'TN': array([59065, 61029, 59113, 59579, 61215, 61854, 61116, 61624, 61885,
       61171]), 'FPs': array([3935, 1971, 3887, 3421, 1785, 1146, 1884, 1376, 1115, 1829]), 'FNs': array([ 702, 1592, 1727, 3008, 3305, 4058, 1374, 1830, 2413, 2340]), 'accuracy': 0.6807285714285715, 'per_class_accuracy': array([0.93375714, 0.9491    , 0.9198    , 0.90815714, 0.92728571,
       0.92565714, 0.95345714, 0.9542    , 0.9496    , 0.94044286]), 'per_class_accuracy_mean': 0.9361457142857142, 'precision': array([0.61545979, 0.73289064, 0.57565502, 0.53851342, 0.67427007,
       0.71966732, 0.74913449, 0.78979529, 0.80445458, 0.71813839]), 'precision_mean': 0.6917979006295811, 'recall': array([0.89971429, 0.77257143, 0.75328571, 0.57028571, 0.52785714,
       0.42028571, 0.80371429, 0.73857143, 0.65528571, 0.66571429]), 'recall_mean': 0.6807285714285713, 'predicted_class_distribution': array([10233,  7379,  91

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:26<00:00, 10.26it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0115, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:49<00:00, 812.45it/s]

Loss (orig, final): 0.18947891891002655 0.011542519554495811
L2 norm of weight change: 0.6846410036087036
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:59<00:00,  4.59it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6589, 5177, 5335, 3482, 4109, 2992, 5321, 4519, 4070,    0]), 'TN': array([49336, 61392, 58462, 60839, 60307, 61795, 61770, 62340, 62353,
       63000]), 'FPs': array([13664,  1608,  4538,  2161,  2693,  1205,  1230,   660,   647,
           0]), 'FNs': array([ 411, 1823, 1665, 3518, 2891, 4008, 1679, 2481, 2930, 7000]), 'accuracy': 0.5942, 'per_class_accuracy': array([0.79892857, 0.95098571, 0.91138571, 0.91887143, 0.92022857,
       0.92552857, 0.95844286, 0.95512857, 0.9489    , 0.9       ]), 'per_class_accuracy_mean': 0.91884, 'precision': array([0.32533452, 0.76300663, 0.54036261, 0.61704767, 0.60408703,
       0.71289016, 0.81224241, 0.87256227, 0.86283655, 0.        ]), 'precision_mean': 0.6110369843349244, 'recall': array([0.94128571, 0.73957143, 0.76214286, 0.49742857, 0.587     ,
       0.42742857, 0.76014286, 0.64557143, 0.58142857, 0.        ]), 'recall_mean': 0.5942000000000001, 'predicted_class_distribution': array([20253,  6785,  9873,  5

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:35<00:00,  7.63it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0049, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:49<00:00, 804.67it/s]

Loss (orig, final): 0.1059262603521347 0.00485608447343111
L2 norm of weight change: 0.3615017235279083
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:37<00:00,  7.29it/s]


Post-edit metrics: {'TP': array([6440, 5439, 5286, 3669, 4232, 2980, 5493, 4888, 4481, 2889]), 'TN': array([55636, 60994, 58833, 60365, 60130, 61787, 61408, 62023, 61916,
       62705]), 'FPs': array([7364, 2006, 4167, 2635, 2870, 1213, 1592,  977, 1084,  295]), 'FNs': array([ 560, 1561, 1714, 3331, 2768, 4020, 1507, 2112, 2519, 4111]), 'accuracy': 0.6542428571428571, 'per_class_accuracy': array([0.8868    , 0.94904286, 0.91598571, 0.91477143, 0.91945714,
       0.92524286, 0.95572857, 0.95587143, 0.94852857, 0.93705714]), 'per_class_accuracy_mean': 0.9308485714285715, 'precision': array([0.46653144, 0.73055742, 0.55918756, 0.58201142, 0.59588848,
       0.71070832, 0.77529993, 0.83341858, 0.80521114, 0.90734925]), 'precision_mean': 0.6966163549133774, 'recall': array([0.92      , 0.777     , 0.75514286, 0.52414286, 0.60457143,
       0.42571429, 0.78471429, 0.69828571, 0.64014286, 0.41271429]), 'recall_mean': 0.6542428571428571, 'predicted_class_distribution': array([13804,  7445,  94

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:25<00:00, 10.81it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0188, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 846.67it/s]

Loss (orig, final): 0.19090276956558228 0.018752720206975937
L2 norm of weight change: 0.7326511144638062
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:15<00:00,  3.63it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6860, 5162, 4928,    1, 3510, 1979, 5417, 5181,    0, 3962]), 'TN': array([40816, 61283, 59897, 62997, 61414, 62716, 61582, 61123, 63000,
       62172]), 'FPs': array([22184,  1717,  3103,     3,  1586,   284,  1418,  1877,     0,
         828]), 'FNs': array([ 140, 1838, 2072, 6999, 3490, 5021, 1583, 1819, 7000, 3038]), 'accuracy': 0.5285714285714286, 'per_class_accuracy': array([0.68108571, 0.94921429, 0.92607143, 0.89997143, 0.92748571,
       0.92421429, 0.95712857, 0.9472    , 0.9       , 0.94477143]), 'per_class_accuracy_mean': 0.9057142857142857, 'precision': array([0.23619336, 0.75039977, 0.61362221, 0.25      , 0.68877551,
       0.87450287, 0.79253841, 0.73406064, 0.        , 0.82713987]), 'precision_mean': 0.5767232646036599, 'recall': array([9.80000000e-01, 7.37428571e-01, 7.04000000e-01, 1.42857143e-04,
       5.01428571e-01, 2.82714286e-01, 7.73857143e-01, 7.40142857e-01,
       0.00000000e+00, 5.66000000e-01]), 'recall_mean': 0.5285714285

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:22<00:00, 11.94it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0547, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 860.26it/s]

Loss (orig, final): 0.21581627428531647 0.05470753833651543
L2 norm of weight change: 0.6321216225624084
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:13<00:00,  3.74it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6757, 5278, 5110, 3897, 3966, 2569, 5685, 5320,    0, 4400]), 'TN': array([52420, 61307, 59639, 59773, 60718, 62292, 60933, 61172, 63000,
       61728]), 'FPs': array([10580,  1693,  3361,  3227,  2282,   708,  2067,  1828,     0,
        1272]), 'FNs': array([ 243, 1722, 1890, 3103, 3034, 4431, 1315, 1680, 7000, 2600]), 'accuracy': 0.6140285714285715, 'per_class_accuracy': array([0.84538571, 0.95121429, 0.92498571, 0.90957143, 0.92405714,
       0.92658571, 0.95168571, 0.94988571, 0.9       , 0.94468571]), 'per_class_accuracy_mean': 0.9228057142857145, 'precision': array([0.38974448, 0.75713671, 0.60323456, 0.54702414, 0.63476312,
       0.78394873, 0.73335913, 0.74426413, 0.        , 0.77574048]), 'precision_mean': 0.5969215495379329, 'recall': array([0.96528571, 0.754     , 0.73      , 0.55671429, 0.56657143,
       0.367     , 0.81214286, 0.76      , 0.        , 0.62857143]), 'recall_mean': 0.6140285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:50<00:00,  5.44it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0040, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 740.81it/s]

Loss (orig, final): 0.13472872972488403 0.0040040938183665276
L2 norm of weight change: 0.5197356939315796
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 16.40it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6536, 5296,    0, 4067, 4224, 2701, 5500, 4991, 4493, 4690]), 'TN': array([50108, 61200, 63000, 59519, 60169, 62122, 61504, 61837, 61972,
       61067]), 'FPs': array([12892,  1800,     0,  3481,  2831,   878,  1496,  1163,  1028,
        1933]), 'FNs': array([ 464, 1704, 7000, 2933, 2776, 4299, 1500, 2009, 2507, 2310]), 'accuracy': 0.6071142857142857, 'per_class_accuracy': array([0.8092    , 0.94994286, 0.9       , 0.90837143, 0.9199    ,
       0.92604286, 0.9572    , 0.95468571, 0.9495    , 0.93938571]), 'per_class_accuracy_mean': 0.9214228571428572, 'precision': array([0.33642166, 0.74633596, 0.        , 0.53881823, 0.59872431,
       0.75468008, 0.78616352, 0.81101722, 0.81380185, 0.70813831]), 'precision_mean': 0.6094101140590492, 'recall': array([0.93371429, 0.75657143, 0.        , 0.581     , 0.60342857,
       0.38585714, 0.78571429, 0.713     , 0.64185714, 0.67      ]), 'recall_mean': 0.6071142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 14.97it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0028, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 848.82it/s]

Loss (orig, final): 0.0781460851430893 0.002840617671608925
L2 norm of weight change: 0.26978766918182373
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [01:01<00:00,  4.47it/s]


Post-edit metrics: {'TP': array([6382, 5318, 4393, 4096, 4312, 2812, 5697, 5094, 4701, 4713]), 'TN': array([58228, 61161, 61180, 59357, 59967, 61988, 61058, 61719, 61784,
       61076]), 'FPs': array([4772, 1839, 1820, 3643, 3033, 1012, 1942, 1281, 1216, 1924]), 'FNs': array([ 618, 1682, 2607, 2904, 2688, 4188, 1303, 1906, 2299, 2287]), 'accuracy': 0.6788285714285714, 'per_class_accuracy': array([0.923     , 0.9497    , 0.93675714, 0.90647143, 0.91827143,
       0.92571429, 0.95364286, 0.95447143, 0.94978571, 0.93984286]), 'per_class_accuracy_mean': 0.9357657142857141, 'precision': array([0.57217142, 0.74304876, 0.70706583, 0.52926735, 0.58706603,
       0.73535565, 0.74577824, 0.79905882, 0.79449045, 0.71010999]), 'precision_mean': 0.6923412545987611, 'recall': array([0.91171429, 0.75971429, 0.62757143, 0.58514286, 0.616     ,
       0.40171429, 0.81385714, 0.72771429, 0.67157143, 0.67328571]), 'recall_mean': 0.6788285714285713, 'predicted_class_distribution': array([11154,  7157,  62

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:40<00:00,  6.69it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0032, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 825.74it/s]

Loss (orig, final): 0.12995630502700806 0.003233768045902252
L2 norm of weight change: 0.5119864344596863
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:17<00:00, 15.38it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6511, 5292,    0, 4103, 4271, 2728, 5524, 5000, 4558, 4704]), 'TN': array([50652, 61208, 63000, 59427, 60048, 62088, 61476, 61831, 61910,
       61051]), 'FPs': array([12348,  1792,     0,  3573,  2952,   912,  1524,  1169,  1090,
        1949]), 'FNs': array([ 489, 1708, 7000, 2897, 2729, 4272, 1476, 2000, 2442, 2296]), 'accuracy': 0.6098714285714286, 'per_class_accuracy': array([0.81661429, 0.95      , 0.9       , 0.90757143, 0.91884286,
       0.92594286, 0.95714286, 0.95472857, 0.94954286, 0.93935714]), 'per_class_accuracy_mean': 0.9219742857142856, 'precision': array([0.3452463 , 0.74703557, 0.        , 0.53452319, 0.59130555,
       0.74945055, 0.78376844, 0.81050413, 0.80701133, 0.70704945]), 'precision_mean': 0.6075894526284606, 'recall': array([0.93014286, 0.756     , 0.        , 0.58614286, 0.61014286,
       0.38971429, 0.78914286, 0.71428571, 0.65114286, 0.672     ]), 'recall_mean': 0.6098714285714285, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:30<00:00,  9.11it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0033, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 721.10it/s]

Loss (orig, final): 0.13241255283355713 0.003343985415995121
L2 norm of weight change: 0.5169293880462646
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 22.11it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6521, 5286,    0, 4100, 4262, 2714, 5514, 4993, 4542, 4700]), 'TN': array([50481, 61220, 63000, 59442, 60069, 62100, 61505, 61832, 61926,
       61057]), 'FPs': array([12519,  1780,     0,  3558,  2931,   900,  1495,  1168,  1074,
        1943]), 'FNs': array([ 479, 1714, 7000, 2900, 2738, 4286, 1486, 2007, 2458, 2300]), 'accuracy': 0.6090285714285715, 'per_class_accuracy': array([0.81431429, 0.95008571, 0.9       , 0.90774286, 0.91901429,
       0.92591429, 0.95741429, 0.95464286, 0.94954286, 0.93938571]), 'per_class_accuracy_mean': 0.9218057142857143, 'precision': array([0.3424895 , 0.74808944, 0.        , 0.53538783, 0.59252051,
       0.75096846, 0.78670281, 0.81042039, 0.80876068, 0.70751167]), 'precision_mean': 0.6082851277120451, 'recall': array([0.93157143, 0.75514286, 0.        , 0.58571429, 0.60885714,
       0.38771429, 0.78771429, 0.71328571, 0.64885714, 0.67142857]), 'recall_mean': 0.6090285714285715, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.84it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0119, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:49<00:00, 807.46it/s]

Loss (orig, final): 0.13259945809841156 0.011858126148581505
L2 norm of weight change: 0.5522463917732239
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:45<00:00,  6.04it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6749, 5120, 4957, 4325, 3833, 2502, 5612, 5349,    0, 4578]), 'TN': array([52915, 61494, 60022, 58535, 60942, 62336, 61167, 61107, 63000,
       61507]), 'FPs': array([10085,  1506,  2978,  4465,  2058,   664,  1833,  1893,     0,
        1493]), 'FNs': array([ 251, 1880, 2043, 2675, 3167, 4498, 1388, 1651, 7000, 2422]), 'accuracy': 0.6146428571428572, 'per_class_accuracy': array([0.85234286, 0.95162857, 0.92827143, 0.898     , 0.92535714,
       0.92625714, 0.95398571, 0.94937143, 0.9       , 0.94407143]), 'per_class_accuracy_mean': 0.9229285714285714, 'precision': array([0.40091482, 0.77271355, 0.62470069, 0.49203641, 0.65065354,
       0.79027164, 0.75379449, 0.73860812, 0.        , 0.75407676]), 'precision_mean': 0.5977770012106436, 'recall': array([0.96414286, 0.73142857, 0.70814286, 0.61785714, 0.54757143,
       0.35742857, 0.80171429, 0.76414286, 0.        , 0.654     ]), 'recall_mean': 0.6146428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:35<00:00,  7.71it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0100, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 864.30it/s]

Loss (orig, final): 0.0866127759218216 0.010029886849224567
L2 norm of weight change: 0.2927623987197876
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:54<00:00,  5.03it/s]


Post-edit metrics: {'TP': array([6344, 5459, 5475, 3676, 3976, 2863, 5654, 5343, 3858, 4491]), 'TN': array([59140, 60940, 58055, 60181, 60691, 61946, 60953, 61276, 62467,
       61490]), 'FPs': array([3860, 2060, 4945, 2819, 2309, 1054, 2047, 1724,  533, 1510]), 'FNs': array([ 656, 1541, 1525, 3324, 3024, 4137, 1346, 1657, 3142, 2509]), 'accuracy': 0.6734142857142857, 'per_class_accuracy': array([0.93548571, 0.94855714, 0.90757143, 0.91224286, 0.92381429,
       0.92584286, 0.95152857, 0.9517    , 0.9475    , 0.94258571]), 'per_class_accuracy_mean': 0.934682857142857, 'precision': array([0.62171697, 0.7260274 , 0.52543186, 0.56597383, 0.63261734,
       0.73091652, 0.73419036, 0.75604924, 0.87861535, 0.74837527]), 'precision_mean': 0.6919914147658146, 'recall': array([0.90628571, 0.77985714, 0.78214286, 0.52514286, 0.568     ,
       0.409     , 0.80771429, 0.76328571, 0.55114286, 0.64157143]), 'recall_mean': 0.6734142857142857, 'predicted_class_distribution': array([10204,  7519, 1042

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:39<00:00,  6.97it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0048, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 743.10it/s]

Loss (orig, final): 0.11036579310894012 0.004771464969962835
L2 norm of weight change: 0.40330392122268677
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 13.75it/s]


Post-edit metrics: {'TP': array([6748, 5307, 5165, 3897, 3912, 2625, 5613, 5326,    1, 4462]), 'TN': array([52778, 61227, 59339, 59835, 60774, 62233, 61140, 61141, 63000,
       61589]), 'FPs': array([10222,  1773,  3661,  3165,  2226,   767,  1860,  1859,     0,
        1411]), 'FNs': array([ 252, 1693, 1835, 3103, 3088, 4375, 1387, 1674, 6999, 2538]), 'accuracy': 0.6150857142857142, 'per_class_accuracy': array([0.85037143, 0.95048571, 0.92148571, 0.91045714, 0.92408571,
       0.92654286, 0.95361429, 0.94952857, 0.90001429, 0.94358571]), 'per_class_accuracy_mean': 0.9230171428571428, 'precision': array([0.3976429 , 0.74957627, 0.58520281, 0.55182668, 0.63734115,
       0.77387972, 0.75110397, 0.74126653, 1.        , 0.759748  ]), 'precision_mean': 0.694758802986075, 'recall': array([9.64000000e-01, 7.58142857e-01, 7.37857143e-01, 5.56714286e-01,
       5.58857143e-01, 3.75000000e-01, 8.01857143e-01, 7.60857143e-01,
       1.42857143e-04, 6.37428571e-01]), 'recall_mean': 0.61508571428

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.61it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0089, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:51<00:00, 782.10it/s]

Loss (orig, final): 0.0787421390414238 0.008922312408685684
L2 norm of weight change: 0.2616383135318756
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:23<00:00, 11.87it/s]


Post-edit metrics: {'TP': array([6449, 5347, 5165, 3879, 4064, 2731, 5701, 5305, 3966, 4707]), 'TN': array([58541, 61126, 59486, 59793, 60535, 62083, 60898, 61269, 62441,
       61142]), 'FPs': array([4459, 1874, 3514, 3207, 2465,  917, 2102, 1731,  559, 1858]), 'FNs': array([ 551, 1653, 1835, 3121, 2936, 4269, 1299, 1695, 3034, 2293]), 'accuracy': 0.6759142857142857, 'per_class_accuracy': array([0.92842857, 0.94961429, 0.92358571, 0.9096    , 0.92284286,
       0.92591429, 0.95141429, 0.95105714, 0.94867143, 0.9407    ]), 'per_class_accuracy_mean': 0.9351828571428571, 'precision': array([0.59121746, 0.74047916, 0.59511464, 0.54741744, 0.62245367,
       0.74862939, 0.73061643, 0.75397953, 0.87646409, 0.71698401]), 'precision_mean': 0.6923355812588918, 'recall': array([0.92128571, 0.76385714, 0.73785714, 0.55414286, 0.58057143,
       0.39014286, 0.81442857, 0.75785714, 0.56657143, 0.67242857]), 'recall_mean': 0.6759142857142857, 'predicted_class_distribution': array([10908,  7221,  86

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:27<00:00,  9.88it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0033, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 829.11it/s]

Loss (orig, final): 0.14275555312633514 0.003277318784967065
L2 norm of weight change: 0.7910994291305542
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 13.90it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6923,    0, 4795,  926, 3658, 2262, 4634, 3986,    0,    0]), 'TN': array([27310, 63000, 59982, 62272, 61072, 62556, 62505, 62487, 63000,
       63000]), 'FPs': array([35690,     0,  3018,   728,  1928,   444,   495,   513,     0,
           0]), 'FNs': array([  77, 7000, 2205, 6074, 3342, 4738, 2366, 3014, 7000, 7000]), 'accuracy': 0.38834285714285716, 'per_class_accuracy': array([0.48904286, 0.9       , 0.92538571, 0.90282857, 0.92471429,
       0.92597143, 0.95912857, 0.94961429, 0.9       , 0.9       ]), 'per_class_accuracy_mean': 0.8776685714285716, 'precision': array([0.16246216, 0.        , 0.61372072, 0.5598549 , 0.65485141,
       0.83592018, 0.90348996, 0.88597466, 0.        , 0.        ]), 'precision_mean': 0.46162739902532124, 'recall': array([0.989     , 0.        , 0.685     , 0.13228571, 0.52257143,
       0.32314286, 0.662     , 0.56942857, 0.        , 0.        ]), 'recall_mean': 0.38834285714285716, 'predicted_class_distribution': arra

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:20<00:00, 13.34it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0068, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 826.38it/s]

Loss (orig, final): 0.08743598312139511 0.006783650256693363
L2 norm of weight change: 0.5197508931159973
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:30<00:00,  8.94it/s]


Post-edit metrics: {'TP': array([6499, 4767, 5085, 3388, 4638, 2758, 5193, 4445, 4704, 2863]), 'TN': array([52915, 61802, 59314, 60981, 58701, 62095, 61950, 62391, 61474,
       62717]), 'FPs': array([10085,  1198,  3686,  2019,  4299,   905,  1050,   609,  1526,
         283]), 'FNs': array([ 501, 2233, 1915, 3612, 2362, 4242, 1807, 2555, 2296, 4137]), 'accuracy': 0.6334285714285715, 'per_class_accuracy': array([0.84877143, 0.95098571, 0.91998571, 0.91955714, 0.90484286,
       0.92647143, 0.95918571, 0.9548    , 0.9454    , 0.93685714]), 'per_class_accuracy_mean': 0.9266857142857143, 'precision': array([0.39188374, 0.79916178, 0.57975145, 0.62659515, 0.5189661 ,
       0.75293475, 0.83181163, 0.87950139, 0.75505618, 0.9100445 ]), 'precision_mean': 0.7045706672218377, 'recall': array([0.92842857, 0.681     , 0.72642857, 0.484     , 0.66257143,
       0.394     , 0.74185714, 0.635     , 0.672     , 0.409     ]), 'recall_mean': 0.6334285714285713, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 17.51it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0071, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:45<00:00, 882.75it/s]

Loss (orig, final): 0.12613266706466675 0.007069211918860674
L2 norm of weight change: 0.5476809144020081
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:24<00:00, 11.11it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6688, 5241,    0, 3937, 4020, 2525, 5450, 5014, 3819, 4610]), 'TN': array([47263, 61289, 63000, 59780, 60556, 62323, 61563, 61749, 62486,
       61295]), 'FPs': array([15737,  1711,     0,  3220,  2444,   677,  1437,  1251,   514,
        1705]), 'FNs': array([ 312, 1759, 7000, 3063, 2980, 4475, 1550, 1986, 3181, 2390]), 'accuracy': 0.5900571428571428, 'per_class_accuracy': array([0.77072857, 0.95042857, 0.9       , 0.91024286, 0.92251429,
       0.9264    , 0.95732857, 0.95375714, 0.94721429, 0.9415    ]), 'per_class_accuracy_mean': 0.9180114285714284, 'precision': array([0.29823857, 0.75388377, 0.        , 0.55009082, 0.62190594,
       0.78856964, 0.79134601, 0.80031923, 0.88137549, 0.73000792]), 'precision_mean': 0.6215737408364413, 'recall': array([0.95542857, 0.74871429, 0.        , 0.56242857, 0.57428571,
       0.36071429, 0.77857143, 0.71628571, 0.54557143, 0.65857143]), 'recall_mean': 0.5900571428571428, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 17.38it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0048, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:56<00:00, 709.23it/s]

Loss (orig, final): 0.09182490408420563 0.004789463244378567
L2 norm of weight change: 0.3654452860355377
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:25<00:00, 10.64it/s]


Post-edit metrics: {'TP': array([6543, 5340, 3512, 3904, 4176, 2705, 5647, 5012, 4445, 4620]), 'TN': array([54129, 61099, 62189, 59896, 60282, 62135, 61185, 61813, 61977,
       61199]), 'FPs': array([8871, 1901,  811, 3104, 2718,  865, 1815, 1187, 1023, 1801]), 'FNs': array([ 457, 1660, 3488, 3096, 2824, 4295, 1353, 1988, 2555, 2380]), 'accuracy': 0.6557714285714286, 'per_class_accuracy': array([0.86674286, 0.94912857, 0.93858571, 0.91142857, 0.92082857,
       0.92628571, 0.95474286, 0.95464286, 0.94888571, 0.94027143]), 'per_class_accuracy_mean': 0.9311542857142857, 'precision': array([0.42448424, 0.7374672 , 0.8123988 , 0.55707763, 0.60574413,
       0.75770308, 0.75676762, 0.8085175 , 0.81291149, 0.71951409]), 'precision_mean': 0.6992585769860888, 'recall': array([0.93471429, 0.76285714, 0.50171429, 0.55771429, 0.59657143,
       0.38642857, 0.80671429, 0.716     , 0.635     , 0.66      ]), 'recall_mean': 0.6557714285714284, 'predicted_class_distribution': array([15414,  7241,  43

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 13.77it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0118, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:59<00:00, 671.11it/s]

Loss (orig, final): 0.09520719945430756 0.011795062571763992
L2 norm of weight change: 0.3190569579601288
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.92it/s]


Post-edit metrics: {'TP': array([6383, 5280, 5260, 3984, 4233, 2846, 5495, 4891, 4565, 4186]), 'TN': array([57876, 61298, 59049, 59538, 60127, 61952, 61406, 62013, 61882,
       61982]), 'FPs': array([5124, 1702, 3951, 3462, 2873, 1048, 1594,  987, 1118, 1018]), 'FNs': array([ 617, 1720, 1740, 3016, 2767, 4154, 1505, 2109, 2435, 2814]), 'accuracy': 0.6731857142857143, 'per_class_accuracy': array([0.91798571, 0.95111429, 0.9187    , 0.90745714, 0.91942857,
       0.92568571, 0.95572857, 0.95577143, 0.94924286, 0.94525714]), 'per_class_accuracy_mean': 0.9346371428571428, 'precision': array([0.55470583, 0.75623031, 0.57105635, 0.53505238, 0.59569378,
       0.730868  , 0.77514459, 0.83208574, 0.80327292, 0.80438125]), 'precision_mean': 0.6958491140572368, 'recall': array([0.91185714, 0.75428571, 0.75142857, 0.56914286, 0.60471429,
       0.40657143, 0.785     , 0.69871429, 0.65214286, 0.598     ]), 'recall_mean': 0.6731857142857144, 'predicted_class_distribution': array([11507,  6982,  92

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 16.92it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0067, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:56<00:00, 708.57it/s]

Loss (orig, final): 0.1520841121673584 0.006745567545294762
L2 norm of weight change: 0.7007972002029419
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.87it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6898, 4902, 4834, 2816, 3633, 2491, 5377, 4469,    0,    0]), 'TN': array([38572, 61653, 60070, 61669, 61136, 62404, 61655, 62261, 63000,
       63000]), 'FPs': array([24428,  1347,  2930,  1331,  1864,   596,  1345,   739,     0,
           0]), 'FNs': array([ 102, 2098, 2166, 4184, 3367, 4509, 1623, 2531, 7000, 7000]), 'accuracy': 0.506, 'per_class_accuracy': array([0.64957143, 0.95078571, 0.9272    , 0.92121429, 0.92527143,
       0.92707143, 0.9576    , 0.95328571, 0.9       , 0.9       ]), 'per_class_accuracy_mean': 0.9012, 'precision': array([0.22020047, 0.78444551, 0.62261721, 0.67904509, 0.66090595,
       0.8069323 , 0.79991074, 0.85810292, 0.        , 0.        ]), 'precision_mean': 0.5432160189061168, 'recall': array([0.98542857, 0.70028571, 0.69057143, 0.40228571, 0.519     ,
       0.35585714, 0.76814286, 0.63842857, 0.        , 0.        ]), 'recall_mean': 0.506, 'predicted_class_distribution': array([31326,  6249,  7764,  4147,  5497,  30

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 14.09it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0063, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:50<00:00, 792.31it/s]

Loss (orig, final): 0.15058936178684235 0.006334173958748579
L2 norm of weight change: 0.48887956142425537
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:32<00:00,  8.38it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6804, 5200, 5009, 3846, 3954, 2425, 5597, 5305,    0, 4394]), 'TN': array([51318, 61377, 59798, 59940, 60679, 62413, 61206, 61084, 63000,
       61719]), 'FPs': array([11682,  1623,  3202,  3060,  2321,   587,  1794,  1916,     0,
        1281]), 'FNs': array([ 196, 1800, 1991, 3154, 3046, 4575, 1403, 1695, 7000, 2606]), 'accuracy': 0.6076285714285714, 'per_class_accuracy': array([0.83031429, 0.9511    , 0.92581429, 0.91122857, 0.92332857,
       0.92625714, 0.95432857, 0.94841429, 0.9       , 0.94447143]), 'per_class_accuracy_mean': 0.9215257142857143, 'precision': array([0.36806232, 0.7621281 , 0.61003532, 0.55690704, 0.63011952,
       0.80511288, 0.75727236, 0.73466279, 0.        , 0.77427313]), 'precision_mean': 0.5998573448241554, 'recall': array([0.972     , 0.74285714, 0.71557143, 0.54942857, 0.56485714,
       0.34642857, 0.79957143, 0.75785714, 0.        , 0.62771429]), 'recall_mean': 0.6076285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:19<00:00, 14.29it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0100, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:52<00:00, 755.29it/s]

Loss (orig, final): 0.15873795747756958 0.009976967237889767
L2 norm of weight change: 0.5245724320411682
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 14.59it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6784, 5162, 5169, 3835, 3825, 2342, 5596, 5377,    0, 4456]), 'TN': array([51750, 61419, 59289, 59964, 60949, 62488, 61186, 60849, 63000,
       61652]), 'FPs': array([11250,  1581,  3711,  3036,  2051,   512,  1814,  2151,     0,
        1348]), 'FNs': array([ 216, 1838, 1831, 3165, 3175, 4658, 1404, 1623, 7000, 2544]), 'accuracy': 0.6078, 'per_class_accuracy': array([0.8362    , 0.95115714, 0.92082857, 0.91141429, 0.92534286,
       0.92614286, 0.95402857, 0.94608571, 0.9       , 0.9444    ]), 'per_class_accuracy_mean': 0.92156, 'precision': array([0.37617833, 0.76553463, 0.58209459, 0.55814292, 0.65095303,
       0.82060266, 0.75519568, 0.71426674, 0.        , 0.76774638]), 'precision_mean': 0.5990714965467727, 'recall': array([0.96914286, 0.73742857, 0.73842857, 0.54785714, 0.54642857,
       0.33457143, 0.79942857, 0.76814286, 0.        , 0.63657143]), 'recall_mean': 0.6077999999999999, 'predicted_class_distribution': array([18034,  6743,  8880,  6

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:32<00:00,  8.38it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0114, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:52<00:00, 763.80it/s]

Loss (orig, final): 0.06266534328460693 0.011382699012756348
L2 norm of weight change: 0.2368786334991455
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 19.66it/s]


Post-edit metrics: {'TP': array([6319, 5237, 4710, 4222, 4280, 2793, 5730, 5178, 4563, 4838]), 'TN': array([59314, 61331, 60724, 58875, 60069, 62002, 60973, 61624, 61978,
       60980]), 'FPs': array([3686, 1669, 2276, 4125, 2931,  998, 2027, 1376, 1022, 2020]), 'FNs': array([ 681, 1763, 2290, 2778, 2720, 4207, 1270, 1822, 2437, 2162]), 'accuracy': 0.6838571428571428, 'per_class_accuracy': array([0.93761429, 0.95097143, 0.93477143, 0.90138571, 0.91927143,
       0.92564286, 0.9529    , 0.95431429, 0.95058571, 0.94025714]), 'per_class_accuracy_mean': 0.9367714285714286, 'precision': array([0.63158421, 0.75832609, 0.67420555, 0.50581047, 0.59353765,
       0.73674492, 0.73868764, 0.79005188, 0.81700985, 0.70545348]), 'precision_mean': 0.6951411745413828, 'recall': array([0.90271429, 0.74814286, 0.67285714, 0.60314286, 0.61142857,
       0.399     , 0.81857143, 0.73971429, 0.65185714, 0.69114286]), 'recall_mean': 0.6838571428571429, 'predicted_class_distribution': array([10005,  6906,  69

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.63it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0042, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:54<00:00, 737.71it/s]

Loss (orig, final): 0.05031372606754303 0.004209229722619057
L2 norm of weight change: 0.23807592689990997
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.31it/s]


Post-edit metrics: {'TP': array([6407, 5306, 4863, 3973, 4247, 2757, 5698, 5183, 4416, 4702]), 'TN': array([58561, 61218, 60230, 59590, 60119, 62066, 60972, 61548, 62080,
       61168]), 'FPs': array([4439, 1782, 2770, 3410, 2881,  934, 2028, 1452,  920, 1832]), 'FNs': array([ 593, 1694, 2137, 3027, 2753, 4243, 1302, 1817, 2584, 2298]), 'accuracy': 0.6793142857142858, 'per_class_accuracy': array([0.92811429, 0.95034286, 0.9299    , 0.90804286, 0.91951429,
       0.92604286, 0.95242857, 0.9533    , 0.94994286, 0.941     ]), 'per_class_accuracy_mean': 0.9358628571428571, 'precision': array([0.59072469, 0.74858916, 0.63710206, 0.53812813, 0.5958193 ,
       0.74695205, 0.73750971, 0.78116051, 0.82758621, 0.71962045]), 'precision_mean': 0.6923192268343664, 'recall': array([0.91528571, 0.758     , 0.69471429, 0.56757143, 0.60671429,
       0.39385714, 0.814     , 0.74042857, 0.63085714, 0.67171429]), 'recall_mean': 0.6793142857142856, 'predicted_class_distribution': array([10846,  7088,  76

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:23<00:00, 11.70it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0040, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 827.04it/s]

Loss (orig, final): 0.17750228941440582 0.004011891782283783
L2 norm of weight change: 0.7034056782722473
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:51<00:00,  5.30it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6839, 5071, 4989, 3903, 3649, 2283, 5571, 5384,    0, 4138]), 'TN': array([49960, 61492, 59808, 59825, 61180, 62531, 61286, 60725, 63000,
       62020]), 'FPs': array([13040,  1508,  3192,  3175,  1820,   469,  1714,  2275,     0,
         980]), 'FNs': array([ 161, 1929, 2011, 3097, 3351, 4717, 1429, 1616, 7000, 2862]), 'accuracy': 0.5975285714285714, 'per_class_accuracy': array([0.81141429, 0.9509    , 0.92567143, 0.9104    , 0.92612857,
       0.92591429, 0.9551    , 0.94441429, 0.9       , 0.94511429]), 'per_class_accuracy_mean': 0.9195057142857144, 'precision': array([0.34403139, 0.77078583, 0.60982765, 0.55142696, 0.66721521,
       0.82957849, 0.76472203, 0.70296383, 0.        , 0.80851895]), 'precision_mean': 0.6049070348898508, 'recall': array([0.977     , 0.72442857, 0.71271429, 0.55757143, 0.52128571,
       0.32614286, 0.79585714, 0.76914286, 0.        , 0.59114286]), 'recall_mean': 0.5975285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:43<00:00,  6.25it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0119, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:50<00:00, 797.94it/s]

Loss (orig, final): 0.17294208705425262 0.011920848861336708
L2 norm of weight change: 0.6573811769485474
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:14<00:00, 19.39it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6815, 5115, 5051, 3933, 3762, 2327, 5595, 5432,    0, 4296]), 'TN': array([51184, 61479, 59693, 59766, 61036, 62501, 61218, 60584, 63000,
       61865]), 'FPs': array([11816,  1521,  3307,  3234,  1964,   499,  1782,  2416,     0,
        1135]), 'FNs': array([ 185, 1885, 1949, 3067, 3238, 4673, 1405, 1568, 7000, 2704]), 'accuracy': 0.6046571428571429, 'per_class_accuracy': array([0.82855714, 0.95134286, 0.92491429, 0.90998571, 0.92568571,
       0.92611429, 0.95447143, 0.94308571, 0.9       , 0.94515714]), 'per_class_accuracy_mean': 0.9209314285714285, 'precision': array([0.3657882 , 0.77079566, 0.60433118, 0.54876517, 0.65700314,
       0.82342534, 0.75843839, 0.69215087, 0.        , 0.79101455]), 'precision_mean': 0.601171249781141, 'recall': array([0.97357143, 0.73071429, 0.72157143, 0.56185714, 0.53742857,
       0.33242857, 0.79928571, 0.776     , 0.        , 0.61371429]), 'recall_mean': 0.6046571428571428, 'predicted_class_distribution': array([1

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:33<00:00,  8.11it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0091, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:52<00:00, 760.14it/s]

Loss (orig, final): 0.1477348804473877 0.009105049073696136
L2 norm of weight change: 0.6882073283195496
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.25it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6514, 4588, 4718, 2093, 4658, 2466, 5148, 2312, 4869,    0]), 'TN': array([43040, 61953, 60387, 62204, 58438, 62408, 62034, 62896, 61006,
       63000]), 'FPs': array([19960,  1047,  2613,   796,  4562,   592,   966,   104,  1994,
           0]), 'FNs': array([ 486, 2412, 2282, 4907, 2342, 4534, 1852, 4688, 2131, 7000]), 'accuracy': 0.5338, 'per_class_accuracy': array([0.70791429, 0.95058571, 0.93007143, 0.91852857, 0.90137143,
       0.92677143, 0.95974286, 0.93154286, 0.94107143, 0.9       ]), 'per_class_accuracy_mean': 0.90676, 'precision': array([0.24605273, 0.81419698, 0.64356841, 0.72447214, 0.50520607,
       0.80640942, 0.84200196, 0.95695364, 0.70945651, 0.        ]), 'precision_mean': 0.6248317860606145, 'recall': array([0.93057143, 0.65542857, 0.674     , 0.299     , 0.66542857,
       0.35228571, 0.73542857, 0.33028571, 0.69557143, 0.        ]), 'recall_mean': 0.5338, 'predicted_class_distribution': array([26474,  5635,  7331,  2889,  9220, 

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.49it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0322, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 724.56it/s]

Loss (orig, final): 0.17104390263557434 0.03218160569667816
L2 norm of weight change: 0.479002445936203
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 19.97it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6706, 5172, 5145, 4163, 4005, 2602, 5636, 5274,    0, 4556]), 'TN': array([53458, 61461, 59539, 58941, 60619, 62226, 61108, 61351, 63000,
       61556]), 'FPs': array([9542, 1539, 3461, 4059, 2381,  774, 1892, 1649,    0, 1444]), 'FNs': array([ 294, 1828, 1855, 2837, 2995, 4398, 1364, 1726, 7000, 2444]), 'accuracy': 0.6179857142857142, 'per_class_accuracy': array([0.85948571, 0.9519    , 0.92405714, 0.90148571, 0.9232    ,
       0.92611429, 0.95348571, 0.95178571, 0.9       , 0.94445714]), 'per_class_accuracy_mean': 0.9235971428571428, 'precision': array([0.41272772, 0.77067501, 0.59783872, 0.5063245 , 0.62715315,
       0.7707346 , 0.74867163, 0.76180846, 0.        , 0.75933333]), 'precision_mean': 0.5955267112409453, 'recall': array([0.958     , 0.73885714, 0.735     , 0.59471429, 0.57214286,
       0.37171429, 0.80514286, 0.75342857, 0.        , 0.65085714]), 'recall_mean': 0.6179857142857144, 'predicted_class_distribution': array([16248,  6711,  86

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 22.44it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0086, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:50<00:00, 795.71it/s]

Loss (orig, final): 0.14646518230438232 0.008632240816950798
L2 norm of weight change: 0.5949752926826477
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:38<00:00,  7.18it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6618, 5410,    0, 3287, 3616, 2653, 4800, 4863, 4301, 4452]), 'TN': array([43767, 60857, 63000, 60966, 61250, 62219, 62335, 61971, 62122,
       61513]), 'FPs': array([19233,  2143,     0,  2034,  1750,   781,   665,  1029,   878,
        1487]), 'FNs': array([ 382, 1590, 7000, 3713, 3384, 4347, 2200, 2137, 2699, 2548]), 'accuracy': 0.5714285714285714, 'per_class_accuracy': array([0.71978571, 0.94667143, 0.9       , 0.9179    , 0.92665714,
       0.92674286, 0.95907143, 0.95477143, 0.9489    , 0.94235714]), 'per_class_accuracy_mean': 0.9142857142857143, 'precision': array([0.25600557, 0.71627168, 0.        , 0.61774103, 0.67387253,
       0.77256843, 0.87831656, 0.82535642, 0.8304692 , 0.74962115]), 'precision_mean': 0.6320222566992866, 'recall': array([0.94542857, 0.77285714, 0.        , 0.46957143, 0.51657143,
       0.379     , 0.68571429, 0.69471429, 0.61442857, 0.636     ]), 'recall_mean': 0.5714285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:20<00:00, 13.59it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0049, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:52<00:00, 757.17it/s]

Loss (orig, final): 0.11085586249828339 0.004902631044387817
L2 norm of weight change: 0.5004245042800903
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.86it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6812, 5296, 5186, 3719, 4037, 2646, 5531, 5137,    0, 3894]), 'TN': array([50892, 61207, 59133, 60287, 60489, 62204, 61306, 61541, 63000,
       62199]), 'FPs': array([12108,  1793,  3867,  2713,  2511,   796,  1694,  1459,     0,
         801]), 'FNs': array([ 188, 1704, 1814, 3281, 2963, 4354, 1469, 1863, 7000, 3106]), 'accuracy': 0.6036857142857143, 'per_class_accuracy': array([0.82434286, 0.95004286, 0.91884286, 0.91437143, 0.9218    ,
       0.92642857, 0.95481429, 0.95254286, 0.9       , 0.94418571]), 'per_class_accuracy_mean': 0.9207371428571429, 'precision': array([0.36004228, 0.74707293, 0.57284878, 0.57820274, 0.61652413,
       0.76873911, 0.76553633, 0.77880534, 0.        , 0.82939297]), 'precision_mean': 0.6017164603588185, 'recall': array([0.97314286, 0.75657143, 0.74085714, 0.53128571, 0.57671429,
       0.378     , 0.79014286, 0.73385714, 0.        , 0.55628571]), 'recall_mean': 0.6036857142857143, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.29it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0041, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:58<00:00, 685.10it/s]

Loss (orig, final): 0.11324313282966614 0.004123925697058439
L2 norm of weight change: 0.5105297565460205
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.44it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6807, 5191, 5099, 3832, 3854, 2501, 5566, 5311,    0, 4290]), 'TN': array([51158, 61371, 59467, 60020, 60862, 62353, 61286, 61093, 63000,
       61841]), 'FPs': array([11842,  1629,  3533,  2980,  2138,   647,  1714,  1907,     0,
        1159]), 'FNs': array([ 193, 1809, 1901, 3168, 3146, 4499, 1434, 1689, 7000, 2710]), 'accuracy': 0.6064428571428572, 'per_class_accuracy': array([0.82807143, 0.95088571, 0.92237143, 0.91217143, 0.92451429,
       0.92648571, 0.95502857, 0.94862857, 0.9       , 0.94472857]), 'per_class_accuracy_mean': 0.9212885714285715, 'precision': array([0.36500617, 0.7611437 , 0.59070899, 0.5625367 , 0.64319092,
       0.79447268, 0.76456044, 0.73579939, 0.        , 0.78730042]), 'precision_mean': 0.6004719405676546, 'recall': array([0.97242857, 0.74157143, 0.72842857, 0.54742857, 0.55057143,
       0.35728571, 0.79514286, 0.75871429, 0.        , 0.61285714]), 'recall_mean': 0.6064428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.40it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0051, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 619.50it/s]

Loss (orig, final): 0.15070386230945587 0.005143491551280022
L2 norm of weight change: 0.6447407603263855
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.14it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6822, 5142, 5040, 3888, 3705, 2312, 5598, 5395,    0, 4246]), 'TN': array([50800, 61407, 59660, 59862, 61116, 62508, 61196, 60703, 63000,
       61896]), 'FPs': array([12200,  1593,  3340,  3138,  1884,   492,  1804,  2297,     0,
        1104]), 'FNs': array([ 178, 1858, 1960, 3112, 3295, 4688, 1402, 1605, 7000, 2754]), 'accuracy': 0.6021142857142857, 'per_class_accuracy': array([0.82317143, 0.9507    , 0.92428571, 0.91071429, 0.92601429,
       0.926     , 0.9542    , 0.94425714, 0.9       , 0.94488571]), 'per_class_accuracy_mean': 0.9204228571428572, 'precision': array([0.35863737, 0.76347439, 0.60143198, 0.55337319, 0.66290929,
       0.82453638, 0.75628209, 0.70137806, 0.        , 0.79364486]), 'precision_mean': 0.6015667584565978, 'recall': array([0.97457143, 0.73457143, 0.72      , 0.55542857, 0.52928571,
       0.33028571, 0.79971429, 0.77071429, 0.        , 0.60657143]), 'recall_mean': 0.6021142857142856, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.40it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0034, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:01<00:00, 650.46it/s]

Loss (orig, final): 0.11627104878425598 0.0034460576716810465
L2 norm of weight change: 0.4142155349254608
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.54it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6785, 5212, 5120, 3858, 3942, 2530, 5587, 5248,    0, 4436]), 'TN': array([51812, 61354, 59438, 59916, 60705, 62325, 61227, 61295, 63000,
       61646]), 'FPs': array([11188,  1646,  3562,  3084,  2295,   675,  1773,  1705,     0,
        1354]), 'FNs': array([ 215, 1788, 1880, 3142, 3058, 4470, 1413, 1752, 7000, 2564]), 'accuracy': 0.6102571428571428, 'per_class_accuracy': array([0.8371    , 0.95094286, 0.92225714, 0.91105714, 0.92352857,
       0.9265    , 0.95448571, 0.95061429, 0.9       , 0.94402857]), 'per_class_accuracy_mean': 0.9220514285714285, 'precision': array([0.37751071, 0.75998833, 0.58972587, 0.55574762, 0.63203463,
       0.78939158, 0.75910326, 0.75478211, 0.        , 0.76614853]), 'precision_mean': 0.5984432647041608, 'recall': array([0.96928571, 0.74457143, 0.73142857, 0.55114286, 0.56314286,
       0.36142857, 0.79814286, 0.74971429, 0.        , 0.63371429]), 'recall_mean': 0.6102571428571428, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.96it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0061, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 637.78it/s]

Loss (orig, final): 0.10212555527687073 0.006100296042859554
L2 norm of weight change: 0.37425941228866577
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.06it/s]


Post-edit metrics: {'TP': array([6479, 5465, 4532, 3022, 4291, 2828, 5557, 4977, 4633, 4373]), 'TN': array([54471, 60859, 60875, 61414, 59957, 62001, 61351, 61859, 61761,
       61609]), 'FPs': array([8529, 2141, 2125, 1586, 3043,  999, 1649, 1141, 1239, 1391]), 'FNs': array([ 521, 1535, 2468, 3978, 2709, 4172, 1443, 2023, 2367, 2627]), 'accuracy': 0.6593857142857142, 'per_class_accuracy': array([0.87071429, 0.94748571, 0.93438571, 0.92051429, 0.91782857,
       0.92612857, 0.95582857, 0.9548    , 0.94848571, 0.9426    ]), 'per_class_accuracy_mean': 0.9318771428571428, 'precision': array([0.43170309, 0.7185117 , 0.68078714, 0.65581597, 0.58508317,
       0.73896002, 0.77116292, 0.81350114, 0.78899864, 0.75867453]), 'precision_mean': 0.6943198334842366, 'recall': array([0.92557143, 0.78071429, 0.64742857, 0.43171429, 0.613     ,
       0.404     , 0.79385714, 0.711     , 0.66185714, 0.62471429]), 'recall_mean': 0.6593857142857142, 'predicted_class_distribution': array([15008,  7606,  66

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.10it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0076, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 630.85it/s]

Loss (orig, final): 0.08840510994195938 0.00761287659406662
L2 norm of weight change: 0.410312682390213
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.37it/s]


Post-edit metrics: {'TP': array([6340, 5556, 5292,  217, 4316, 2773, 5587, 5016, 4751, 4313]), 'TN': array([53999, 60643, 58646, 62789, 59843, 62093, 61063, 61834, 61574,
       61677]), 'FPs': array([9001, 2357, 4354,  211, 3157,  907, 1937, 1166, 1426, 1323]), 'FNs': array([ 660, 1444, 1708, 6783, 2684, 4227, 1413, 1984, 2249, 2687]), 'accuracy': 0.6308714285714285, 'per_class_accuracy': array([0.86198571, 0.9457    , 0.9134    , 0.90008571, 0.91655714,
       0.92665714, 0.95214286, 0.955     , 0.9475    , 0.94271429]), 'per_class_accuracy_mean': 0.9261742857142858, 'precision': array([0.41327163, 0.70213573, 0.54862119, 0.50700935, 0.57754583,
       0.75353261, 0.74255715, 0.8113879 , 0.7691436 , 0.76525905]), 'precision_mean': 0.659046402435296, 'recall': array([0.90571429, 0.79371429, 0.756     , 0.031     , 0.61657143,
       0.39614286, 0.79814286, 0.71657143, 0.67871429, 0.61614286]), 'recall_mean': 0.6308714285714285, 'predicted_class_distribution': array([15341,  7913,  964

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.72it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0128, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:00<00:00, 655.93it/s]

Loss (orig, final): 0.14946970343589783 0.012786869890987873
L2 norm of weight change: 0.5635129809379578
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.43it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6512, 5329, 5529, 3164, 4239, 2995, 5293, 4926, 4204,    0]), 'TN': array([51508, 61181, 57385, 61307, 59993, 61842, 61774, 61944, 62257,
       63000]), 'FPs': array([11492,  1819,  5615,  1693,  3007,  1158,  1226,  1056,   743,
           0]), 'FNs': array([ 488, 1671, 1471, 3836, 2761, 4005, 1707, 2074, 2796, 7000]), 'accuracy': 0.6027285714285714, 'per_class_accuracy': array([0.82885714, 0.95014286, 0.89877143, 0.92101429, 0.9176    ,
       0.92624286, 0.9581    , 0.95528571, 0.94944286, 0.9       ]), 'per_class_accuracy_mean': 0.9205457142857144, 'precision': array([0.3616974 , 0.74552322, 0.49614142, 0.65143092, 0.58501242,
       0.72116542, 0.81193435, 0.82347041, 0.84980796, 0.        ]), 'precision_mean': 0.6046183534335612, 'recall': array([0.93028571, 0.76128571, 0.78985714, 0.452     , 0.60557143,
       0.42785714, 0.75614286, 0.70371429, 0.60057143, 0.        ]), 'recall_mean': 0.6027285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.98it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0034, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:54<00:00, 738.23it/s]

Loss (orig, final): 0.05999001860618591 0.003442182671278715
L2 norm of weight change: 0.24511569738388062
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 17.36it/s]


Post-edit metrics: {'TP': array([6352, 5456, 5260, 3478, 4254, 2875, 5575, 5125, 4632, 4330]), 'TN': array([58355, 60954, 59034, 60685, 60039, 61932, 61225, 61652, 61780,
       61681]), 'FPs': array([4645, 2046, 3966, 2315, 2961, 1068, 1775, 1348, 1220, 1319]), 'FNs': array([ 648, 1544, 1740, 3522, 2746, 4125, 1425, 1875, 2368, 2670]), 'accuracy': 0.6762428571428571, 'per_class_accuracy': array([0.92438571, 0.94871429, 0.91848571, 0.91661429, 0.91847143,
       0.92581429, 0.95428571, 0.95395714, 0.94874286, 0.94301429]), 'per_class_accuracy_mean': 0.9352485714285714, 'precision': array([0.57761208, 0.72727273, 0.5701279 , 0.60037977, 0.58960499,
       0.72914025, 0.7585034 , 0.79175035, 0.79152427, 0.76650735]), 'precision_mean': 0.6902423070140943, 'recall': array([0.90742857, 0.77942857, 0.75142857, 0.49685714, 0.60771429,
       0.41071429, 0.79642857, 0.73214286, 0.66171429, 0.61857143]), 'recall_mean': 0.6762428571428571, 'predicted_class_distribution': array([10997,  7502,  92

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.02it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0032, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:59<00:00, 674.77it/s]

Loss (orig, final): 0.14752653241157532 0.003241928294301033
L2 norm of weight change: 0.710345983505249
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.37it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6685, 5307,    0, 2829, 3832, 2460, 5039, 4556, 4050, 4351]), 'TN': array([41862, 61074, 63000, 61562, 60819, 62423, 62083, 62233, 62330,
       61723]), 'FPs': array([21138,  1926,     0,  1438,  2181,   577,   917,   767,   670,
        1277]), 'FNs': array([ 315, 1693, 7000, 4171, 3168, 4540, 1961, 2444, 2950, 2649]), 'accuracy': 0.5587, 'per_class_accuracy': array([0.69352857, 0.9483    , 0.9       , 0.91987143, 0.92358571,
       0.9269    , 0.95888571, 0.95412857, 0.94828571, 0.94391429]), 'per_class_accuracy_mean': 0.9117399999999998, 'precision': array([0.24026884, 0.73372045, 0.        , 0.66299508, 0.63728588,
       0.81000988, 0.84603761, 0.85590832, 0.85805085, 0.77309879]), 'precision_mean': 0.6417375698263776, 'recall': array([0.955     , 0.75814286, 0.        , 0.40414286, 0.54742857,
       0.35142857, 0.71985714, 0.65085714, 0.57857143, 0.62157143]), 'recall_mean': 0.5586999999999999, 'predicted_class_distribution': array([27823,  7233

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.27it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0090, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 622.58it/s]

Loss (orig, final): 0.08622988313436508 0.009018299169838428
L2 norm of weight change: 0.31286725401878357
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.87it/s]


Post-edit metrics: {'TP': array([6317, 5530, 4328, 3898, 4241, 3094, 5734, 5025, 4886, 4530]), 'TN': array([58452, 60678, 61328, 59839, 60156, 61563, 60983, 61879, 61421,
       61284]), 'FPs': array([4548, 2322, 1672, 3161, 2844, 1437, 2017, 1121, 1579, 1716]), 'FNs': array([ 683, 1470, 2672, 3102, 2759, 3906, 1266, 1975, 2114, 2470]), 'accuracy': 0.6797571428571428, 'per_class_accuracy': array([0.92527143, 0.94582857, 0.93794286, 0.91052857, 0.91995714,
       0.92367143, 0.9531    , 0.95577143, 0.94724286, 0.9402    ]), 'per_class_accuracy_mean': 0.9359514285714287, 'precision': array([0.58140819, 0.70427916, 0.72133333, 0.55220286, 0.59858857,
       0.68285147, 0.73977551, 0.81760495, 0.75576179, 0.72526417]), 'precision_mean': 0.6879070008464693, 'recall': array([0.90242857, 0.79      , 0.61828571, 0.55685714, 0.60585714,
       0.442     , 0.81914286, 0.71785714, 0.698     , 0.64714286]), 'recall_mean': 0.6797571428571428, 'predicted_class_distribution': array([10865,  7852,  60

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.89it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0057, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 622.87it/s]

Loss (orig, final): 0.19306322932243347 0.005747493822127581
L2 norm of weight change: 0.7187454104423523
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.86it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6832, 5101, 5036, 3887, 3603, 2292, 5569, 5386,    0, 4118]), 'TN': array([50021, 61461, 59676, 59850, 61234, 62517, 61287, 60741, 63000,
       62037]), 'FPs': array([12979,  1539,  3324,  3150,  1766,   483,  1713,  2259,     0,
         963]), 'FNs': array([ 168, 1899, 1964, 3113, 3397, 4708, 1431, 1614, 7000, 2882]), 'accuracy': 0.5974857142857143, 'per_class_accuracy': array([0.81218571, 0.95088571, 0.92445714, 0.91052857, 0.92624286,
       0.92584286, 0.95508571, 0.94467143, 0.9       , 0.94507143]), 'per_class_accuracy_mean': 0.9194971428571428, 'precision': array([0.34485892, 0.76822289, 0.60239234, 0.55236607, 0.67107469,
       0.82594595, 0.76476243, 0.70451275, 0.        , 0.81047038]), 'precision_mean': 0.6044606413066151, 'recall': array([0.976     , 0.72871429, 0.71942857, 0.55528571, 0.51471429,
       0.32742857, 0.79557143, 0.76942857, 0.        , 0.58828571]), 'recall_mean': 0.5974857142857143, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.57it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0045, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:05<00:00, 615.29it/s]

Loss (orig, final): 0.13582585752010345 0.0045173075050115585
L2 norm of weight change: 0.41782617568969727
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.02it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6758, 5276, 5151, 3881, 3948, 2578, 5592, 5303,    0, 4440]), 'TN': array([52365, 61280, 59412, 59878, 60702, 62278, 61194, 61168, 63000,
       61650]), 'FPs': array([10635,  1720,  3588,  3122,  2298,   722,  1806,  1832,     0,
        1350]), 'FNs': array([ 242, 1724, 1849, 3119, 3052, 4422, 1408, 1697, 7000, 2560]), 'accuracy': 0.6132428571428571, 'per_class_accuracy': array([0.84461429, 0.9508    , 0.92232857, 0.91084286, 0.92357143,
       0.92651429, 0.95408571, 0.94958571, 0.9       , 0.94414286]), 'per_class_accuracy_mean': 0.9226485714285715, 'precision': array([0.38854712, 0.75414523, 0.58942671, 0.55419106, 0.63208453,
       0.78121212, 0.75587997, 0.74323756, 0.        , 0.76683938]), 'precision_mean': 0.5965563673763182, 'recall': array([0.96542857, 0.75371429, 0.73585714, 0.55442857, 0.564     ,
       0.36828571, 0.79885714, 0.75757143, 0.        , 0.63428571]), 'recall_mean': 0.6132428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.41it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0422, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 717.33it/s]

Loss (orig, final): 0.17757049202919006 0.042221799492836
L2 norm of weight change: 0.49503347277641296
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:33<00:00,  8.09it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6684, 5303, 5127, 4047, 3881, 2690, 5704, 5378,    0, 4623]), 'TN': array([53918, 61243, 59594, 59280, 60870, 62141, 60896, 61126, 63000,
       61369]), 'FPs': array([9082, 1757, 3406, 3720, 2130,  859, 2104, 1874,    0, 1631]), 'FNs': array([ 316, 1697, 1873, 2953, 3119, 4310, 1296, 1622, 7000, 2377]), 'accuracy': 0.6205285714285714, 'per_class_accuracy': array([0.86574286, 0.95065714, 0.92458571, 0.90467143, 0.92501429,
       0.92615714, 0.95142857, 0.95005714, 0.9       , 0.94274286]), 'per_class_accuracy_mean': 0.9241057142857143, 'precision': array([0.42395027, 0.75113314, 0.60084378, 0.5210506 , 0.64564964,
       0.75795999, 0.73053279, 0.74158853, 0.        , 0.73920691]), 'precision_mean': 0.5911915651680586, 'recall': array([0.95485714, 0.75757143, 0.73242857, 0.57814286, 0.55442857,
       0.38428571, 0.81485714, 0.76828571, 0.        , 0.66042857]), 'recall_mean': 0.6205285714285714, 'predicted_class_distribution': array([15766,  7060,  85

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 18.05it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0347, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:51<00:00, 771.10it/s]

Loss (orig, final): 0.18862126767635345 0.03471451997756958
L2 norm of weight change: 0.5382944345474243
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 14.42it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6780, 5243, 5142, 3872, 3977, 2611, 5601, 5252,    0, 4248]), 'TN': array([51719, 61341, 59439, 59870, 60651, 62248, 61163, 61373, 63000,
       61922]), 'FPs': array([11281,  1659,  3561,  3130,  2349,   752,  1837,  1627,     0,
        1078]), 'FNs': array([ 220, 1757, 1858, 3128, 3023, 4389, 1399, 1748, 7000, 2752]), 'accuracy': 0.6103714285714286, 'per_class_accuracy': array([0.8357    , 0.9512    , 0.92258571, 0.9106    , 0.92325714,
       0.92655714, 0.95377143, 0.95178571, 0.9       , 0.94528571]), 'per_class_accuracy_mean': 0.922074285714286, 'precision': array([0.3753945 , 0.75963489, 0.59083075, 0.55298486, 0.62867531,
       0.77639013, 0.75302501, 0.76348306, 0.        , 0.7975967 ]), 'precision_mean': 0.5998015197044052, 'recall': array([0.96857143, 0.749     , 0.73457143, 0.55314286, 0.56814286,
       0.373     , 0.80014286, 0.75028571, 0.        , 0.60685714]), 'recall_mean': 0.6103714285714286, 'predicted_class_distribution': array([1

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:37<00:00,  7.34it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0233, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 818.39it/s]

Loss (orig, final): 0.2057744562625885 0.023343782871961594
L2 norm of weight change: 0.6759790778160095
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:22<00:00, 12.09it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6534, 5512, 5158,    0, 4016, 2408, 5282, 5088, 4228, 4086]), 'TN': array([48721, 60649, 59267, 63000, 60553, 62444, 61829, 61634, 62216,
       61999]), 'FPs': array([14279,  2351,  3733,     0,  2447,   556,  1171,  1366,   784,
        1001]), 'FNs': array([ 466, 1488, 1842, 7000, 2984, 4592, 1718, 1912, 2772, 2914]), 'accuracy': 0.6044571428571428, 'per_class_accuracy': array([0.78935714, 0.94515714, 0.92035714, 0.9       , 0.92241429,
       0.92645714, 0.95872857, 0.95317143, 0.9492    , 0.94407143]), 'per_class_accuracy_mean': 0.9208914285714285, 'precision': array([0.3139384 , 0.70100471, 0.58013722, 0.        , 0.62138326,
       0.81241565, 0.81853402, 0.78834831, 0.84357542, 0.8032239 ]), 'precision_mean': 0.6282560889320973, 'recall': array([0.93342857, 0.78742857, 0.73685714, 0.        , 0.57371429,
       0.344     , 0.75457143, 0.72685714, 0.604     , 0.58371429]), 'recall_mean': 0.6044571428571428, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:18<00:00, 14.84it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0241, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:54<00:00, 738.74it/s]


Loss (orig, final): 0.18598546087741852 0.024084679782390594
L2 norm of weight change: 0.5865170955657959
Performing post-edit metric & KNN calculations on validation set.


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.28it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6452, 5493, 5150,    0, 4131, 2506, 5326, 5013, 4563, 4209]), 'TN': array([49899, 60705, 59289, 63000, 60327, 62376, 61747, 61801, 61837,
       61862]), 'FPs': array([13101,  2295,  3711,     0,  2673,   624,  1253,  1199,  1163,
        1138]), 'FNs': array([ 548, 1507, 1850, 7000, 2869, 4494, 1674, 1987, 2437, 2791]), 'accuracy': 0.6120428571428571, 'per_class_accuracy': array([0.80501429, 0.94568571, 0.92055714, 0.9       , 0.92082857,
       0.92688571, 0.95818571, 0.95448571, 0.94857143, 0.94387143]), 'per_class_accuracy_mean': 0.9224085714285716, 'precision': array([0.32997494, 0.70531587, 0.58119851, 0.        , 0.60714286,
       0.80063898, 0.80954552, 0.80698648, 0.79689137, 0.78717038]), 'precision_mean': 0.6224864905600279, 'recall': array([0.92171429, 0.78471429, 0.73571429, 0.        , 0.59014286,
       0.358     , 0.76085714, 0.71614286, 0.65185714, 0.60128571]), 'recall_mean': 0.6120428571428571, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.59it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0340, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:57<00:00, 694.38it/s]

Loss (orig, final): 0.20154443383216858 0.03395078331232071
L2 norm of weight change: 0.6222022175788879
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.47it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6778, 5255, 5077, 4051, 3778, 2547, 5647, 5338,    0, 4327]), 'TN': array([52026, 61320, 59683, 59413, 61017, 62297, 61059, 61163, 63000,
       61820]), 'FPs': array([10974,  1680,  3317,  3587,  1983,   703,  1941,  1837,     0,
        1180]), 'FNs': array([ 222, 1745, 1923, 2949, 3222, 4453, 1353, 1662, 7000, 2673]), 'accuracy': 0.6114, 'per_class_accuracy': array([0.84005714, 0.95107143, 0.92514286, 0.90662857, 0.92564286,
       0.92634286, 0.95294286, 0.95001429, 0.9       , 0.94495714]), 'per_class_accuracy_mean': 0.9222800000000001, 'precision': array([0.38181613, 0.75775054, 0.60483679, 0.53037444, 0.65578893,
       0.78369231, 0.74420137, 0.74397213, 0.        , 0.78572726]), 'precision_mean': 0.5988159891348572, 'recall': array([0.96828571, 0.75071429, 0.72528571, 0.57871429, 0.53971429,
       0.36385714, 0.80671429, 0.76257143, 0.        , 0.61814286]), 'recall_mean': 0.6113999999999999, 'predicted_class_distribution': array([17752,  6935

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.82it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0261, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:00<00:00, 656.82it/s]

Loss (orig, final): 0.16886310279369354 0.026098787784576416
L2 norm of weight change: 0.4645383954048157
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.97it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6744, 5277, 5198, 3915, 3926, 2590, 5637, 5295,    0, 4434]), 'TN': array([52730, 61289, 59259, 59753, 60757, 62258, 61060, 61236, 63000,
       61674]), 'FPs': array([10270,  1711,  3741,  3247,  2243,   742,  1940,  1764,     0,
        1326]), 'FNs': array([ 256, 1723, 1802, 3085, 3074, 4410, 1363, 1705, 7000, 2566]), 'accuracy': 0.6145142857142857, 'per_class_accuracy': array([0.84962857, 0.95094286, 0.92081429, 0.90954286, 0.92404286,
       0.9264    , 0.95281429, 0.95044286, 0.9       , 0.9444    ]), 'per_class_accuracy_mean': 0.9229028571428571, 'precision': array([0.39637945, 0.75515169, 0.58149681, 0.54663502, 0.63640785,
       0.77731092, 0.74396199, 0.75010625, 0.        , 0.76979167]), 'precision_mean': 0.5957241644993972, 'recall': array([0.96342857, 0.75385714, 0.74257143, 0.55928571, 0.56085714,
       0.37      , 0.80528571, 0.75642857, 0.        , 0.63342857]), 'recall_mean': 0.6145142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.01it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0055, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 747.24it/s]

Loss (orig, final): 0.09612706303596497 0.005478637758642435
L2 norm of weight change: 0.3454355001449585
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:16<00:00, 16.62it/s]


Post-edit metrics: {'TP': array([6494, 5504, 5142, 3772, 3931, 2830, 5741, 5435, 3596, 4568]), 'TN': array([58268, 60694, 59556, 60051, 60814, 62000, 60767, 60957, 62612,
       61294]), 'FPs': array([4732, 2306, 3444, 2949, 2186, 1000, 2233, 2043,  388, 1706]), 'FNs': array([ 506, 1496, 1858, 3228, 3069, 4170, 1259, 1565, 3404, 2432]), 'accuracy': 0.6716142857142857, 'per_class_accuracy': array([0.92517143, 0.94568571, 0.92425714, 0.91175714, 0.92492857,
       0.92614286, 0.95011429, 0.94845714, 0.94582857, 0.94088571]), 'per_class_accuracy_mean': 0.9343228571428572, 'precision': array([0.57847853, 0.70473752, 0.5988819 , 0.56122601, 0.64263528,
       0.73890339, 0.71996489, 0.72679861, 0.90261044, 0.72808416]), 'precision_mean': 0.6902320723512676, 'recall': array([0.92771429, 0.78628571, 0.73457143, 0.53885714, 0.56157143,
       0.40428571, 0.82014286, 0.77642857, 0.51371429, 0.65257143]), 'recall_mean': 0.6716142857142857, 'predicted_class_distribution': array([11226,  7810,  85

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.45it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0040, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:00<00:00, 661.82it/s]

Loss (orig, final): 0.08053191006183624 0.003979570232331753
L2 norm of weight change: 0.3407222628593445
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.50it/s]


Post-edit metrics: {'TP': array([6418, 5070, 4365, 4417, 4292, 2689, 5395, 5022, 4740, 4786]), 'TN': array([57499, 61544, 61262, 58519, 59968, 62130, 61706, 61847, 61690,
       61029]), 'FPs': array([5501, 1456, 1738, 4481, 3032,  870, 1294, 1153, 1310, 1971]), 'FNs': array([ 582, 1930, 2635, 2583, 2708, 4311, 1605, 1978, 2260, 2214]), 'accuracy': 0.6742, 'per_class_accuracy': array([0.9131    , 0.95162857, 0.93752857, 0.89908571, 0.918     ,
       0.92598571, 0.95858571, 0.95527143, 0.949     , 0.94021429]), 'per_class_accuracy_mean': 0.93484, 'precision': array([0.53846799, 0.77689243, 0.71522202, 0.49640369, 0.58601857,
       0.75554931, 0.80654806, 0.81327935, 0.78347107, 0.7083025 ]), 'precision_mean': 0.6980155003132962, 'recall': array([0.91685714, 0.72428571, 0.62357143, 0.631     , 0.61314286,
       0.38414286, 0.77071429, 0.71742857, 0.67714286, 0.68371429]), 'recall_mean': 0.6741999999999999, 'predicted_class_distribution': array([11919,  6526,  6103,  8898,  7324,  3559

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.82it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0044, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 626.94it/s]

Loss (orig, final): 0.08140794187784195 0.00444080401211977
L2 norm of weight change: 0.3846830129623413
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.13it/s]


Post-edit metrics: {'TP': array([6593, 5311, 4075, 3852, 4073, 2621, 5692, 5077, 4151, 4627]), 'TN': array([54450, 61156, 61586, 59951, 60478, 62218, 61036, 61686, 62259,
       61252]), 'FPs': array([8550, 1844, 1414, 3049, 2522,  782, 1964, 1314,  741, 1748]), 'FNs': array([ 407, 1689, 2925, 3148, 2927, 4379, 1308, 1923, 2849, 2373]), 'accuracy': 0.6581714285714285, 'per_class_accuracy': array([0.87204286, 0.94952857, 0.93801429, 0.91147143, 0.92215714,
       0.92627143, 0.95325714, 0.95375714, 0.94871429, 0.94112857]), 'per_class_accuracy_mean': 0.9316342857142856, 'precision': array([0.43538269, 0.74227813, 0.74239388, 0.55817997, 0.61758908,
       0.77020276, 0.74346917, 0.79439837, 0.84852821, 0.72580392]), 'precision_mean': 0.6978226187848244, 'recall': array([0.94185714, 0.75871429, 0.58214286, 0.55028571, 0.58185714,
       0.37442857, 0.81314286, 0.72528571, 0.593     , 0.661     ]), 'recall_mean': 0.6581714285714286, 'predicted_class_distribution': array([15143,  7155,  54

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:21<00:00, 12.71it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0042, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 855.02it/s]

Loss (orig, final): 0.10257057845592499 0.004155976697802544
L2 norm of weight change: 0.44259801506996155
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:31<00:00,  8.61it/s]


Post-edit metrics: {'TP': array([6500, 5297,    2, 4114, 4294, 2770, 5588, 5037, 4581, 4713]), 'TN': array([51257, 61192, 62998, 59383, 59986, 62033, 61332, 61769, 61889,
       61057]), 'FPs': array([11743,  1808,     2,  3617,  3014,   967,  1668,  1231,  1111,
        1943]), 'FNs': array([ 500, 1703, 6998, 2886, 2706, 4230, 1412, 1963, 2419, 2287]), 'accuracy': 0.6128, 'per_class_accuracy': array([0.8251    , 0.94984286, 0.9       , 0.9071    , 0.91828571,
       0.92575714, 0.956     , 0.95437143, 0.94957143, 0.93957143]), 'per_class_accuracy_mean': 0.9225599999999998, 'precision': array([0.35630105, 0.74553132, 0.5       , 0.53214332, 0.58757526,
       0.74123629, 0.77012128, 0.80360562, 0.80481377, 0.70808293]), 'precision_mean': 0.6549410829014076, 'recall': array([9.28571429e-01, 7.56714286e-01, 2.85714286e-04, 5.87714286e-01,
       6.13428571e-01, 3.95714286e-01, 7.98285714e-01, 7.19571429e-01,
       6.54428571e-01, 6.73285714e-01]), 'recall_mean': 0.6128, 'predicted_class

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:15<00:00, 17.63it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0030, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 642.21it/s]

Loss (orig, final): 0.1092892438173294 0.002997185569256544
L2 norm of weight change: 0.47515708208084106
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.05it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6517, 5289,    0, 4118, 4284, 2750, 5549, 5038, 4547, 4707]), 'TN': array([50915, 61217, 63000, 59414, 60008, 62064, 61433, 61766, 61927,
       61055]), 'FPs': array([12085,  1783,     0,  3586,  2992,   936,  1567,  1234,  1073,
        1945]), 'FNs': array([ 483, 1711, 7000, 2882, 2716, 4250, 1451, 1962, 2453, 2293]), 'accuracy': 0.6114142857142857, 'per_class_accuracy': array([0.82045714, 0.95008571, 0.9       , 0.9076    , 0.91845714,
       0.92591429, 0.95688571, 0.95434286, 0.94962857, 0.93945714]), 'per_class_accuracy_mean': 0.922282857142857, 'precision': array([0.35033867, 0.74787896, 0.        , 0.53452752, 0.58878505,
       0.7460662 , 0.77979202, 0.80325255, 0.80907473, 0.70760673]), 'precision_mean': 0.6067322430777915, 'recall': array([0.931     , 0.75557143, 0.        , 0.58828571, 0.612     ,
       0.39285714, 0.79271429, 0.71971429, 0.64957143, 0.67242857]), 'recall_mean': 0.6114142857142857, 'predicted_class_distribution': array([1

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.50it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0043, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 620.17it/s]

Loss (orig, final): 0.1294988989830017 0.004278149455785751
L2 norm of weight change: 0.5764836668968201
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 22.83it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6656, 5213,    0, 3957, 4098, 2515, 5453, 4972, 4025, 4660]), 'TN': array([47867, 61317, 63000, 59734, 60378, 62330, 61556, 61796, 62371,
       61200]), 'FPs': array([15133,  1683,     0,  3266,  2622,   670,  1444,  1204,   629,
        1800]), 'FNs': array([ 344, 1787, 7000, 3043, 2902, 4485, 1547, 2028, 2975, 2340]), 'accuracy': 0.5935571428571429, 'per_class_accuracy': array([0.7789    , 0.95042857, 0.9       , 0.90987143, 0.92108571,
       0.92635714, 0.95727143, 0.95382857, 0.94851429, 0.94085714]), 'per_class_accuracy_mean': 0.9187114285714285, 'precision': array([0.30547524, 0.75594548, 0.        , 0.54783331, 0.60982143,
       0.78963893, 0.79063361, 0.80505181, 0.86484744, 0.72136223]), 'precision_mean': 0.619060948121318, 'recall': array([0.95085714, 0.74471429, 0.        , 0.56528571, 0.58542857,
       0.35928571, 0.779     , 0.71028571, 0.575     , 0.66571429]), 'recall_mean': 0.5935571428571429, 'predicted_class_distribution': array([2

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.95it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0049, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 616.74it/s]

Loss (orig, final): 0.1268656700849533 0.0049268221482634544
L2 norm of weight change: 0.5678382515907288
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.68it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6646, 5218,    0, 3989, 4134, 2539, 5480, 4985, 4050, 4670]), 'TN': array([48271, 61305, 63000, 59685, 60329, 62316, 61513, 61773, 62342,
       61177]), 'FPs': array([14729,  1695,     0,  3315,  2671,   684,  1487,  1227,   658,
        1823]), 'FNs': array([ 354, 1782, 7000, 3011, 2866, 4461, 1520, 2015, 2950, 2330]), 'accuracy': 0.5958714285714286, 'per_class_accuracy': array([0.78452857, 0.95032857, 0.9       , 0.90962857, 0.9209    ,
       0.9265    , 0.95704286, 0.95368571, 0.94845714, 0.94067143]), 'per_class_accuracy_mean': 0.9191742857142857, 'precision': array([0.31092398, 0.75480978, 0.        , 0.5461391 , 0.60749449,
       0.78777536, 0.78656524, 0.80247907, 0.86023789, 0.7192361 ]), 'precision_mean': 0.6175661013300783, 'recall': array([0.94942857, 0.74542857, 0.        , 0.56985714, 0.59057143,
       0.36271429, 0.78285714, 0.71214286, 0.57857143, 0.66714286]), 'recall_mean': 0.5958714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.55it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0061, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 636.91it/s]

Loss (orig, final): 0.1714026927947998 0.006062868516892195
L2 norm of weight change: 0.7601000666618347
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.67it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6846, 4022, 5248, 4311, 3417, 2328, 5321, 4911,    0,    0]), 'TN': array([42739, 62268, 58951, 58807, 61500, 62484, 61757, 61898, 63000,
       63000]), 'FPs': array([20261,   732,  4049,  4193,  1500,   516,  1243,  1102,     0,
           0]), 'FNs': array([ 154, 2978, 1752, 2689, 3583, 4672, 1679, 2089, 7000, 7000]), 'accuracy': 0.5200571428571429, 'per_class_accuracy': array([0.70835714, 0.947     , 0.91712857, 0.90168571, 0.92738571,
       0.92588571, 0.95825714, 0.95441429, 0.9       , 0.9       ]), 'per_class_accuracy_mean': 0.9040114285714287, 'precision': array([0.25255469, 0.8460244 , 0.56448317, 0.50693791, 0.69493594,
       0.8185654 , 0.81063376, 0.81673042, 0.        , 0.        ]), 'precision_mean': 0.5310865684081015, 'recall': array([0.978     , 0.57457143, 0.74971429, 0.61585714, 0.48814286,
       0.33257143, 0.76014286, 0.70157143, 0.        , 0.        ]), 'recall_mean': 0.5200571428571429, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.71it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0046, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 637.89it/s]

Loss (orig, final): 0.1415344923734665 0.004579402040690184
L2 norm of weight change: 0.5554224252700806
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.83it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6538, 5307,    0, 4039, 4205, 2703, 5476, 4951, 4490, 4675]), 'TN': array([49808, 61168, 63000, 59596, 60215, 62127, 61541, 61876, 61968,
       61085]), 'FPs': array([13192,  1832,     0,  3404,  2785,   873,  1459,  1124,  1032,
        1915]), 'FNs': array([ 462, 1693, 7000, 2961, 2795, 4297, 1524, 2049, 2510, 2325]), 'accuracy': 0.6054857142857143, 'per_class_accuracy': array([0.80494286, 0.94964286, 0.9       , 0.90907143, 0.92028571,
       0.92614286, 0.95738571, 0.95467143, 0.9494    , 0.93942857]), 'per_class_accuracy_mean': 0.921097142857143, 'precision': array([0.33137354, 0.74338143, 0.        , 0.54265753, 0.60157368,
       0.75587248, 0.78961788, 0.81497942, 0.81311119, 0.70940819]), 'precision_mean': 0.6101975349282764, 'recall': array([0.934     , 0.75814286, 0.        , 0.577     , 0.60071429,
       0.38614286, 0.78228571, 0.70728571, 0.64142857, 0.66785714]), 'recall_mean': 0.6054857142857143, 'predicted_class_distribution': array([1

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.96it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0033, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 630.40it/s]

Loss (orig, final): 0.149329274892807 0.0032847244292497635
L2 norm of weight change: 0.5839526057243347
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.42it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6580, 5263,    0, 4023, 4155, 2626, 5391, 4912, 4399, 4663]), 'TN': array([48810, 61247, 63000, 59630, 60313, 62213, 61678, 61907, 62055,
       61159]), 'FPs': array([14190,  1753,     0,  3370,  2687,   787,  1322,  1093,   945,
        1841]), 'FNs': array([ 420, 1737, 7000, 2977, 2845, 4374, 1609, 2088, 2601, 2337]), 'accuracy': 0.6001714285714286, 'per_class_accuracy': array([0.79128571, 0.95014286, 0.9       , 0.90932857, 0.92097143,
       0.92627143, 0.95812857, 0.95455714, 0.94934286, 0.94031429]), 'per_class_accuracy_mean': 0.9200342857142857, 'precision': array([0.31680308, 0.75014253, 0.        , 0.5441634 , 0.60727857,
       0.76941108, 0.80306867, 0.81798501, 0.82316617, 0.71694342]), 'precision_mean': 0.6148961931662853, 'recall': array([0.94      , 0.75185714, 0.        , 0.57471429, 0.59357143,
       0.37514286, 0.77014286, 0.70171429, 0.62842857, 0.66614286]), 'recall_mean': 0.6001714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.86it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0048, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 623.87it/s]

Loss (orig, final): 0.10928415507078171 0.0048429155722260475
L2 norm of weight change: 0.40966880321502686
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.97it/s]


Post-edit metrics: {'TP': array([6776, 5290, 4954, 3912, 4012, 2616, 5621, 5321,  148, 4509]), 'TN': array([52532, 61241, 59996, 59840, 60560, 62243, 61155, 61083, 62999,
       61510]), 'FPs': array([10468,  1759,  3004,  3160,  2440,   757,  1845,  1917,     1,
        1490]), 'FNs': array([ 224, 1710, 2046, 3088, 2988, 4384, 1379, 1679, 6852, 2491]), 'accuracy': 0.6165571428571428, 'per_class_accuracy': array([0.84725714, 0.95044286, 0.92785714, 0.91074286, 0.92245714,
       0.92655714, 0.95394286, 0.94862857, 0.9021    , 0.94312857]), 'per_class_accuracy_mean': 0.9233114285714287, 'precision': array([0.39294827, 0.75046106, 0.62251822, 0.55316742, 0.62182269,
       0.77557071, 0.75287972, 0.73514783, 0.99328859, 0.75162527]), 'precision_mean': 0.6949429784627073, 'recall': array([0.968     , 0.75571429, 0.70771429, 0.55885714, 0.57314286,
       0.37371429, 0.803     , 0.76014286, 0.02114286, 0.64414286]), 'recall_mean': 0.6165571428571429, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.80it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0077, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 616.95it/s]

Loss (orig, final): 0.0899723470211029 0.00772605836391449
L2 norm of weight change: 0.31624412536621094
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.66it/s]


Post-edit metrics: {'TP': array([6271, 5282, 5599, 3738, 3965, 2703, 5622, 5310, 3748, 4657]), 'TN': array([59440, 61271, 57193, 59989, 60702, 62158, 60995, 61327, 62509,
       61311]), 'FPs': array([3560, 1729, 5807, 3011, 2298,  842, 2005, 1673,  491, 1689]), 'FNs': array([ 729, 1718, 1401, 3262, 3035, 4297, 1378, 1690, 3252, 2343]), 'accuracy': 0.6699285714285714, 'per_class_accuracy': array([0.93872857, 0.95075714, 0.89702857, 0.91038571, 0.92381429,
       0.92658571, 0.95167143, 0.95195714, 0.94652857, 0.9424    ]), 'per_class_accuracy_mean': 0.9339857142857142, 'precision': array([0.63788017, 0.75338753, 0.49088199, 0.55385983, 0.63308319,
       0.76248237, 0.73711813, 0.76041816, 0.88417079, 0.73384809]), 'precision_mean': 0.6947130267977234, 'recall': array([0.89585714, 0.75457143, 0.79985714, 0.534     , 0.56642857,
       0.38614286, 0.80314286, 0.75857143, 0.53542857, 0.66528571]), 'recall_mean': 0.6699285714285714, 'predicted_class_distribution': array([ 9831,  7011, 114

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.48it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0056, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 627.03it/s]

Loss (orig, final): 0.15200495719909668 0.005576925352215767
L2 norm of weight change: 0.6786588430404663
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.83it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6920, 4210, 4104, 3448, 3150, 2353,    0, 5313,    0, 3909]), 'TN': array([35512, 62121, 61611, 60696, 61882, 62483, 63000, 60919, 63000,
       62183]), 'FPs': array([27488,   879,  1389,  2304,  1118,   517,     0,  2081,     0,
         817]), 'FNs': array([  80, 2790, 2896, 3552, 3850, 4647, 7000, 1687, 7000, 3091]), 'accuracy': 0.47724285714285714, 'per_class_accuracy': array([0.60617143, 0.94758571, 0.93878571, 0.91634286, 0.92902857,
       0.92622857, 0.9       , 0.94617143, 0.9       , 0.94417143]), 'per_class_accuracy_mean': 0.8954485714285714, 'precision': array([0.20111602, 0.82727451, 0.74713271, 0.59944367, 0.73805061,
       0.81986063, 0.        , 0.71855559, 0.        , 0.82712653]), 'precision_mean': 0.5478560275356475, 'recall': array([0.98857143, 0.60142857, 0.58628571, 0.49257143, 0.45      ,
       0.33614286, 0.        , 0.759     , 0.        , 0.55842857]), 'recall_mean': 0.4772428571428571, 'predicted_class_distribution': array(

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 19.60it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0498, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:54<00:00, 736.60it/s]

Loss (orig, final): 0.1810174286365509 0.04975363612174988
L2 norm of weight change: 0.4843009412288666
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.50it/s]


Post-edit metrics: {'TP': array([6364, 5602, 5228,   70, 4215, 2876, 5397, 5141, 4647, 4243]), 'TN': array([52338, 60493, 59106, 62916, 60165, 61970, 61622, 61634, 61751,
       61788]), 'FPs': array([10662,  2507,  3894,    84,  2835,  1030,  1378,  1366,  1249,
        1212]), 'FNs': array([ 636, 1398, 1772, 6930, 2785, 4124, 1603, 1859, 2353, 2757]), 'accuracy': 0.6254714285714286, 'per_class_accuracy': array([0.8386    , 0.94421429, 0.91905714, 0.8998    , 0.91971429,
       0.92637143, 0.95741429, 0.95392857, 0.94854286, 0.9433    ]), 'per_class_accuracy_mean': 0.9250942857142859, 'precision': array([0.37378128, 0.69083734, 0.57311993, 0.45454545, 0.59787234,
       0.73630312, 0.79660517, 0.79007223, 0.78816147, 0.77781852]), 'precision_mean': 0.6579116841614551, 'recall': array([0.90914286, 0.80028571, 0.74685714, 0.01      , 0.60214286,
       0.41085714, 0.771     , 0.73442857, 0.66385714, 0.60614286]), 'recall_mean': 0.6254714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:13<00:00, 20.10it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0060, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:59<00:00, 675.88it/s]

Loss (orig, final): 0.16830569505691528 0.006016704719513655
L2 norm of weight change: 0.749998927116394
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.72it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6869, 5023, 4731, 3736, 3513, 2194, 5510, 5386,    0, 4072]), 'TN': array([47773, 61548, 60367, 60237, 61359, 62595, 61439, 60631, 63000,
       62085]), 'FPs': array([15227,  1452,  2633,  2763,  1641,   405,  1561,  2369,     0,
         915]), 'FNs': array([ 131, 1977, 2269, 3264, 3487, 4806, 1490, 1614, 7000, 2928]), 'accuracy': 0.5862, 'per_class_accuracy': array([0.7806    , 0.95101429, 0.92997143, 0.9139    , 0.92674286,
       0.92555714, 0.95641429, 0.9431    , 0.9       , 0.9451    ]), 'per_class_accuracy_mean': 0.9172399999999999, 'precision': array([0.31087075, 0.7757529 , 0.64244976, 0.57485767, 0.68160652,
       0.84417083, 0.77923915, 0.69451966, 0.        , 0.81652296]), 'precision_mean': 0.6119990191947056, 'recall': array([0.98128571, 0.71757143, 0.67585714, 0.53371429, 0.50185714,
       0.31342857, 0.78714286, 0.76942857, 0.        , 0.58171429]), 'recall_mean': 0.5862, 'predicted_class_distribution': array([22096,  6475,  7364,  64

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.87it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0165, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:02<00:00, 635.04it/s]

Loss (orig, final): 0.10949495434761047 0.016530655324459076
L2 norm of weight change: 0.3210103511810303
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.82it/s]


Post-edit metrics: {'TP': array([6370, 5420, 5272, 3759, 4098, 2769, 5791, 5333, 3727, 4631]), 'TN': array([59163, 60984, 58898, 59951, 60492, 62057, 60478, 61296, 62561,
       61290]), 'FPs': array([3837, 2016, 4102, 3049, 2508,  943, 2522, 1704,  439, 1710]), 'FNs': array([ 630, 1580, 1728, 3241, 2902, 4231, 1209, 1667, 3273, 2369]), 'accuracy': 0.6738571428571428, 'per_class_accuracy': array([0.93618571, 0.94862857, 0.91671429, 0.91014286, 0.92271429,
       0.92608571, 0.9467    , 0.95184286, 0.94697143, 0.94172857]), 'per_class_accuracy_mean': 0.9347714285714286, 'precision': array([0.62408151, 0.7288865 , 0.56240666, 0.55214454, 0.62034514,
       0.74595905, 0.69661975, 0.75785136, 0.89462314, 0.73032645]), 'precision_mean': 0.6913244091802897, 'recall': array([0.91      , 0.77428571, 0.75314286, 0.537     , 0.58542857,
       0.39557143, 0.82728571, 0.76185714, 0.53242857, 0.66157143]), 'recall_mean': 0.6738571428571428, 'predicted_class_distribution': array([10207,  7436,  93

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.44it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0072, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 625.18it/s]

Loss (orig, final): 0.1825065016746521 0.007219737395644188
L2 norm of weight change: 0.651501476764679
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.69it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6804, 5170, 5092, 3907, 3748, 2363, 5603, 5410,    0, 4268]), 'TN': array([51366, 61397, 59542, 59781, 61064, 62467, 61174, 60692, 63000,
       61882]), 'FPs': array([11634,  1603,  3458,  3219,  1936,   533,  1826,  2308,     0,
        1118]), 'FNs': array([ 196, 1830, 1908, 3093, 3252, 4637, 1397, 1590, 7000, 2732]), 'accuracy': 0.6052142857142857, 'per_class_accuracy': array([0.831     , 0.95095714, 0.92334286, 0.90982857, 0.92588571,
       0.92614286, 0.95395714, 0.94431429, 0.9       , 0.945     ]), 'per_class_accuracy_mean': 0.9210428571428573, 'precision': array([0.3690205 , 0.76332497, 0.59555556, 0.54827393, 0.65939479,
       0.81595304, 0.75420649, 0.7009588 , 0.        , 0.79242481]), 'precision_mean': 0.5999112871767901, 'recall': array([0.972     , 0.73857143, 0.72742857, 0.55814286, 0.53542857,
       0.33757143, 0.80042857, 0.77285714, 0.        , 0.60971429]), 'recall_mean': 0.6052142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.31it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0087, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 617.40it/s]

Loss (orig, final): 0.14466431736946106 0.008713189512491226
L2 norm of weight change: 0.4171293079853058
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.74it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6753, 5313, 5130, 3893, 3910, 2602, 5638, 5338,    0, 4492]), 'TN': array([52662, 61229, 59515, 59871, 60792, 62261, 61087, 61097, 63000,
       61555]), 'FPs': array([10338,  1771,  3485,  3129,  2208,   739,  1913,  1903,     0,
        1445]), 'FNs': array([ 247, 1687, 1870, 3107, 3090, 4398, 1362, 1662, 7000, 2508]), 'accuracy': 0.6152714285714286, 'per_class_accuracy': array([0.84878571, 0.9506    , 0.9235    , 0.91091429, 0.92431429,
       0.92661429, 0.95321429, 0.94907143, 0.9       , 0.94352857]), 'per_class_accuracy_mean': 0.9230542857142856, 'precision': array([0.39512024, 0.75      , 0.59547301, 0.55440046, 0.63909774,
       0.77880874, 0.74665607, 0.737191  , 0.        , 0.75661108]), 'precision_mean': 0.5953358341680772, 'recall': array([0.96471429, 0.759     , 0.73285714, 0.55614286, 0.55857143,
       0.37171429, 0.80542857, 0.76257143, 0.        , 0.64171429]), 'recall_mean': 0.6152714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.32it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0051, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 629.02it/s]

Loss (orig, final): 0.14169779419898987 0.00514156324788928
L2 norm of weight change: 0.6584505438804626
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.96it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6867, 4613, 5115, 3811, 3934, 2512, 5404, 5052,    0,   88]), 'TN': array([43883, 61904, 59258, 60074, 60680, 62346, 61617, 61634, 63000,
       63000]), 'FPs': array([19117,  1096,  3742,  2926,  2320,   654,  1383,  1366,     0,
           0]), 'FNs': array([ 133, 2387, 1885, 3189, 3066, 4488, 1596, 1948, 7000, 6912]), 'accuracy': 0.5342285714285714, 'per_class_accuracy': array([0.725     , 0.95024286, 0.91961429, 0.91264286, 0.92305714,
       0.92654286, 0.95744286, 0.95265714, 0.9       , 0.90125714]), 'per_class_accuracy_mean': 0.9068457142857141, 'precision': array([0.26427802, 0.80802242, 0.57750931, 0.56568205, 0.62903742,
       0.7934302 , 0.79622808, 0.78716111, 0.        , 1.        ]), 'precision_mean': 0.622134861133812, 'recall': array([0.981     , 0.659     , 0.73071429, 0.54442857, 0.562     ,
       0.35885714, 0.772     , 0.72171429, 0.        , 0.01257143]), 'recall_mean': 0.5342285714285715, 'predicted_class_distribution': array([2

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.85it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0097, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 622.86it/s]

Loss (orig, final): 0.09996029734611511 0.009733401238918304
L2 norm of weight change: 0.3777402937412262
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.29it/s]


Post-edit metrics: {'TP': array([6476, 5564, 5443, 3608, 3923, 2880, 5663, 5334, 3487, 4200]), 'TN': array([58086, 60672, 58113, 60326, 60800, 61920, 60911, 61242, 62648,
       61860]), 'FPs': array([4914, 2328, 4887, 2674, 2200, 1080, 2089, 1758,  352, 1140]), 'FNs': array([ 524, 1436, 1557, 3392, 3077, 4120, 1337, 1666, 3513, 2800]), 'accuracy': 0.6654, 'per_class_accuracy': array([0.92231429, 0.94622857, 0.90794286, 0.91334286, 0.92461429,
       0.92571429, 0.95105714, 0.95108571, 0.94478571, 0.94371429]), 'per_class_accuracy_mean': 0.93308, 'precision': array([0.56856892, 0.70501774, 0.52691191, 0.57433938, 0.640699  ,
       0.72727273, 0.73052116, 0.75211506, 0.90830946, 0.78651685]), 'precision_mean': 0.6920272204618921, 'recall': array([0.92514286, 0.79485714, 0.77757143, 0.51542857, 0.56042857,
       0.41142857, 0.809     , 0.762     , 0.49814286, 0.6       ]), 'recall_mean': 0.6653999999999999, 'predicted_class_distribution': array([11390,  7892, 10330,  6282,  6123,  3960

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.86it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0050, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 624.42it/s]

Loss (orig, final): 0.17585551738739014 0.004960277117788792
L2 norm of weight change: 0.6561248898506165
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.95it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6814, 5156, 5065, 3907, 3747, 2350, 5604, 5413,    0, 4279]), 'TN': array([51248, 61413, 59634, 59782, 61055, 62479, 61173, 60682, 63000,
       61869]), 'FPs': array([11752,  1587,  3366,  3218,  1945,   521,  1827,  2318,     0,
        1131]), 'FNs': array([ 186, 1844, 1935, 3093, 3253, 4650, 1396, 1587, 7000, 2721]), 'accuracy': 0.6047857142857143, 'per_class_accuracy': array([0.82945714, 0.95098571, 0.92427143, 0.90984286, 0.92574286,
       0.92612857, 0.95395714, 0.94421429, 0.9       , 0.94497143]), 'per_class_accuracy_mean': 0.9209571428571428, 'precision': array([0.36701497, 0.76464482, 0.6007591 , 0.54835088, 0.65829234,
       0.81853013, 0.75413807, 0.70016815, 0.        , 0.7909427 ]), 'precision_mean': 0.6002841163095236, 'recall': array([0.97342857, 0.73657143, 0.72357143, 0.55814286, 0.53528571,
       0.33571429, 0.80057143, 0.77328571, 0.        , 0.61128571]), 'recall_mean': 0.6047857142857143, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.44it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0049, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 619.02it/s]

Loss (orig, final): 0.07321176677942276 0.004914899356663227
L2 norm of weight change: 0.22401650249958038
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.41it/s]


Post-edit metrics: {'TP': array([6424, 5384, 5196, 3830, 4132, 2793, 5671, 5284, 4065, 4624]), 'TN': array([58670, 61070, 59349, 59919, 60390, 62022, 60984, 61351, 62369,
       61279]), 'FPs': array([4330, 1930, 3651, 3081, 2610,  978, 2016, 1649,  631, 1721]), 'FNs': array([ 576, 1616, 1804, 3170, 2868, 4207, 1329, 1716, 2935, 2376]), 'accuracy': 0.6771857142857143, 'per_class_accuracy': array([0.92991429, 0.94934286, 0.92207143, 0.9107    , 0.92174286,
       0.92592857, 0.95221429, 0.95192857, 0.94905714, 0.94147143]), 'per_class_accuracy_mean': 0.935437142857143, 'precision': array([0.59735912, 0.7361225 , 0.58731773, 0.55418897, 0.61287452,
       0.74065235, 0.73773904, 0.76215203, 0.86563032, 0.72876281]), 'precision_mean': 0.692279939621743, 'recall': array([0.91771429, 0.76914286, 0.74228571, 0.54714286, 0.59028571,
       0.399     , 0.81014286, 0.75485714, 0.58071429, 0.66057143]), 'recall_mean': 0.6771857142857144, 'predicted_class_distribution': array([10754,  7314,  8847

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.41it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0064, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 620.56it/s]

Loss (orig, final): 0.13449276983737946 0.006381826940923929
L2 norm of weight change: 0.4965915083885193
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.13it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6620, 5112, 5096, 3077, 4171, 2761, 5311, 4818, 4134,    0]), 'TN': array([47530, 61453, 59411, 61376, 60117, 62083, 61808, 62014, 62308,
       63000]), 'FPs': array([15470,  1547,  3589,  1624,  2883,   917,  1192,   986,   692,
           0]), 'FNs': array([ 380, 1888, 1904, 3923, 2829, 4239, 1689, 2182, 2866, 7000]), 'accuracy': 0.5871428571428572, 'per_class_accuracy': array([0.77357143, 0.95092857, 0.92152857, 0.92075714, 0.9184    ,
       0.92634286, 0.95884286, 0.95474286, 0.94917143, 0.9       ]), 'per_class_accuracy_mean': 0.9174285714285715, 'precision': array([0.29968311, 0.76768284, 0.58675878, 0.65454159, 0.59129572,
       0.75067972, 0.81669998, 0.83011716, 0.85661003, 0.        ]), 'precision_mean': 0.6154068926383004, 'recall': array([0.94571429, 0.73028571, 0.728     , 0.43957143, 0.59585714,
       0.39442857, 0.75871429, 0.68828571, 0.59057143, 0.        ]), 'recall_mean': 0.5871428571428572, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.73it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0047, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 633.33it/s]

Loss (orig, final): 0.12387923896312714 0.00467915553599596
L2 norm of weight change: 0.6099483966827393
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.20it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6654, 5305,    0, 2944, 3869, 2688, 3833, 5085, 4222, 4403]), 'TN': array([42168, 61133, 63000, 61456, 60819, 62189, 62809, 61602, 62186,
       61641]), 'FPs': array([20832,  1867,     0,  1544,  2181,   811,   191,  1398,   814,
        1359]), 'FNs': array([ 346, 1695, 7000, 4056, 3131, 4312, 3167, 1915, 2778, 2597]), 'accuracy': 0.5571857142857143, 'per_class_accuracy': array([0.69745714, 0.94911429, 0.9       , 0.92      , 0.92411429,
       0.92681429, 0.95202857, 0.95267143, 0.94868571, 0.94348571]), 'per_class_accuracy_mean': 0.9114371428571427, 'precision': array([0.24208688, 0.7396821 , 0.        , 0.65597148, 0.63950413,
       0.76821949, 0.95253479, 0.78435909, 0.83836378, 0.76414439]), 'precision_mean': 0.6384866140004244, 'recall': array([0.95057143, 0.75785714, 0.        , 0.42057143, 0.55271429,
       0.384     , 0.54757143, 0.72642857, 0.60314286, 0.629     ]), 'recall_mean': 0.5571857142857144, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.80it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0371, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 619.46it/s]

Loss (orig, final): 0.21750865876674652 0.03710708022117615
L2 norm of weight change: 0.7015421986579895
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.13it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6815, 5161, 5078, 3899, 3752, 2423, 5618, 5310,    0, 4209]), 'TN': array([50864, 61419, 59591, 59748, 61037, 62428, 61110, 61115, 63000,
       61953]), 'FPs': array([12136,  1581,  3409,  3252,  1963,   572,  1890,  1885,     0,
        1047]), 'FNs': array([ 185, 1839, 1922, 3101, 3248, 4577, 1382, 1690, 7000, 2791]), 'accuracy': 0.6037857142857143, 'per_class_accuracy': array([0.82398571, 0.95114286, 0.92384286, 0.90924286, 0.92555714,
       0.92644286, 0.95325714, 0.94892857, 0.9       , 0.94517143]), 'per_class_accuracy_mean': 0.9207571428571428, 'precision': array([0.35961163, 0.76549985, 0.59832685, 0.54523843, 0.65651794,
       0.80901503, 0.74826851, 0.73801251, 0.        , 0.80079909]), 'precision_mean': 0.6021289832024724, 'recall': array([0.97357143, 0.73728571, 0.72542857, 0.557     , 0.536     ,
       0.34614286, 0.80257143, 0.75857143, 0.        , 0.60128571]), 'recall_mean': 0.6037857142857143, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.18it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0236, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 617.09it/s]


Loss (orig, final): 0.11024066805839539 0.023615913465619087
L2 norm of weight change: 0.2949005663394928
Performing post-edit metric & KNN calculations on validation set.


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.07it/s]


Post-edit metrics: {'TP': array([6507, 5321, 5190, 3927, 4077, 2704, 5659, 5254, 3769, 4659]), 'TN': array([57856, 61198, 59422, 59735, 60507, 62110, 61035, 61407, 62539,
       61258]), 'FPs': array([5144, 1802, 3578, 3265, 2493,  890, 1965, 1593,  461, 1742]), 'FNs': array([ 493, 1679, 1810, 3073, 2923, 4296, 1341, 1746, 3231, 2341]), 'accuracy': 0.6723857142857143, 'per_class_accuracy': array([0.91947143, 0.95027143, 0.92302857, 0.90945714, 0.92262857,
       0.92591429, 0.95277143, 0.9523    , 0.94725714, 0.94167143]), 'per_class_accuracy_mean': 0.9344771428571429, 'precision': array([0.55849283, 0.74701671, 0.59192518, 0.54602336, 0.62054795,
       0.75236505, 0.74226128, 0.76734336, 0.89101655, 0.72785502]), 'precision_mean': 0.6944847292858664, 'recall': array([0.92957143, 0.76014286, 0.74142857, 0.561     , 0.58242857,
       0.38628571, 0.80842857, 0.75057143, 0.53842857, 0.66557143]), 'recall_mean': 0.6723857142857144, 'predicted_class_distribution': array([11651,  7123,  87

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.78it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0169, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:03<00:00, 633.02it/s]

Loss (orig, final): 0.2073083519935608 0.016881154850125313
L2 norm of weight change: 0.7358757853507996
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 24.88it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6843, 5067, 5000, 3873, 3585, 2273, 5576, 5315,    0, 4083]), 'TN': array([49487, 61498, 59724, 59853, 61233, 62534, 61264, 60960, 63000,
       62062]), 'FPs': array([13513,  1502,  3276,  3147,  1767,   466,  1736,  2040,     0,
         938]), 'FNs': array([ 157, 1933, 2000, 3127, 3415, 4727, 1424, 1685, 7000, 2917]), 'accuracy': 0.5945, 'per_class_accuracy': array([0.80471429, 0.95092857, 0.92462857, 0.91037143, 0.92597143,
       0.92581429, 0.95485714, 0.94678571, 0.9       , 0.94492857]), 'per_class_accuracy_mean': 0.9189, 'precision': array([0.33616624, 0.77135028, 0.6041566 , 0.5517094 , 0.66984305,
       0.82986491, 0.76258206, 0.72263766, 0.        , 0.81318462]), 'precision_mean': 0.6061494828091001, 'recall': array([0.97757143, 0.72385714, 0.71428571, 0.55328571, 0.51214286,
       0.32471429, 0.79657143, 0.75928571, 0.        , 0.58328571]), 'recall_mean': 0.5945, 'predicted_class_distribution': array([20356,  6569,  8276,  7020,  5352,  

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.99it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0135, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [01:04<00:00, 623.88it/s]

Loss (orig, final): 0.13021735846996307 0.013461943715810776
L2 norm of weight change: 0.38087427616119385
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.99it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6734, 5287, 5240, 3873, 4007, 2613, 5609, 5270,    0, 4459]), 'TN': array([52923, 61276, 59127, 59863, 60588, 62237, 61155, 61296, 63000,
       61627]), 'FPs': array([10077,  1724,  3873,  3137,  2412,   763,  1845,  1704,     0,
        1373]), 'FNs': array([ 266, 1713, 1760, 3127, 2993, 4387, 1391, 1730, 7000, 2541]), 'accuracy': 0.6156, 'per_class_accuracy': array([0.85224286, 0.9509    , 0.91952857, 0.91051429, 0.92278571,
       0.92642857, 0.95377143, 0.95094286, 0.9       , 0.94408571]), 'per_class_accuracy_mean': 0.9231199999999999, 'precision': array([0.40057105, 0.7541007 , 0.57500274, 0.55249643, 0.62424054,
       0.77399289, 0.75248189, 0.75566389, 0.        , 0.76457476]), 'precision_mean': 0.5953124900801472, 'recall': array([0.962     , 0.75528571, 0.74857143, 0.55328571, 0.57242857,
       0.37328571, 0.80128571, 0.75285714, 0.        , 0.637     ]), 'recall_mean': 0.6156, 'predicted_class_distribution': array([16811,  7011,  9113,  70

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.51it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0078, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:55<00:00, 719.34it/s]

Loss (orig, final): 0.1469847410917282 0.007828700356185436
L2 norm of weight change: 0.6315382122993469
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 43.87it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6899, 5179,    0, 4047, 3708, 2415, 5518, 5236,    0, 4502]), 'TN': array([42839, 61371, 63000, 59507, 61090, 62431, 61457, 61263, 63000,
       61546]), 'FPs': array([20161,  1629,     0,  3493,  1910,   569,  1543,  1737,     0,
        1454]), 'FNs': array([ 101, 1821, 7000, 2953, 3292, 4585, 1482, 1764, 7000, 2498]), 'accuracy': 0.5357714285714286, 'per_class_accuracy': array([0.71054286, 0.95071429, 0.9       , 0.90791429, 0.92568571,
       0.92637143, 0.95678571, 0.94998571, 0.9       , 0.94354286]), 'per_class_accuracy_mean': 0.9071542857142857, 'precision': array([0.25495196, 0.76072268, 0.        , 0.5367374 , 0.66002136,
       0.80931635, 0.78147571, 0.75089631, 0.        , 0.75587643]), 'precision_mean': 0.5309998205287145, 'recall': array([0.98557143, 0.73985714, 0.        , 0.57814286, 0.52971429,
       0.345     , 0.78828571, 0.748     , 0.        , 0.64314286]), 'recall_mean': 0.5357714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 44.72it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0043, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 889.27it/s]

Loss (orig, final): 0.07807697355747223 0.004263629671186209
L2 norm of weight change: 0.3381984829902649
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 31.43it/s]


Post-edit metrics: {'TP': array([6539, 5274, 5406, 3784, 4062, 2691, 5506, 5169, 3779, 4364]), 'TN': array([57207, 61307, 58251, 60018, 60472, 62174, 61307, 61569, 62511,
       61758]), 'FPs': array([5793, 1693, 4749, 2982, 2528,  826, 1693, 1431,  489, 1242]), 'FNs': array([ 461, 1726, 1594, 3216, 2938, 4309, 1494, 1831, 3221, 2636]), 'accuracy': 0.6653428571428571, 'per_class_accuracy': array([0.91065714, 0.95115714, 0.90938571, 0.91145714, 0.92191429,
       0.92664286, 0.95447143, 0.9534    , 0.947     , 0.9446    ]), 'per_class_accuracy_mean': 0.9330685714285714, 'precision': array([0.53024651, 0.75699727, 0.5323486 , 0.55926692, 0.61638847,
       0.76514074, 0.76482845, 0.78318182, 0.88542643, 0.77845166]), 'precision_mean': 0.6972276872677263, 'recall': array([0.93414286, 0.75342857, 0.77228571, 0.54057143, 0.58028571,
       0.38442857, 0.78657143, 0.73842857, 0.53985714, 0.62342857]), 'recall_mean': 0.6653428571428571, 'predicted_class_distribution': array([12332,  6967, 101

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.04it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0026, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 906.22it/s]

Loss (orig, final): 0.15208272635936737 0.002588572446256876
L2 norm of weight change: 0.6998029351234436
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 35.39it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6682, 5181,    0, 3667, 3931, 2517, 5084, 4693, 4028, 4439]), 'TN': array([44442, 61352, 63000, 60319, 60646, 62339, 62031, 62146, 62342,
       61605]), 'FPs': array([18558,  1648,     0,  2681,  2354,   661,   969,   854,   658,
        1395]), 'FNs': array([ 318, 1819, 7000, 3333, 3069, 4483, 1916, 2307, 2972, 2561]), 'accuracy': 0.5746, 'per_class_accuracy': array([0.73034286, 0.95047143, 0.9       , 0.91408571, 0.92252857,
       0.92651429, 0.95878571, 0.95484286, 0.94814286, 0.94348571]), 'per_class_accuracy_mean': 0.9149200000000001, 'precision': array([0.26473851, 0.75867623, 0.        , 0.57766226, 0.62545744,
       0.79200755, 0.83991409, 0.84604291, 0.85958173, 0.76088447]), 'precision_mean': 0.6324965191532076, 'recall': array([0.95457143, 0.74014286, 0.        , 0.52385714, 0.56157143,
       0.35957143, 0.72628571, 0.67042857, 0.57542857, 0.63414286]), 'recall_mean': 0.5746, 'predicted_class_distribution': array([25240,  6829,     0,  63

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.02it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0037, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:45<00:00, 888.33it/s]

Loss (orig, final): 0.10641852021217346 0.003705390030518174
L2 norm of weight change: 0.4753883481025696
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 34.34it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6515, 5293,    0, 4111, 4287, 2741, 5565, 5040, 4543, 4714]), 'TN': array([51002, 61207, 63000, 59400, 60001, 62067, 61393, 61761, 61934,
       61044]), 'FPs': array([11998,  1793,     0,  3600,  2999,   933,  1607,  1239,  1066,
        1956]), 'FNs': array([ 485, 1707, 7000, 2889, 2713, 4259, 1435, 1960, 2457, 2286]), 'accuracy': 0.6115571428571429, 'per_class_accuracy': array([0.82167143, 0.95      , 0.9       , 0.9073    , 0.9184    ,
       0.92582857, 0.95654286, 0.9543    , 0.94967143, 0.9394    ]), 'per_class_accuracy_mean': 0.9223114285714284, 'precision': array([0.35191487, 0.74696585, 0.        , 0.53313448, 0.58838869,
       0.74605335, 0.77593419, 0.80267559, 0.8099483 , 0.70674663]), 'precision_mean': 0.6061761938338905, 'recall': array([0.93071429, 0.75614286, 0.        , 0.58728571, 0.61242857,
       0.39157143, 0.795     , 0.72      , 0.649     , 0.67342857]), 'recall_mean': 0.6115571428571429, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.10it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0038, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 909.74it/s]

Loss (orig, final): 0.07126792520284653 0.003837807569652796
L2 norm of weight change: 0.2680896818637848
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 32.45it/s]


Post-edit metrics: {'TP': array([6371, 5312, 4457, 4061, 4351, 2829, 5681, 5094, 4719, 4722]), 'TN': array([58439, 61167, 61081, 59446, 59879, 61973, 61097, 61698, 61754,
       61063]), 'FPs': array([4561, 1833, 1919, 3554, 3121, 1027, 1903, 1302, 1246, 1937]), 'FNs': array([ 629, 1688, 2543, 2939, 2649, 4171, 1319, 1906, 2281, 2278]), 'accuracy': 0.6799571428571428, 'per_class_accuracy': array([0.92585714, 0.9497    , 0.93625714, 0.90724286, 0.91757143,
       0.92574286, 0.95397143, 0.95417143, 0.94961429, 0.93978571]), 'per_class_accuracy_mean': 0.9359914285714286, 'precision': array([0.58278449, 0.74345696, 0.6990276 , 0.53328956, 0.58230728,
       0.73366183, 0.749077  , 0.79643527, 0.79111484, 0.70911548]), 'precision_mean': 0.6920270314272253, 'recall': array([0.91014286, 0.75885714, 0.63671429, 0.58014286, 0.62157143,
       0.40414286, 0.81157143, 0.72771429, 0.67414286, 0.67457143]), 'recall_mean': 0.6799571428571428, 'predicted_class_distribution': array([10932,  7145,  63

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 44.19it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0038, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 899.78it/s]

Loss (orig, final): 0.14969608187675476 0.003844767576083541
L2 norm of weight change: 0.6565413475036621
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.19it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6828, 5160, 5055, 3962, 3360, 2367, 5560, 5365,    0, 4294]), 'TN': array([50167, 61399, 59616, 59721, 61597, 62466, 61294, 60844, 63000,
       61847]), 'FPs': array([12833,  1601,  3384,  3279,  1403,   534,  1706,  2156,     0,
        1153]), 'FNs': array([ 172, 1840, 1945, 3038, 3640, 4633, 1440, 1635, 7000, 2706]), 'accuracy': 0.5993, 'per_class_accuracy': array([0.81421429, 0.95084286, 0.92387143, 0.90975714, 0.92795714,
       0.92618571, 0.95505714, 0.94584286, 0.9       , 0.94487143]), 'per_class_accuracy_mean': 0.9198599999999999, 'precision': array([0.34728651, 0.76320071, 0.59900462, 0.54716199, 0.70543775,
       0.81592554, 0.76520782, 0.71333599, 0.        , 0.78832385]), 'precision_mean': 0.6044884781844688, 'recall': array([0.97542857, 0.73714286, 0.72214286, 0.566     , 0.48      ,
       0.33814286, 0.79428571, 0.76642857, 0.        , 0.61342857]), 'recall_mean': 0.5993, 'predicted_class_distribution': array([19661,  6761,  8439,  72

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.95it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0040, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 897.51it/s]

Loss (orig, final): 0.11592701077461243 0.004042710177600384
L2 norm of weight change: 0.4025658071041107
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 27.41it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6776, 5281, 5140, 3915, 3765, 2602, 5588, 5256,    0, 4493]), 'TN': array([51968, 61254, 59432, 59776, 61026, 62252, 61237, 61325, 63000,
       61546]), 'FPs': array([11032,  1746,  3568,  3224,  1974,   748,  1763,  1675,     0,
        1454]), 'FNs': array([ 224, 1719, 1860, 3085, 3235, 4398, 1412, 1744, 7000, 2507]), 'accuracy': 0.6116571428571429, 'per_class_accuracy': array([0.8392    , 0.9505    , 0.92245714, 0.90987143, 0.92558571,
       0.92648571, 0.95464286, 0.95115714, 0.9       , 0.94341429]), 'per_class_accuracy_mean': 0.9223314285714285, 'precision': array([0.38050314, 0.75152981, 0.59026183, 0.54839613, 0.65603764,
       0.77671642, 0.76016868, 0.75833213, 0.        , 0.75550698]), 'precision_mean': 0.5977452769322766, 'recall': array([0.968     , 0.75442857, 0.73428571, 0.55928571, 0.53785714,
       0.37171429, 0.79828571, 0.75085714, 0.        , 0.64185714]), 'recall_mean': 0.6116571428571429, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 41.02it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0086, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 820.45it/s]

Loss (orig, final): 0.12469330430030823 0.00861372984945774
L2 norm of weight change: 0.564173698425293
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 29.96it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6595, 5265, 5247, 3610, 4135, 2948, 5341, 4770, 4094,    0]), 'TN': array([49899, 61244, 58916, 60614, 60291, 61852, 61758, 62100, 62331,
       63000]), 'FPs': array([13101,  1756,  4084,  2386,  2709,  1148,  1242,   900,   669,
           0]), 'FNs': array([ 405, 1735, 1753, 3390, 2865, 4052, 1659, 2230, 2906, 7000]), 'accuracy': 0.6000714285714286, 'per_class_accuracy': array([0.80705714, 0.95012857, 0.91661429, 0.91748571, 0.92037143,
       0.92571429, 0.95855714, 0.95528571, 0.94892857, 0.9       ]), 'per_class_accuracy_mean': 0.9200142857142858, 'precision': array([0.33483956, 0.74989318, 0.56231915, 0.60206805, 0.60417884,
       0.71972656, 0.81133222, 0.84126984, 0.85954231, 0.        ]), 'precision_mean': 0.6085169706696736, 'recall': array([0.94214286, 0.75214286, 0.74957143, 0.51571429, 0.59071429,
       0.42114286, 0.763     , 0.68142857, 0.58485714, 0.        ]), 'recall_mean': 0.6000714285714286, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.98it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0060, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 905.55it/s]

Loss (orig, final): 0.09961603581905365 0.006023285910487175
L2 norm of weight change: 0.38647156953811646
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 32.66it/s]


Post-edit metrics: {'TP': array([6227, 5380, 5491, 3938, 4141, 3088, 5544, 4652, 4574, 3627]), 'TN': array([59328, 61140, 57823, 58824, 60342, 61493, 61202, 62347, 61766,
       62397]), 'FPs': array([3672, 1860, 5177, 4176, 2658, 1507, 1798,  653, 1234,  603]), 'FNs': array([ 773, 1620, 1509, 3062, 2859, 3912, 1456, 2348, 2426, 3373]), 'accuracy': 0.6666, 'per_class_accuracy': array([0.9365    , 0.95028571, 0.90448571, 0.8966    , 0.92118571,
       0.92258571, 0.95351429, 0.95712857, 0.94771429, 0.9432    ]), 'per_class_accuracy_mean': 0.9333200000000001, 'precision': array([0.62905344, 0.74309392, 0.51471691, 0.48533399, 0.60906016,
       0.67203482, 0.7551076 , 0.87690858, 0.78753444, 0.85744681]), 'precision_mean': 0.693029066047093, 'recall': array([0.88957143, 0.76857143, 0.78442857, 0.56257143, 0.59157143,
       0.44114286, 0.792     , 0.66457143, 0.65342857, 0.51814286]), 'recall_mean': 0.6666000000000001, 'predicted_class_distribution': array([ 9899,  7240, 10668,  8114,  6

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.17it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0038, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:45<00:00, 879.51it/s]

Loss (orig, final): 0.16194212436676025 0.003780404105782509
L2 norm of weight change: 0.6567239761352539
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 30.61it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6820, 5068, 5033, 3935, 3737, 2310, 5589, 5392,    0, 4279]), 'TN': array([50804, 61497, 59698, 59758, 61058, 62515, 61226, 60731, 63000,
       61876]), 'FPs': array([12196,  1503,  3302,  3242,  1942,   485,  1774,  2269,     0,
        1124]), 'FNs': array([ 180, 1932, 1967, 3065, 3263, 4690, 1411, 1608, 7000, 2721]), 'accuracy': 0.6023285714285714, 'per_class_accuracy': array([0.8232    , 0.95092857, 0.92472857, 0.9099    , 0.92564286,
       0.92607143, 0.9545    , 0.94461429, 0.9       , 0.94507143]), 'per_class_accuracy_mean': 0.9204657142857144, 'precision': array([0.35864535, 0.77126769, 0.60383923, 0.54827923, 0.65803839,
       0.82647585, 0.7590656 , 0.70382457, 0.        , 0.79196743]), 'precision_mean': 0.6021403326632043, 'recall': array([0.97428571, 0.724     , 0.719     , 0.56214286, 0.53385714,
       0.33      , 0.79842857, 0.77028571, 0.        , 0.61128571]), 'recall_mean': 0.6023285714285714, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 43.70it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0144, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 822.94it/s]

Loss (orig, final): 0.0848778784275055 0.014396525919437408
L2 norm of weight change: 0.24481768906116486
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 32.01it/s]


Post-edit metrics: {'TP': array([6349, 5324, 5391, 3823, 4081, 2772, 5647, 5301, 3991, 4652]), 'TN': array([59115, 61189, 58545, 59865, 60490, 62067, 60988, 61380, 62404,
       61288]), 'FPs': array([3885, 1811, 4455, 3135, 2510,  933, 2012, 1620,  596, 1712]), 'FNs': array([ 651, 1676, 1609, 3177, 2919, 4228, 1353, 1699, 3009, 2348]), 'accuracy': 0.6761571428571429, 'per_class_accuracy': array([0.9352    , 0.95018571, 0.91337143, 0.90982857, 0.92244286,
       0.92627143, 0.95192857, 0.95258571, 0.9485    , 0.942     ]), 'per_class_accuracy_mean': 0.9352314285714286, 'precision': array([0.62038304, 0.7461808 , 0.54753199, 0.54943949, 0.61917767,
       0.74817814, 0.73730252, 0.76592978, 0.87006758, 0.7309868 ]), 'precision_mean': 0.6935177808673062, 'recall': array([0.907     , 0.76057143, 0.77014286, 0.54614286, 0.583     ,
       0.396     , 0.80671429, 0.75728571, 0.57014286, 0.66457143]), 'recall_mean': 0.6761571428571428, 'predicted_class_distribution': array([10234,  7135,  98

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:05<00:00, 45.80it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0718, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:45<00:00, 881.73it/s]

Loss (orig, final): 0.21077954769134521 0.0718313455581665
L2 norm of weight change: 0.5485361814498901
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 32.14it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6176,    0, 5426, 4165, 4460, 2524, 5322, 4816, 4837, 4771]), 'TN': array([54066, 63000, 58320, 58866, 59532, 62278, 61702, 62152, 61298,
       61283]), 'FPs': array([8934,    0, 4680, 4134, 3468,  722, 1298,  848, 1702, 1717]), 'FNs': array([ 824, 7000, 1574, 2835, 2540, 4476, 1678, 2184, 2163, 2229]), 'accuracy': 0.6071, 'per_class_accuracy': array([0.8606    , 0.9       , 0.91065714, 0.90044286, 0.91417143,
       0.92574286, 0.95748571, 0.95668571, 0.94478571, 0.94362857]), 'per_class_accuracy_mean': 0.9214199999999998, 'precision': array([0.40873594, 0.        , 0.53690877, 0.50186769, 0.56256307,
       0.7775724 , 0.80392749, 0.85028249, 0.73971555, 0.73535758]), 'precision_mean': 0.5916930977232759, 'recall': array([0.88228571, 0.        , 0.77514286, 0.595     , 0.63714286,
       0.36057143, 0.76028571, 0.688     , 0.691     , 0.68157143]), 'recall_mean': 0.6071, 'predicted_class_distribution': array([15110,     0, 10106,  8299,  7928,  3246,

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.41it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0096, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 848.54it/s]

Loss (orig, final): 0.1497228890657425 0.009584104642271996
L2 norm of weight change: 0.725858747959137
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 30.34it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6820,    0, 4831,  162, 3970, 2140, 4967, 4025, 2834,    0]), 'TN': array([30186, 63000, 60004, 62768, 60525, 62630, 62236, 62573, 62827,
       63000]), 'FPs': array([32814,     0,  2996,   232,  2475,   370,   764,   427,   173,
           0]), 'FNs': array([ 180, 7000, 2169, 6838, 3030, 4860, 2033, 2975, 4166, 7000]), 'accuracy': 0.4249857142857143, 'per_class_accuracy': array([0.52865714, 0.9       , 0.92621429, 0.899     , 0.92135714,
       0.92528571, 0.96004286, 0.9514    , 0.93801429, 0.9       ]), 'per_class_accuracy_mean': 0.8849971428571429, 'precision': array([0.17207448, 0.        , 0.61722244, 0.41116751, 0.61598138,
       0.85258964, 0.86668993, 0.90408805, 0.94246758, 0.        ]), 'precision_mean': 0.5382281009626493, 'recall': array([0.97428571, 0.        , 0.69014286, 0.02314286, 0.56714286,
       0.30571429, 0.70957143, 0.575     , 0.40485714, 0.        ]), 'recall_mean': 0.42498571428571436, 'predicted_class_distribution': array(

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.18it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0438, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 921.15it/s]

Loss (orig, final): 0.16976721584796906 0.04381655529141426
L2 norm of weight change: 0.4310559630393982
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 27.51it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6650, 5305, 5207, 4021, 4014, 2749, 5685, 5318,    0, 4627]), 'TN': array([54342, 61243, 59367, 59327, 60607, 62067, 60941, 61294, 63000,
       61388]), 'FPs': array([8658, 1757, 3633, 3673, 2393,  933, 2059, 1706,    0, 1612]), 'FNs': array([ 350, 1695, 1793, 2979, 2986, 4251, 1315, 1682, 7000, 2373]), 'accuracy': 0.6225142857142857, 'per_class_accuracy': array([0.87131429, 0.95068571, 0.92248571, 0.90497143, 0.92315714,
       0.92594286, 0.9518    , 0.9516    , 0.9       , 0.94307143]), 'per_class_accuracy_mean': 0.9245028571428572, 'precision': array([0.43441338, 0.75120363, 0.58902715, 0.52261502, 0.62650226,
       0.74660511, 0.73411674, 0.75711845, 0.        , 0.74162526]), 'precision_mean': 0.590322699376782, 'recall': array([0.95      , 0.75785714, 0.74385714, 0.57442857, 0.57342857,
       0.39271429, 0.81214286, 0.75971429, 0.        , 0.661     ]), 'recall_mean': 0.6225142857142858, 'predicted_class_distribution': array([15308,  7062,  884

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 43.13it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0045, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 839.93it/s]

Loss (orig, final): 0.12027763575315475 0.00446388078853488
L2 norm of weight change: 0.5884215235710144
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.51it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6731,    0, 4294, 3560, 3993, 2507, 4722, 4637, 3804, 3055]), 'TN': array([39789, 63000, 61311, 60571, 60541, 62335, 62427, 62165, 62509,
       62655]), 'FPs': array([23211,     0,  1689,  2429,  2459,   665,   573,   835,   491,
         345]), 'FNs': array([ 269, 7000, 2706, 3440, 3007, 4493, 2278, 2363, 3196, 3945]), 'accuracy': 0.5329, 'per_class_accuracy': array([0.66457143, 0.9       , 0.93721429, 0.91615714, 0.92191429,
       0.92631429, 0.95927143, 0.95431429, 0.94732857, 0.93871429]), 'per_class_accuracy_mean': 0.9065799999999999, 'precision': array([0.22480128, 0.        , 0.71770015, 0.59442311, 0.61887787,
       0.79035309, 0.8917847 , 0.84740497, 0.88568102, 0.89852941]), 'precision_mean': 0.646955560832182, 'recall': array([0.96157143, 0.        , 0.61342857, 0.50857143, 0.57042857,
       0.35814286, 0.67457143, 0.66242857, 0.54342857, 0.43642857]), 'recall_mean': 0.5328999999999999, 'predicted_class_distribution': array([29942,     0,

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:05<00:00, 45.69it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0050, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:44<00:00, 905.22it/s]

Loss (orig, final): 0.046584371477365494 0.00501144677400589
L2 norm of weight change: 0.21154743432998657
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 28.89it/s]


Post-edit metrics: {'TP': array([6261, 5290, 5403, 3773, 4300, 2865, 5510, 5107, 4548, 4440]), 'TN': array([59261, 61289, 58435, 59984, 59982, 61960, 61308, 61711, 61921,
       61646]), 'FPs': array([3739, 1711, 4565, 3016, 3018, 1040, 1692, 1289, 1079, 1354]), 'FNs': array([ 739, 1710, 1597, 3227, 2700, 4135, 1490, 1893, 2452, 2560]), 'accuracy': 0.6785285714285715, 'per_class_accuracy': array([0.93602857, 0.95112857, 0.91197143, 0.91081429, 0.91831429,
       0.92607143, 0.95454286, 0.95454286, 0.94955714, 0.94408571]), 'per_class_accuracy_mean': 0.9357057142857143, 'precision': array([0.6261    , 0.75560634, 0.54203451, 0.55575195, 0.58759224,
       0.73367478, 0.76506526, 0.79846779, 0.80824596, 0.76630998]), 'precision_mean': 0.6938848803166336, 'recall': array([0.89442857, 0.75571429, 0.77185714, 0.539     , 0.61428571,
       0.40928571, 0.78714286, 0.72957143, 0.64971429, 0.63428571]), 'recall_mean': 0.6785285714285714, 'predicted_class_distribution': array([10000,  7001,  99

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 41.85it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0067, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:48<00:00, 827.15it/s]

Loss (orig, final): 0.1935807168483734 0.006739319767802954
L2 norm of weight change: 0.7001426219940186
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.74it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6641, 5081, 5295, 3426, 4061, 2916, 5320, 4493, 3827,    0]), 'TN': array([48139, 61500, 58624, 60925, 60378, 61901, 61763, 62320, 62510,
       63000]), 'FPs': array([14861,  1500,  4376,  2075,  2622,  1099,  1237,   680,   490,
           0]), 'FNs': array([ 359, 1919, 1705, 3574, 2939, 4084, 1680, 2507, 3173, 7000]), 'accuracy': 0.5865714285714285, 'per_class_accuracy': array([0.78257143, 0.95115714, 0.91312857, 0.9193    , 0.92055714,
       0.92595714, 0.95832857, 0.95447143, 0.94767143, 0.9       ]), 'per_class_accuracy_mean': 0.9173142857142856, 'precision': array([0.30885499, 0.77207111, 0.54751318, 0.62279586, 0.60766123,
       0.72627646, 0.81134665, 0.86854823, 0.88649525, 0.        ]), 'precision_mean': 0.6151562971304431, 'recall': array([0.94871429, 0.72585714, 0.75642857, 0.48942857, 0.58014286,
       0.41657143, 0.76      , 0.64185714, 0.54671429, 0.        ]), 'recall_mean': 0.5865714285714285, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 36.35it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0049, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:46<00:00, 852.56it/s]

Loss (orig, final): 0.09877819567918777 0.0049400427378714085
L2 norm of weight change: 0.31037089228630066
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [06:29<00:00,  1.42s/it]


Post-edit metrics: {'TP': array([6376, 5536, 5288, 3648, 4255, 3032, 5509, 4926, 4639, 3704]), 'TN': array([57660, 60811, 58868, 60388, 60073, 61720, 61354, 61984, 61741,
       62314]), 'FPs': array([5340, 2189, 4132, 2612, 2927, 1280, 1646, 1016, 1259,  686]), 'FNs': array([ 624, 1464, 1712, 3352, 2745, 3968, 1491, 2074, 2361, 3296]), 'accuracy': 0.6701857142857143, 'per_class_accuracy': array([0.9148    , 0.94781429, 0.91651429, 0.9148    , 0.91897143,
       0.92502857, 0.95518571, 0.95585714, 0.94828571, 0.94311429]), 'per_class_accuracy_mean': 0.9340371428571428, 'precision': array([0.54421304, 0.7166343 , 0.56135881, 0.5827476 , 0.59245336,
       0.70315399, 0.76995108, 0.8290138 , 0.78653781, 0.84373576]), 'precision_mean': 0.6929799561305526, 'recall': array([0.91085714, 0.79085714, 0.75542857, 0.52114286, 0.60785714,
       0.43314286, 0.787     , 0.70371429, 0.66271429, 0.52914286]), 'recall_mean': 0.6701857142857143, 'predicted_class_distribution': array([11716,  7725,  94

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 33.72it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0030, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 754.15it/s]

Loss (orig, final): 0.05970033258199692 0.003009126055985689
L2 norm of weight change: 0.2282068282365799
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 29.64it/s]


Post-edit metrics: {'TP': array([6350, 5309, 4651, 4048, 4315, 2835, 5632, 5119, 4707, 4709]), 'TN': array([58676, 61190, 60690, 59440, 59980, 61971, 61156, 61676, 61792,
       61104]), 'FPs': array([4324, 1810, 2310, 3560, 3020, 1029, 1844, 1324, 1208, 1896]), 'FNs': array([ 650, 1691, 2349, 2952, 2685, 4165, 1368, 1881, 2293, 2291]), 'accuracy': 0.6810714285714285, 'per_class_accuracy': array([0.92894286, 0.94998571, 0.93344286, 0.90697143, 0.9185    ,
       0.9258    , 0.95411429, 0.95421429, 0.94998571, 0.94018571]), 'per_class_accuracy_mean': 0.9362142857142859, 'precision': array([0.5949035 , 0.74575081, 0.66815113, 0.5320715 , 0.58827539,
       0.73369565, 0.75334403, 0.79450567, 0.79577346, 0.71294474]), 'precision_mean': 0.6919415882512883, 'recall': array([0.90714286, 0.75842857, 0.66442857, 0.57828571, 0.61642857,
       0.405     , 0.80457143, 0.73128571, 0.67242857, 0.67271429]), 'recall_mean': 0.6810714285714285, 'predicted_class_distribution': array([10674,  7119,  69

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.17it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0035, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 914.09it/s]

Loss (orig, final): 0.11000146716833115 0.0035476258490234613
L2 norm of weight change: 0.4000532925128937
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 22.76it/s]


Post-edit metrics: {'TP': array([6526, 5280,   36, 4084, 4257, 2724, 5558, 5034, 4523, 4679]), 'TN': array([50619, 61235, 62976, 59475, 60091, 62093, 61385, 61760, 61947,
       61120]), 'FPs': array([12381,  1765,    24,  3525,  2909,   907,  1615,  1240,  1053,
        1880]), 'FNs': array([ 474, 1720, 6964, 2916, 2743, 4276, 1442, 1966, 2477, 2321]), 'accuracy': 0.6100142857142857, 'per_class_accuracy': array([0.81635714, 0.95021429, 0.90017143, 0.90798571, 0.91925714,
       0.92595714, 0.95632857, 0.9542    , 0.94957143, 0.93998571]), 'per_class_accuracy_mean': 0.9220028571428571, 'precision': array([0.34516317, 0.74946771, 0.6       , 0.53673282, 0.59405526,
       0.75020655, 0.77485013, 0.80235894, 0.81115495, 0.71337094]), 'precision_mean': 0.6677360471281366, 'recall': array([0.93228571, 0.75428571, 0.00514286, 0.58342857, 0.60814286,
       0.38914286, 0.794     , 0.71914286, 0.64614286, 0.66842857]), 'recall_mean': 0.6100142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 44.72it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0323, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 922.66it/s]

Loss (orig, final): 0.20986032485961914 0.03229523450136185
L2 norm of weight change: 0.6665641665458679
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 29.27it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6800, 5130, 5080, 4068, 3767, 2412, 5613, 5308,    0, 4313]), 'TN': array([51387, 61449, 59633, 59407, 61029, 62427, 61136, 61171, 63000,
       61852]), 'FPs': array([11613,  1551,  3367,  3593,  1971,   573,  1864,  1829,     0,
        1148]), 'FNs': array([ 200, 1870, 1920, 2932, 3233, 4588, 1387, 1692, 7000, 2687]), 'accuracy': 0.6070142857142857, 'per_class_accuracy': array([0.83124286, 0.95112857, 0.92447143, 0.90678571, 0.92565714,
       0.92627143, 0.95355714, 0.9497    , 0.9       , 0.94521429]), 'per_class_accuracy_mean': 0.9214028571428571, 'precision': array([0.3693043 , 0.76784912, 0.60139695, 0.53100117, 0.65650052,
       0.8080402 , 0.75070215, 0.74372986, 0.        , 0.78978209]), 'precision_mean': 0.6018306367484385, 'recall': array([0.97142857, 0.73285714, 0.72571429, 0.58114286, 0.53814286,
       0.34457143, 0.80185714, 0.75828571, 0.        , 0.61614286]), 'recall_mean': 0.6070142857142857, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 31.48it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0307, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 920.10it/s]

Loss (orig, final): 0.15627171099185944 0.030677489936351776
L2 norm of weight change: 0.4097956717014313
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:17<00:00, 15.87it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6704, 5289, 5227, 3913, 3979, 2610, 5666, 5302,    0, 4575]), 'TN': array([53508, 61270, 59198, 59722, 60674, 62237, 60983, 61230, 63000,
       61443]), 'FPs': array([9492, 1730, 3802, 3278, 2326,  763, 2017, 1770,    0, 1557]), 'FNs': array([ 296, 1711, 1773, 3087, 3021, 4390, 1334, 1698, 7000, 2425]), 'accuracy': 0.6180714285714286, 'per_class_accuracy': array([0.86017143, 0.95084286, 0.92035714, 0.90907143, 0.92361429,
       0.92638571, 0.95212857, 0.95045714, 0.9       , 0.94311429]), 'per_class_accuracy_mean': 0.9236142857142857, 'precision': array([0.41392937, 0.75352614, 0.57891239, 0.54415241, 0.63108644,
       0.77379188, 0.73747234, 0.74971719, 0.        , 0.74608611]), 'precision_mean': 0.5928674272517884, 'recall': array([0.95771429, 0.75557143, 0.74671429, 0.559     , 0.56842857,
       0.37285714, 0.80942857, 0.75742857, 0.        , 0.65357143]), 'recall_mean': 0.6180714285714286, 'predicted_class_distribution': array([16196,  7019,  90

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 40.51it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0029, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:43<00:00, 921.51it/s]

Loss (orig, final): 0.12462717294692993 0.0028799879364669323
L2 norm of weight change: 0.5283616781234741
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:12<00:00, 21.44it/s]
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Post-edit metrics: {'TP': array([6562, 5289,    0, 4049, 4187, 2671, 5504, 4973, 4415, 4669]), 'TN': array([49672, 61204, 63000, 59570, 60252, 62155, 61494, 61829, 62036,
       61107]), 'FPs': array([13328,  1796,     0,  3430,  2748,   845,  1506,  1171,   964,
        1893]), 'FNs': array([ 438, 1711, 7000, 2951, 2813, 4329, 1496, 2027, 2585, 2331]), 'accuracy': 0.6045571428571429, 'per_class_accuracy': array([0.80334286, 0.9499    , 0.9       , 0.90884286, 0.92055714,
       0.92608571, 0.95711429, 0.95431429, 0.9493    , 0.93965714]), 'per_class_accuracy_mean': 0.9209114285714286, 'precision': array([0.32991453, 0.7465067 , 0.        , 0.54138254, 0.6037491 ,
       0.75967008, 0.78516405, 0.80940755, 0.82078453, 0.71152088]), 'precision_mean': 0.6108099964062791, 'recall': array([0.93742857, 0.75557143, 0.        , 0.57842857, 0.59814286,
       0.38157143, 0.78628571, 0.71042857, 0.63071429, 0.667     ]), 'recall_mean': 0.6045571428571428, 'predicted_class_distribution': array([

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 29.35it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0039, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:53<00:00, 745.11it/s]

Loss (orig, final): 0.0627262145280838 0.0038970704190433025
L2 norm of weight change: 0.23487889766693115
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 31.27it/s]


Post-edit metrics: {'TP': array([6337, 5338, 4558, 4078, 4295, 2859, 5710, 5110, 4726, 4723]), 'TN': array([58787, 61143, 60921, 59369, 60019, 61936, 61031, 61695, 61766,
       61067]), 'FPs': array([4213, 1857, 2079, 3631, 2981, 1064, 1969, 1305, 1234, 1933]), 'FNs': array([ 663, 1662, 2442, 2922, 2705, 4141, 1290, 1890, 2274, 2277]), 'accuracy': 0.6819142857142857, 'per_class_accuracy': array([0.93034286, 0.94972857, 0.93541429, 0.90638571, 0.91877143,
       0.92564286, 0.95344286, 0.95435714, 0.94988571, 0.93985714]), 'per_class_accuracy_mean': 0.9363828571428572, 'precision': array([0.60066351, 0.7419041 , 0.68675606, 0.52899209, 0.59029687,
       0.728779  , 0.7435864 , 0.79657054, 0.79295302, 0.70958534]), 'precision_mean': 0.6920086919867467, 'recall': array([0.90528571, 0.76257143, 0.65114286, 0.58257143, 0.61357143,
       0.40842857, 0.81571429, 0.73      , 0.67514286, 0.67471429]), 'recall_mean': 0.6819142857142857, 'predicted_class_distribution': array([10550,  7195,  66

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 44.51it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0059, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 837.00it/s]

Loss (orig, final): 0.10483500361442566 0.005947549361735582
L2 norm of weight change: 0.35867226123809814
Performing post-edit metric & KNN calculations on validation set.



100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:11<00:00, 23.20it/s]


Post-edit metrics: {'TP': array([6431, 5311, 3193, 4172, 4329, 2803, 5693, 5016, 4716, 4718]), 'TN': array([55968, 61175, 62437, 59191, 59883, 61992, 61123, 61837, 61734,
       61042]), 'FPs': array([7032, 1825,  563, 3809, 3117, 1008, 1877, 1163, 1266, 1958]), 'FNs': array([ 569, 1689, 3807, 2828, 2671, 4197, 1307, 1984, 2284, 2282]), 'accuracy': 0.6626, 'per_class_accuracy': array([0.89141429, 0.9498    , 0.93757143, 0.90518571, 0.91731429,
       0.92564286, 0.95451429, 0.95504286, 0.94928571, 0.93942857]), 'per_class_accuracy_mean': 0.93252, 'precision': array([0.47767957, 0.74425448, 0.8501065 , 0.52274151, 0.58138598,
       0.73550249, 0.75204756, 0.81178184, 0.7883651 , 0.70671061]), 'precision_mean': 0.6970575628021397, 'recall': array([0.91871429, 0.75871429, 0.45614286, 0.596     , 0.61842857,
       0.40042857, 0.81328571, 0.71657143, 0.67371429, 0.674     ]), 'recall_mean': 0.6626, 'predicted_class_distribution': array([13463,  7136,  3756,  7981,  7446,  3811,  7570,  61

100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 45.60it/s]


Pre-edit metrics: {'TP': array([6142, 5382, 5203, 3842, 4270, 2990, 5659, 5136, 4820, 4679]), 'TN': array([60115, 61110, 59441, 59902, 60157, 61756, 61050, 61747, 61643,
       61202]), 'FPs': array([2885, 1890, 3559, 3098, 2843, 1244, 1950, 1253, 1357, 1798]), 'FNs': array([ 858, 1618, 1797, 3158, 2730, 4010, 1341, 1864, 2180, 2321]), 'accuracy': 0.6874714285714286, 'per_class_accuracy': array([0.94652857, 0.94988571, 0.92348571, 0.91062857, 0.92038571,
       0.92494286, 0.95298571, 0.95547143, 0.94947143, 0.94115714]), 'per_class_accuracy_mean': 0.9374942857142857, 'precision': array([0.68040323, 0.74009901, 0.5938142 , 0.55360231, 0.60030929,
       0.706188  , 0.74372454, 0.80388167, 0.78031407, 0.72240235]), 'precision_mean': 0.6924738665961387, 'recall': array([0.87742857, 0.76885714, 0.74328571, 0.54885714, 0.61      ,
       0.42714286, 0.80842857, 0.73371429, 0.68857143, 0.66842857]), 'recall_mean': 0.6874714285714286, 'predicted_class_distribution': array([9027, 7272, 8762, 

tensor(0.0046, device='cuda:0', grad_fn=<DivBackward0>): 100%|███████| 40000/40000 [00:47<00:00, 834.26it/s]

Loss (orig, final): 0.1403559297323227 0.004618950188159943
L2 norm of weight change: 0.4223584830760956
Performing post-edit metric & KNN calculations on validation set.



 34%|████████████████████████                                              | 94/274 [00:03<00:05, 35.30it/s]

## FIN

In [None]:
save_trials_path = 'saved/edit/trials/CINIC10_ImageNet-VGG_16/0125_114341/trial_paths.txt'
trial_dirs = read_lists(save_trials_path)
knn_analysis_filename = 'knn_analysis_results.pth'

In [None]:
try:
    if not os.path.exists(save_trials_path):
        print("Path {} does not exist".format(save_trials_path))
    else:
        print("Obtaining trial paths from {}".format(save_trials_path))
except:
    print("Need to define save_trials_path.")



## Analyze KNN

In [None]:
## Process KNN results for each
n_trials = len(trial_dirs)
n_log = n_trials // 10 + 1  # log every 10%
progress_report_path = os.path.join(os.path.dirname(save_trials_path), 'progress_report_analysis.txt')

informal_log("Starting KNN analysis...", progress_report_path)
for trial_idx, trial_dir in tqdm(enumerate(trial_dirs)):
    # if trial_idx % n_log == 0:
    informal_log("Processing {}/{} trials. Currently processing {}".format(
        trial_idx+1, n_trials, os.path.basename(trial_dir)), progress_report_path)
                     
    results_save_dir = os.path.join(trial_dir, 'models')
    load_and_analyze_knn(
        restore_dir=results_save_dir,
        pre_edit_knn_path=os.path.join(results_save_dir, 'pre_edit_{}-nn.pth'.format(K)),
        post_edit_knn_path=os.path.join(results_save_dir, 'post_edit_{}-nn.pth'.format(K)),
        knn_analysis_filename=knn_analysis_filename,
        target_class_idx=target_class_idx,
        class_list=class_list,
        progress_report_path=progress_report_path,
        save_images=False,
        save_plots=True)
    


## Convert to CSV for all trials

In [None]:
csv_save_path = os.path.join(os.path.dirname(save_trials_path), 'results_table.csv')
store_csv(
    trial_dirs=trial_dirs,
    class_list=class_list,
    save_path=csv_save_path)
