# CS 4650 Final Project: Evaluating Spreadsheet-to-Language Representations for Retrieval-Augmented Generation
Sanjay Jacob, Pranav Aluru, Justin Xu, Aryan Garg

## Setup

In [None]:
!pip install -q --upgrade pip
!pip install -q faiss-cpu sentence-transformers transformers accelerate bitsandbytes datasets safetensors tqdm tabulate
!pip install -q -U bitsandbytes

In [None]:
import os
from google.colab import userdata
from google.colab import drive
drive.mount('/content/drive')

In [None]:
HF_TOKEN = userdata.get('HF_TOKEN')
if HF_TOKEN:
    print('HF_token provided')
else:
    print('add HF_token')

In [None]:
%cd /content/drive/My Drive/TabularDataRetrieval---NLP-Final
print(os.getcwd())

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Data_dir = os.getcwd() + '/data'
Index_dir = os.getcwd() + '/indices'
Cache_dir = os.getcwd() + '/cache'

os.makedirs(Index_dir, exist_ok=True)
os.makedirs(Cache_dir, exist_ok=True)
os.makedirs("results", exist_ok=True)
os.makedirs("figures", exist_ok=True)

In [None]:
TOP_K =5
EMBED_BATCH =32
EMBED_DIM =None
CHUNK_SIZE= 512
CHUNK_OVERLAP= 64

## data processing

In [None]:
import zipfile
import re

zips = [file for file in os.listdir(Data_dir) if file.endswith('.zip')]
print(zips)

for name in zips:
    loc = os.path.join(Data_dir, name)
    match =re.match(r'(R[1-5])', name)
    if not match:
        print(f'skipping {name}')
        continue

    out_folder= match.group(1)
    out_dir =os.path.join(Data_dir, out_folder)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'unzipping {name}')
        with zipfile.ZipFile(loc, 'r') as z:
            z.extractall(out_dir)
    else:
        print(f'{out_folder} already exists')

## Hybrid Representations

In [None]:
from pathlib import Path

HYBRID_CONFIGS = {
    'R2_R4': ('R2','R4'),
    'R3_R4': ('R3', 'R4'),
    'R2_R5': ('R2','R5'),
    'R3_R5': ('R3','R5'),
}

for hybrid_name, (rep_a, rep_b) in HYBRID_CONFIGS.items():
    hybrid_dir = Path(Data_dir) / hybrid_name
    if hybrid_dir.exists() and len(list(hybrid_dir.iterdir())) > 0:
        print(f"{hybrid_name} already exists")
        continue
    hybrid_dir.mkdir(exist_ok=True)
    path_a = Path(Data_dir) / rep_a
    path_b = Path(Data_dir) / rep_b
    if not path_a.exists() or not path_b.exists():
        print(f"skipping {hybrid_name}")
        continue
    files_a = sorted([f for f in path_a.iterdir() if f.is_file()])
    files_b = sorted([f for f in path_b.iterdir() if f.is_file()])
    print(f"creating {hybrid_name} w/ {len(files_a)} files")
    for fa, fb in zip(files_a, files_b):
        text_a = fa.read_text(errors='ignore')
        text_b = fb.read_text(errors='ignore')
        hybrid_text = f"{text_a}\n\n ===ADDITIONAL CONTEXT=== \n\n{text_b}"
        out_file = hybrid_dir / fa.name
        out_file.write_text(hybrid_text)
    print(f"  {hybrid_name}: {len(files_a)} files created")

print("\nhybrid generation complete")

## REP_PATHS for all representations

In [None]:
from pathlib import Path
REPS = ['R1', 'R2', 'R3', 'R4', 'R5', 'R2_R4','R3_R4','R2_R5','R3_R5']
REP_PATHS = {}
for rep in REPS:
    path = Path(Data_dir) / rep
    if path.exists():
        files = sorted([p for p in path.iterdir() if p.is_file()])
        if len(files) > 0:
            print(f"{rep}: {len(files)} files")
            REP_PATHS[rep] = files
        else:
            print(f"{rep} directory is empty")
    else:
        print(f"{path} not found")
