In [68]:
import csv
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

RESULTS_FOLDER = Path.cwd().parent.parent / 'work' / 'results'
SUBSETS = ['train_cnn', 'val', 'train_rnn', 'test']
FOLDS = [str(fold) for fold in range(1, 11)]


BEST_MODELS = {
    'test': [
        ('hp_pr', '29'),
        ('hp_po', '63'),
        ('hp_tf', '5'),
        ('hp_cr', '56'),
        ('hp_sh', '66'),
        ('hp_ro', '14'),
    ],
    'val': [
        ('hp_pr', '57'),
        ('hp_po', '63'),
        ('hp_tf', '69'),
        ('hp_cr', '55'),
        ('hp_sh', '72'),
        ('hp_ro', '33'),
    ]
}

In [7]:
def get_metrics(exp_name, model_id):
    metrics_path = RESULTS_FOLDER / exp_name / f'model_{model_id}' / 'metrics.json'
    
    with open(metrics_path, 'r') as metrics_json:
        metrics_file = json.load(metrics_json)

    return metrics_file['metrics']

In [56]:
def generate_model_f1_scores(model_type, model_id, data_set_name):
    model_metrics = get_metrics(model_type, model_id)
    return [
        {
            'pb_name': (fold, pb_name),
            'f1': model_metrics[fold][data_set_name][pb_name]['f1']
        }
        for fold in FOLDS
        for pb_name in model_metrics[fold][data_set_name].keys()

    ]

In [57]:
def find_pullbacks_by_f1(pullbacks, f1_value):
    return filter(lambda item: item['f1'] == f1_value, pullbacks)

In [58]:
def find_edge_pullbacks(pullbacks):
    flag_f1 = pullbacks[0]['f1']
    edge_pullbacks = []
    
    for pb in pullbacks:
        if pb['f1'] != flag_f1:
            flag_f1 = pb['f1']
            if len(edge_pullbacks) >= 3:
                break
        edge_pullbacks.append(pb)
    
    return edge_pullbacks

In [59]:
def get_edge_pullbacks(model_type, model_id, data_set_name):
    f1_scores = generate_model_f1_scores(model_type, model_id, data_set_name)
    sorted_f1_scores = sorted(f1_scores, key=lambda item: item['f1'])
    
    return {
        'worst': find_edge_pullbacks(sorted_f1_scores),
        'best': find_edge_pullbacks(sorted_f1_scores[::-1])
    }

In [60]:
get_edge_pullbacks('hp_pr', '57', 'val')

{'worst': [{'pb_name': ('3', '68_PDW49G5N'), 'f1': 0.0},
  {'pb_name': ('4', '24_PDKL3TD8'), 'f1': 0.0},
  {'pb_name': ('5', '51_PDBB4VBN'), 'f1': 0.0},
  {'pb_name': ('7', '20_PD6M9KHG'), 'f1': 0.0},
  {'pb_name': ('7', '06_PDLKTHJP'), 'f1': 0.0},
  {'pb_name': ('8', '05_PDP5ZUCK'), 'f1': 0.0},
  {'pb_name': ('9', '48_PDKS3P67'), 'f1': 0.0},
  {'pb_name': ('10', '43_PDQWC1XR'), 'f1': 0.0}],
 'best': [{'pb_name': ('10', '65_PDY5C128'), 'f1': 1.0},
  {'pb_name': ('9', '69_PDNVGH7Z'), 'f1': 1.0},
  {'pb_name': ('7', '64_PDJJP4QL'), 'f1': 1.0},
  {'pb_name': ('4', '70_PDUPMC2M'), 'f1': 1.0}]}

In [61]:
def get_all_edge_pullbacks(data_set_name):
    return {
        (model_type, model_id): get_edge_pullbacks(model_type, model_id, data_set_name)
        for model_type, model_id in BEST_MODELS[data_set_name]
    }

