In [1]:
import numpy as np
import matplotlib.pyplot as plt
import re, os

In [None]:
def compute_wer_loss(data):
    out = {}
    name=None
    for line in data:
        if ("EVAL MODEL" in line):
            iter_number = re.search(r'(\d)\_iter', line)
            iter_number = iter_number.group(1) if iter_number is not None else '1'
            name=line.strip().split(' ')[-2].split('/')[-1]
            name = iter_number + '_iter_' + name.replace('checkpoint-','') 
            # print(name)
            # print(line)
            out[name]={}
        elif name in out:
            dataset = re.search(r'DATASET: ([a-zA-Z_]+)', line)
            wer = re.search(r'WER: ([\d.]+)', line)
            loss = re.search(r'LOSS: ([\d.]+)', line)
            cal_wer = re.search(r'CALLSIGN WER: ([\d.]+)', line)
            if dataset is None or wer is None and loss is None:
                continue
            out[name][dataset.group(1)]={
                'wer': float(wer.group(1) if wer is not None else 0.0),
                'loss': float(loss.group(1) if loss is not None else 0.0),
            }    
            if cal_wer is not None:
                out[name][dataset.group(1)].update({'cal_wer': float(cal_wer.group(1))})
    return out

In [None]:
def compute_total_wer(out, num_examples):
    # total wer
    checkpoints =[]
    wer = {'total':[]}
    cal_wer = {'total':[]}
    loss = {'total':[]}
    for checkpoint,ds_set in out.items():
        checkpoints.append(checkpoint)
        total_wer = 0
        total_loss = 0
        total_cal_wer = 0
        for ds_name,wer_loss in ds_set.items():
            if ds_name not in num_examples.keys():
                continue
            if ds_name not in wer:
                wer[ds_name] = []
            if ds_name not in loss:
                loss[ds_name] = []
            if ds_name not in cal_wer:
                cal_wer[ds_name] = []
            wer[ds_name].append(wer_loss['wer'])
            loss[ds_name].append(wer_loss['loss'])
            if 'cal_wer' in wer_loss:
                cal_wer[ds_name].append(wer_loss['cal_wer'])
                total_cal_wer += wer_loss['cal_wer'] * num_examples[ds_name]
            total_wer += wer_loss['wer'] * num_examples[ds_name]
            total_loss += wer_loss['loss'] * num_examples[ds_name]
        # average wer and loss
        wer['total'].append(total_wer / sum(num_examples.values()))
        loss['total'].append(total_loss / sum(num_examples.values()))
        if 'cal_wer' in wer_loss:
            cal_wer['total'].append(total_cal_wer / sum(num_examples.values()))
    return checkpoints, wer, loss, cal_wer

In [None]:
files = [    
    # ===========================================
    # BASE SHORTTS+FULLTS
    # ===========================================
    # "fullts/allds-atcoen/mp_eval.txt",
    # "fullts/allds-atcoen/vp_eval.txt",
    # "shortts/allds-atcoen/mp_eval.txt",
    # "shortts/allds-atcoen/vp_eval.txt"
    # ===========================================
    # PROMPT FULLTS
    # ===========================================
    "fullts/vanmed-full/AG4B_eval.txt",
    "fullts/vanmed-full/AG40B_eval.txt",
    "fullts/vanmed-full/AG50CZB_eval.txt",
    "fullts/vanmed-full/AG_eval.txt",
    "fullts/vanmed-full/5B_eval.txt",
    "fullts/vanmed-full/40B_eval.txt",
    "fullts/vanmed-full/50CZB_eval.txt",
]
# SAVE=False

for file in files:
    print('************************************************************************************')
    print(file)
    print('************************************************************************************')
    data = open(file,'r').readlines()

    # number of examples in dataset
    num_examples = {
        'atco_en_ruzyne': 70,
        'atco_en_stefanik': 53,
        'atco_en_zurich': 412,
    }

    out=compute_wer_loss(data)
    # print(json.dumps(out, indent=2))
    checkpoints, wer, loss, cal_wer = compute_total_wer(out, num_examples)
    # plot_wer_loss(wer,loss, cal_wer, checkpoints, file)
    plot_wer_loss_to_text(wer,loss, cal_wer, checkpoints, file)