print(f"\nloaded : {list(REP_PATHS.keys())}")

## Load Models

In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np

embed_name ='sentence-transformers/all-mpnet-base-v2'
embed_model= SentenceTransformer(embed_name,device=device)
EMBED_DIM =embed_model.get_sentence_embedding_dimension()
print(f'{embed_name} loaded')
print(f'embed dim {EMBED_DIM}')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

LLAMA3 = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer =AutoTokenizer.from_pretrained(LLAMA3, token=HF_TOKEN, use_fast=True)
bnb_config= BitsAndBytesConfig(load_in_8bit=True) if device=="cuda" else None

if device=='cuda':
    model= AutoModelForCausalLM.from_pretrained(LLAMA3,device_map='auto',quantization_config=bnb_config,trust_remote_code=True,token=HF_TOKEN)
else:
    model = AutoModelForCausalLM.from_pretrained(LLAMA3,torch_dtype=torch.float16,device_map="auto",trust_remote_code=True,token=HF_TOKEN)

model.eval()
print("model loaded")

## Helpers

In [None]:
import re
import pickle
import hashlib
import json
from tqdm.auto import tqdm

def count_tokens(text):
    try:
        toks =tokenizer(text,return_tensors=None)
        return len(toks["input_ids"])
    except Exception:
        return len(text.split())
def save_pickle(obj, path):
    with open(path,"wb") as f:
        pickle.dump(obj,f)
def load_pickle(path):
    with open(path,"rb") as f:
        return pickle.load(f)

## FAISS Indices

In [None]:
import faiss
import numpy as np

def embed_texts(texts):
    embeddings=embed_model.encode(texts, batch_size=EMBED_BATCH,convert_to_numpy=True)
    return embeddings

faiss_indexes = {}
for rep, files in REP_PATHS.items():
    idx_path = os.path.join(Index_dir, f"{rep}_faiss.index")
    meta_path = os.path.join(Index_dir, f"{rep}_meta.pkl")
    if os.path.exists(idx_path) and os.path.exists(meta_path):
        print(f"{rep} loading cache")
        index = faiss.read_index(idx_path)
        meta = load_pickle(meta_path)
        if meta.get("embed_name") == embed_name:
            faiss_indexes[rep] = {"index": index,"texts": meta["texts"],"paths": meta["paths"]}
            print(f"{rep} loaded from cache w/ {len(meta['texts'])} entries.")
            continue
        else:
            continue

    print(f"{rep} embedding {len(files)} files")
    texts =[f.read_text(errors="ignore") for f in files]
    embs= embed_texts(texts)
    dim =embs.shape[1]
    index= faiss.IndexFlatIP(dim)
    faiss.normalize_L2(embs)
    index.add(embs)
    faiss_indexes[rep] ={"index": index,"texts": texts,"paths": files,}
    print(f"{rep} saving to cache")
    faiss.write_index(index,idx_path)
    save_pickle({"texts":texts, "paths":files, "embed_name":embed_name},meta_path)
    print(f"{rep} done")
print(f"\nfaiis indices built for: {list(faiss_indexes.keys())}")

## Retrieval functions

In [None]:
def retrieve(query, rep="R1",k=5):
    q_emb = embed_texts([query])
    faiss.normalize_L2(q_emb)
    D,I=faiss_indexes[rep]["index"].search(q_emb,k)
    retrieved_texts =[faiss_indexes[rep]["texts"][i] for i in I[0]]
    retrieved_paths=[faiss_indexes[rep]["paths"][i] for i in I[0]]
    return list(zip(retrieved_paths, retrieved_texts))

def build_context_and_prompt(query, rep="R1", k=TOP_K):
    retrieved = retrieve(query, rep=rep, k=k)
    context = "\n\n".join(f"TABLE FROM {path.name}:\n{text}" for path,text in retrieved)
    prompt = f"""You are a fact verification assistant.

    You are given one or more tables and a natural language statement.
    Determine whether the statement is supported by the table data.

    Reply with a single word: "yes" if the statement is supported (entailed),
    or "no" if the statement is not supported (refuted).

    Context:
    {context}

    Statement: {query}
    Answer (yes or no):"""
    return context, prompt, retrieved