In [62]:
all_edge_pullbacks = get_all_edge_pullbacks('val')
len(all_edge_pullbacks)

6

In [63]:
def pullback_histogram(pullbacks, key):
    result = {}
    
    for model, data in pullbacks.items():
        for pb in data[key]:
            if pb['pb_name'] in result:
                result[pb['pb_name']] += 1
            else:
                result[pb['pb_name']] = 1
    
    return result

In [64]:
best_pullbacks_histogram = pullback_histogram(all_edge_pullbacks, 'best')

In [65]:
worst_pullbacks_histogram = pullback_histogram(all_edge_pullbacks, 'worst')

In [66]:
max_counter_value = max(best_pullbacks_histogram.values())
best_pullbacks = [pb_name for pb_name, counter in best_pullbacks_histogram.items() if counter == max_counter_value]
best_pullbacks

[('10', '65_PDY5C128'),
 ('9', '69_PDNVGH7Z'),
 ('7', '64_PDJJP4QL'),
 ('4', '70_PDUPMC2M')]

In [67]:
max_counter_value = max(worst_pullbacks_histogram.values())
worst_pullbacks = [pb_name for pb_name, counter in worst_pullbacks_histogram.items() if counter == max_counter_value]
worst_pullbacks

[('7', '20_PD6M9KHG'), ('8', '05_PDP5ZUCK'), ('10', '43_PDQWC1XR')]

In [83]:
def get_pullback_prediction(model, pullback, data_set_name, real=False):
    model_type, model_id = model
    fold, pb_name = pullback

    prediction_path = (
        RESULTS_FOLDER / model_type / f'model_{model_id}'
        / 'predictions' / fold / data_set_name.upper() / f'{pb_name}.csv'
    )
    
    with open(prediction_path, 'r') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        return [
            row[1 if real else 0]
            for row in csv_reader
        ]

In [85]:
values = get_pullback_prediction(BEST_MODELS['val'][0], best_pullbacks[0], 'val', real=True)
len(values)

67

In [88]:
def get_all_pullback_predictions(data_set_name):
    all_pullback_predictions = {
        pb_name: {
            model: get_pullback_prediction(model, (fold, pb_name), data_set_name)
            for model in BEST_MODELS[data_set_name]
        }
        for fold, pb_name in [*best_pullbacks, *worst_pullbacks]
    }

    for fold, pb_name in [*best_pullbacks, *worst_pullbacks]:
        all_pullback_predictions[pb_name]['real'] = get_pullback_prediction(BEST_MODELS[data_set_name][0], (fold, pb_name), data_set_name, real=True)
    
    return all_pullback_predictions

In [89]:
all_predictions = get_all_pullback_predictions('val')
len(all_predictions)

7

In [90]:
all_predictions

{'65_PDY5C128': {('hp_pr', '57'): ['0.18729616701602936',
   '0.18388748168945312',
   '0.1768369972705841',
   '0.16464954614639282',
   '0.14577552676200867',
   '0.12344910949468613',
   '0.10092870146036148',
   '0.0812835767865181',
   '0.06444144994020462',
   '0.05093476548790932',
   '0.040659721940755844',
   '0.03339049965143204',
   '0.028377309441566467',
   '0.02474644035100937',
   '0.02256985381245613',
   '0.021258432418107986',
   '0.020717956125736237',
   '0.02023935317993164',
   '0.01969520002603531',
   '0.0182176623493433',
   '0.016208263114094734',
   '0.013652096502482891',
   '0.011084593832492828',
   '0.009232127107679844',
   '0.00749211385846138',
   '0.006706266663968563',
   '0.006426114123314619',
   '0.007315034978091717',
   '0.009599120356142521',
   '0.01378605142235756',
   '0.019422458484768867',
   '0.03062065690755844',
   '0.050593599677085876',
   '0.0763104259967804',
   '0.1117815300822258',
   '0.16479991376399994',
   '0.16213268041610718