In [5]:
import pandas as pd
import numpy as np
import re
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import sklearn
from datasets import load_metric


  from .autonotebook import tqdm as notebook_tqdm


# Confusion matrixes

In [None]:
model_name='20230120-024256'
results_path='../../results/'+model_name+'/results_'+model_name+'.csv'
df=pd.read_csv(results_path)

In [None]:
def get_label(caption,label):
    """Extracts the label indicated by the string 'label'.
    from the caption. 
    """
    rgx={'core_modality':r'core modality:([ a-z/0-9]+)modality',
         'plane': r'plane:([ a-z/0-9]+)anatomy',
         'anatomy':r'anatomy:([ a-z/0-9]+)findings'}
    caption=caption.lower() # lowercase
    # find regex match
    match=re.search(rgx[label],caption)
    if match is None:
    # If no match is found return 'N/A'
        return 'N/A'    
    label=match.group(1) 
    label=label.strip() # strip white spaces
    return label



In [None]:
# Extract true modality, plane and anatomy labels--------------------------------------- 
# extract true core modality
df['true_core_modality']=df['true_captions'].apply(lambda c:get_label(c,'core_modality'))
# extract true plane
df['true_plane']=df['true_captions'].apply(lambda c:get_label(c,'plane')) 
# extract true anatomy
df['true_anatomy']=df['true_captions'].apply(lambda c:get_label(c,'anatomy')) 

# Extract predicted modality, plane and anatomy labels--------------------------------------- 
# extract true core modality
df['predicted_core_modality']=df['predicted_captions'].apply(lambda c:get_label(c,'core_modality'))
# extract true plane
df['predicted_plane']=df['predicted_captions'].apply(lambda c:get_label(c,'plane')) 
# extract true anatomy
df['predicted_anatomy']=df['predicted_captions'].apply(lambda c:get_label(c,'anatomy')) 

In [None]:
labels_savepath='../../results/'+model_name+'/labels_'+model_name+'.csv'
df.to_csv(labels_savepath)

In [None]:
def plot_conf_matrix(y_true,y_pred,path):
    """Plots and saves the confusion matrix"""
    uniques = list(set(np.concatenate([y_true,y_pred])))
    cm = confusion_matrix(y_true, y_pred, labels=uniques)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=uniques)
    
    fig=disp.plot(xticks_rotation='vertical').figure_
    fig.savefig(path, dpi='figure', format=None, metadata=None,
        bbox_inches='tight', pad_inches=0.3,
        facecolor='auto', edgecolor='auto',
        backend=None)

In [None]:
fig_savepath='../../results/'+model_name+'/cm_xx_'+model_name+'.png'

plot_conf_matrix(y_true=df['true_core_modality'],
                 y_pred=df['predicted_core_modality'],
                 path=fig_savepath.replace('xx','cm'))

plot_conf_matrix(y_true=df['true_plane'],
                 y_pred=df['predicted_plane'],
                 path=fig_savepath.replace('xx','pl'))

plot_conf_matrix(y_true=df['true_anatomy'],
                 y_pred=df['predicted_anatomy'],
                 path=fig_savepath.replace('xx','an'))



# Quantitative scores

## Classification scores

In [41]:
# Core Modality f1 weighted
y_true=df['true_core_modality']
y_pred=df['predicted_core_modality']
f1w_cm=sklearn.metrics.f1_score(y_true, y_pred,average='weighted')

# Plane f1 weighted
y_true=df['true_plane']
y_pred=df['predicted_plane']
f1w_pl=sklearn.metrics.f1_score(y_true, y_pred,average='weighted')

# Anatomy f1 weighted
y_true=df['true_anatomy']
y_pred=df['predicted_anatomy']
f1w_an=sklearn.metrics.f1_score(y_true, y_pred,average='weighted') 

print(f'core modality: {f1w_cm:2.3f}')
print(f'plane: {f1w_pl:2.3f}')
print(f'anatomy: {f1w_an:2.3f}')

KeyError: 'true_core_modality'

## Sacrebleu score

In [42]:
results_path='../../results/20230124-060631/results_20230124-060631.csv'
df=pd.read_csv(results_path)
true_caption_cols=[]

# Obtain titles of columns refering to true captions. 
for col_name in list(df.columns):
    if 'true_caption' in col_name:
        true_caption_cols.append(col_name)
        
def extract_true_captions_list(row,references,true_caption_cols):
    """To be used as map function through df.apply() in a 
    results df obtain from a results.csv file. 
    It iterates over rows of the df and makes 
    a list of lists with all reference sentences in a single row. 
    The resulting list is the "references" parameter. 
    """
    ref_row=[] # make a list of true captions for a single row
    for col in true_caption_cols:
    # for every column that contains true captions.
        ref_row.append(row[col]) # Append it to this row's list
    references.append(ref_row) # Append the row list of captions to records. 
    return None

# Create empty references list. 
references=[]
# Fill references list.
df.apply(lambda row:extract_true_captions_list(row,references,true_caption_cols),axis=1)
predictions=df['predicted_captions'].to_list() # Fill predictions list

# compute sacrebleu metric
sacrebleu = load_metric('sacrebleu')
sacrebleu_score=sacrebleu.compute(predictions=predictions, references=references)

sacrebleu_score

{'score': 31.143870626415477,
 'counts': [197, 104, 51, 26],
 'totals': [271, 246, 221, 196],
 'precisions': [72.69372693726937,
  42.27642276422764,
  23.076923076923077,
  13.26530612244898],
 'bp': 1.0,
 'sys_len': 271,
 'ref_len': 267}

# Rouge Score
    

In [44]:
rouge=load_metric('rouge')
rouge_score=rouge.compute(predictions=predictions,references=references)
rouge_score

{'rouge1': AggregateScore(low=Score(precision=0.6898924242424241, recall=0.1294223124856389, fmeasure=0.21757238652019723), mid=Score(precision=0.754121212121212, recall=0.14256209546629506, fmeasure=0.23904744976566936), high=Score(precision=0.8139242424242424, recall=0.15746251986047, fmeasure=0.2618547891544109)),
 'rouge2': AggregateScore(low=Score(precision=0.3411306998556998, recall=0.05993640245028824, fmeasure=0.1018080445512604), mid=Score(precision=0.4446810966810967, recall=0.07625699677353068, fmeasure=0.12957729349058653), high=Score(precision=0.5536443001443002, recall=0.09341542863914934, fmeasure=0.1592912272587512)),
 'rougeL': AggregateScore(low=Score(precision=0.6345136363636363, recall=0.11854956693288969, fmeasure=0.1994236815483046), mid=Score(precision=0.6943333333333332, recall=0.1311301069224129, fmeasure=0.21961716234541218), high=Score(precision=0.7589712121212122, recall=0.1436058336189333, fmeasure=0.23971242107768667)),
 'rougeLsum': AggregateScore(low=Sco