def generate_answer_only(prompt, max_new_tokens=16):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output =model.generate(**inputs,max_new_tokens=max_new_tokens,pad_token_id=tokenizer.eos_token_id,)
    generated_ids = output[0,inputs["input_ids"].shape[1]:]
    return tokenizer.decode(generated_ids,skip_special_tokens=True).strip()

## Load TabFact

In [None]:
import pandas as pd
from datasets import Dataset
from collections import defaultdict

TABFACT_DIR = "/content/drive/My Drive/TabularDataRetrieval---NLP-Final/data/qa_data"
def load_tabfact_split(split="test"):
    path_csv = os.path.join(TABFACT_DIR, f"{split}.tsv")
    print("loading", path_csv)
    df = pd.read_csv(path_csv,sep="\t",header=None,names=["table_file", "col_id", "col_name", "context", "statement", "label"])
    df["label"]= df["label"].astype(int)
    table_id_pattern =re.compile(r"\d+-\d+-\d+")
    def extract_table_id(fn):
        m = table_id_pattern.search(str(fn))
        return m.group(0) if m else None
    df["table_id"]= df["table_file"].apply(extract_table_id)
    df =df[df["table_id"].notna()].reset_index(drop=True)
    dataset =Dataset.from_pandas(df[["statement","label","table_id","table_file","context"]])
    return dataset
tabfact_test = load_tabfact_split("test")
print(tabfact_test[0])

In [None]:
TABLEID_TO_PATHS = defaultdict(dict)
table_id_pattern = re.compile(r"\d+-\d+-\d+")
for rep, data in faiss_indexes.items():
    paths = data["paths"]
    rep_map = defaultdict(list)
    for p in paths:
        m =table_id_pattern.search(p.name)
        if m:
            table_id =m.group(0)
            rep_map[table_id].append(p)
    TABLEID_TO_PATHS[rep] =dict(rep_map)
for rep,rep_map in TABLEID_TO_PATHS.items():
    print(f"{rep} {len(rep_map)} table ids mapped")

## Evaluation 

In [None]:
import numpy as np
import json
from tqdm.auto import tqdm

def gold_label_to_str(label_int):
    return "yes" if int(label_int)==1 else "no"
def extract_yes_no_label(text):
    t =text.strip().lower()
    if "answer:" in t:
        t = t.split("answer:")[-1].strip()
    if re.search(r"\byes\b",t):
        return "yes"
    if re.search(r"\bno\b",t):
        return"no"
    tokens = t.split()
    return tokens[0] if tokens else ""
def compute_em(gold, pred):
    return int(gold.strip().lower() == pred.strip().lower())
def compute_classification_f1(golds, preds, positive_label="yes"):
    gold_bin = [1 if g ==positive_label else 0 for g in golds]
    pred_bin = [1 if p ==positive_label else 0 for p in preds]
    tp = sum(1 for g,p in zip(gold_bin, pred_bin) if g ==1 and p == 1)
    fp = sum(1 for g,p in zip(gold_bin, pred_bin) if g== 0 and p== 1)
    fn = sum(1 for g,p in zip(gold_bin, pred_bin) if g== 1 and p== 0)
    if tp == 0 and fp== 0 and fn == 0:
        return 0.0
    precision = tp/(tp + fp+1e-8)
    recall = tp/(tp + fn+1e-8)
    if precision+recall == 0:
        return 0.0
    return 2 * precision *recall/(precision+recall+1e-8)

def evaluate_representation(dataset,rep="R1",k=TOP_K, max_examples=200, max_new_tokens=16, log_path=None):
    if rep not in faiss_indexes:
        raise ValueError(f"{rep} not found")
    table_map = TABLEID_TO_PATHS.get(rep, {})
    em_list, gold_labels, pred_labels = [], [], []
    recall_hits, recall_total = 0, 0
    ctx_token_counts, prompt_token_counts = [], []
    logs = []
    ds = dataset.select(range(max_examples)) if max_examples is not None else dataset
    for ex in tqdm(ds, desc=f"Eval {rep}"):
        statement = ex["statement"]
        table_id = ex["table_id"]
        gold_label = gold_label_to_str(ex["label"])
        gold_labels.append(gold_label)
        context, prompt, retrieved = build_context_and_prompt(statement, rep=rep, k=k)

        #recall@k
        if table_id in table_map:
            recall_total += 1
            retrieved_names = [p.name for p, _ in retrieved]
            gold_names = {p.name for p in table_map[table_id]}
            if any(name in gold_names for name in retrieved_names):
                recall_hits += 1
        ctx_token_counts.append(count_tokens(context))
        prompt_token_counts.append(count_tokens(prompt))
        answer_text = generate_answer_only(prompt, max_new_tokens=max_new_tokens)
        pred_label = extract_yes_no_label(answer_text)
        pred_labels.append(pred_label)
        em_list.append(compute_em(gold_label, pred_label))

        #errors
        if gold_label != pred_label:
            if table_id in table_map:
                gold_names = {p.name for p in table_map[table_id]}
                retrieved_names = [p.name for p, _ in retrieved]
                error_type = "retrieval_error" if not any(name in gold_names for name in retrieved_names) else "prediction_or_representation_error"
            else:
                error_type = "no_gold_table_in_rep"
        else:
            error_type = "none"
        logs.append({"statement":statement,"table_id":table_id,"gold_label":gold_label, "pred_label":pred_label,"answer_text": answer_text,"retrieved_paths": [p.name for p, _ in retrieved],"error_type": error_type,})

    metrics = {
        "rep": rep, "k": k, "num_examples": len(ds),
        "EM":float(np.mean(em_list)) if em_list else 0.0,
        "F1":compute_classification_f1(gold_labels, pred_labels),
        f"Recall@{k}": recall_hits / recall_total if recall_total>0 else 0.0,
        "avg_context_tokens":float(np.mean(ctx_token_counts)) if ctx_token_counts else 0.0,
        "avg_prompt_tokens":float(np.mean(prompt_token_counts)) if prompt_token_counts else 0.0,
    }
    if log_path:
        with open(log_path,"w") as f:
            for row in logs:
                f.write(json.dumps(row) + "\n")
        print(f"saved at {log_path}")

    return metrics

In [None]:
NUM_EXAMPLES = None  #none = full dataset

results_full = {}
available_reps = list(faiss_indexes.keys())
for rep in available_reps:
    results_full[rep] = evaluate_representation(
        tabfact_test,rep=rep,k=TOP_K,max_examples=NUM_EXAMPLES,
        max_new_tokens=16,log_path = f"results/logs_{rep}_{NUM_EXAMPLES or 'full'}.jsonl")

## Results Table

In [None]:
import pandas as pd

df_results = pd.DataFrame(results_full).T
column_order = ['rep', 'num_examples', 'EM', 'F1', f'Recall@{TOP_K}', 'avg_context_tokens', 'avg_prompt_tokens']
df_results = df_results[[c for c in column_order if c in df_results.columns]]
for col in ['EM','F1',f'Recall@{TOP_K}', 'avg_context_tokens', 'avg_prompt_tokens']:
    if col in df_results.columns:
        df_results[col] = df_results[col].round(4)
print(df_results.to_string())
csv_path = f"results/results_{NUM_EXAMPLES or 'full'}.csv"
df_results.to_csv(csv_path, index=True)
print(f"\nrresults saved to {csv_path}")
print(df_results.to_markdown())

## Error Analysis

In [None]:
from collections import Counter

error_summary = {}
for rep in available_reps:
    log_path = f"results/logs_{rep}_{NUM_EXAMPLES or 'full'}.jsonl"
    if not os.path.exists(log_path):
        continue
    with open(log_path, "r") as f:
        logs = [json.loads(line) for line in f]
    error_counts = Counter(log["error_type"] for log in logs)
    error_summary[rep] = dict(error_counts)
    print(f"\n{rep} error breakdown:")
    total = len(logs)
    for error_type, count in sorted(error_counts.items(), key=lambda x: -x[1]):
        print(f"  {error_type}: {count} ({count/total*100:.1f}%)")
df_errors=pd.DataFrame(error_summary).T.fillna(0).astype(int)
print(df_errors.to_string())
df_errors.to_csv(f"results/errors_{NUM_EXAMPLES or 'full'}.csv")

## Visualizations

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from math import pi

#params
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
REP_LABELS = {
    'R1': 'R1: Naive\nSerialization',
    'R2': 'R2: Schema\nTemplates',
    'R3': 'R3: Header+Row\nContext',
    'R4': 'R4: Semantic\nSummaries',
    'R5': 'R5: Adjacency\nSummaries',
    'R2_R4': 'R2+R4:\nSchema+Semantic',
    'R3_R4': 'R3+R4:\nContext+Semantic',
    'R2_R5': 'R2+R5:\nSchema+Adjacency',
    'R3_R5': 'R3+R5:\nContext+Adjacency',
}

#performance metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics = ['EM', 'F1', f'Recall@{TOP_K}']
titles = ['Exact Match (EM)','F1 Score',f'Recall@{TOP_K}']
colors = sns.color_palette("husl", len(available_reps))
for idx, (metric, title) in enumerate(zip(metrics, titles)):
    ax = axes[idx]
    values = [results_full[rep][metric] for rep in available_reps]
    bars = ax.bar(range(len(available_reps)), values, color=colors)
    ax.set_xticks(range(len(available_reps)))
    ax.set_xticklabels([REP_LABELS.get(r, r) for r in available_reps], fontsize=8, rotation=45, ha='right')
    ax.set_ylabel('Score')
    ax.set_title(title, fontweight='bold')
    ax.set_ylim(0, 1.0)
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{val:.3f}', ha='center', va='bottom', fontsize=8)
plt.suptitle('Performance Metrics Across Representations', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('figure1_performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

#token efficiency
fig,ax = plt.subplots(figsize=(12, 6))
x=np.arange(len(available_reps))
width=0.35
ctx_tokens=[results_full[rep]['avg_context_tokens'] for rep in available_reps]
prompt_tokens=[results_full[rep]['avg_prompt_tokens'] for rep in available_reps]
bars1 = ax.bar(x- width/2, ctx_tokens, width, label='Context Tokens', color='#3498db')
bars2 = ax.bar(x+ width/2, prompt_tokens, width, label='Total Prompt Tokens', color='#e74c3c')
ax.set_xlabel('Representation')
ax.set_ylabel('Average Token Count')
ax.set_title('Token Efficiency by Representation', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([REP_LABELS.get(r, r) for r in available_reps], fontsize=8, rotation=45, ha='right')
ax.legend()
plt.tight_layout()
plt.savefig('figure2_token_efficiency.png', dpi=300, bbox_inches='tight')
plt.show()

#error distribution
fig, ax = plt.subplots(figsize=(14, 6))
error_types = ['none', 'retrieval_error', 'prediction_error', 'no_gold_table']
error_labels = ['Correct', 'Retrieval Error', 'Prediction Error', 'Missing Table']
error_colors = ['#2ecc71', '#e74c3c', '#f39c12', '#9b59b6']
bottom = np.zeros(len(available_reps))
for error_type, label, color in zip(error_types, error_labels, error_colors):
    values = []
    for rep in available_reps:
        total = sum(error_summary.get(rep, {}).values())
        count = error_summary.get(rep, {}).get(error_type, 0)
        pct = (count / total * 100) if total > 0 else 0
        values.append(pct)
    ax.bar(range(len(available_reps)), values, bottom=bottom, label=label, color=color)
    bottom += np.array(values)
ax.set_xticks(range(len(available_reps)))
ax.set_xticklabels([REP_LABELS.get(r, r) for r in available_reps], fontsize=8, rotation=45, ha='right')
ax.set_ylabel('Percentage (%)')
ax.set_title('Error Type Distribution by Representation', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
ax.set_ylim(0, 100)
plt.tight_layout()
plt.savefig('figure3_error_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

#radar chart
max_tokens = max(results_full[rep]['avg_prompt_tokens'] for rep in available_reps)
min_tokens = min(results_full[rep]['avg_prompt_tokens'] for rep in available_reps)
categories = ['EM', 'F1', f'Recall@{TOP_K}', 'Token Efficiency']
N = len(categories)
angles = [n / float(N) * 2 * pi for n in range(N)]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
colors = sns.color_palette("husl", len(available_reps))
for idx, rep in enumerate(available_reps):
    values = [
        results_full[rep]['EM'],
        results_full[rep]['F1'],
        results_full[rep][f'Recall@{TOP_K}'],
        1 - (results_full[rep]['avg_prompt_tokens'] - min_tokens) / (max_tokens - min_tokens + 1e-8)
    ]
    values += values[:1]
    ax.plot(angles, values, 'o-', linewidth=2, label=rep, color=colors[idx])
    ax.fill(angles, values, alpha=0.1, color=colors[idx])
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=11)
ax.set_ylim(0, 1)
ax.set_title('Multi-Metric Comparison', fontsize=14, fontweight='bold', y=1.08)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
plt.tight_layout()
plt.savefig('figure4_radar_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

#accuracy v. efficiency
fig, ax = plt.subplots(figsize=(12, 8))
colors = sns.color_palette("husl", len(available_reps))
markers = ['o', 's', '^', 'D', 'p', 'h', '*', 'X', 'v']
for idx, rep in enumerate(available_reps):
    em=results_full[rep]['EM']
    tokens=results_full[rep]['avg_prompt_tokens']
    ax.scatter(tokens,em,s=200,c=[colors[idx]],marker=markers[idx % len(markers)],
               label=rep, edgecolors='black', linewidth=1.5, zorder=5)
    ax.annotate(rep,(tokens,em),xytext=(10, 5),textcoords='offset points',fontsize=9,fontweight='bold')
ax.set_xlabel('Average Prompt Tokens',fontsize=12)
ax.set_ylabel('Exact Match Score ',fontsize=12)
ax.set_title('Accuracy vs Token Efficiency Trade-off',fontsize=14,fontweight='bold')
ax.axhline(y=np.mean([results_full[r]['EM'] for r in available_reps]),
           color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=np.mean([results_full[r]['avg_prompt_tokens'] for r in available_reps]),
           color='gray', linestyle='--', alpha=0.5)
ax.legend(loc='best')
plt.tight_layout()
plt.savefig('figure5_accuracy_vs_efficiency.png', dpi=300, bbox_inches='tight')
plt.show()

#heatmap
fig, ax = plt.subplots(figsize=(10, 8))
metrics_for_heatmap = ['EM', 'F1', f'Recall@{TOP_K}']
data = np.array([[results_full[rep][m] for m in metrics_for_heatmap] for rep in available_reps])
im = ax.imshow(data, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
ax.set_xticks(range(len(metrics_for_heatmap)))
ax.set_xticklabels(['Exact Match', 'F1 Score', f'Recall@{TOP_K}'], fontsize=11)
ax.set_yticks(range(len(available_reps)))
ax.set_yticklabels(available_reps, fontsize=11)
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Score', fontsize=11)
for i in range(len(available_reps)):
    for j in range(len(metrics_for_heatmap)):
        ax.text(j, i, f'{data[i, j]:.3f}', ha='center', va='center',
                color='black', fontsize=10, fontweight='bold')

ax.set_title('Performance Metrics Heatmap', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figure6_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

#much of the help in creating figures came from these two sources:
#https://www.machinelearningplus.com/plots/top-50-matplotlib-visualizations-the-master-plots-python/
#https://python-graph-gallery.com/
