In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -U plotly
!pip install --upgrade "kaleido==0.1.*"
import os
if not os.path.exists("images"):
    os.mkdir("images")

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np

In [None]:
symbols = ['circle', 'square', 'triangle-right', 'star', 'x', 'triangle-left', 'circle-open']
colors = px.colors.qualitative.Plotly
models = ['kNN-LM', 'REALM', 'DPR + FiD', 'Contriever + ATLAS', 'Contriever + Flan-T5']
color_map = {
    'REALM': colors[0],
    'DPR + FiD': colors[1],
    'Contriever + Flan-T5': colors[2],
    'kNN-LM': colors[3],
    'Contriever + ATLAS': colors[4]
}

### Data

In [None]:
# Gold + Distracting Facts
ret_lm_performance = {
    "qa": {
        "1": {
            "kNN-LM": {
                "F1": [
                    "0.1398", "0.1294", "0.1279", "0.1281", "0.1296", "0.1262", "0.1262", "0.1262", "0.1262", "0.1274"
                ],
                "Precision": [
                    "0.1215", "0.1133", "0.1111", "0.1119", "0.1137", "0.1095", "0.1095", "0.1095", "0.1095", "0.1105"
                ],
                "Recall": [
                    "0.2329", "0.2128", "0.2128", "0.2119", "0.2115", "0.2095", "0.2095", "0.2095", "0.2095", "0.2110"
                ]
            },
            "REALM": {
                "F1": [
                    "0.1860", "0.1943", "0.1943", "0.1943", "0.1943", "0.1943", "0.1943", "0.1943", "0.1943", "0.1943"
                ],
                "Precision": [
                    "0.2423", "0.2497", "0.2497", "0.2497", "0.2497", "0.2497", "0.2497", "0.2497", "0.2497", "0.2497"
                ],
                "Recall": [
                    "0.1773", "0.1857", "0.1857", "0.1857", "0.1857", "0.1857", "0.1857", "0.1857", "0.1857", "0.1857"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.2312", "0.2670", "0.2930", "0.3089", "0.3153", "0.3214", "0.3214", "0.3214", "0.3214", "0.3214"
                ],
                "Precision": [
                    "0.2762", "0.3148", "0.3454", "0.3615", "0.3703", "0.3735", "0.3735", "0.3735", "0.3735", "0.3735"
                ],
                "Recall": [
                    "0.2238", "0.2617", "0.2906", "0.3049", "0.3099", "0.3153", "0.3153", "0.3153", "0.3153", "0.3153"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.2227", "0.3120", "0.3485", "0.3595", "0.3679", "0.3626", "0.3626", "0.3626", "0.3626", "0.3626"
                ],
                "Precision": [
                    "0.2546", "0.3558", "0.4089", "0.4197", "0.4278", "0.4214", "0.4214", "0.4214", "0.4214", "0.4214"
                ],
                "Recall": [
                    "0.2275", "0.3126", "0.3455", "0.3580", "0.3659", "0.3623", "0.3623", "0.3623", "0.3623", "0.3623"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.2190", "0.3671", "0.4314", "0.4381", "0.4492", "0.4533", "0.4540", "0.4540", "0.4540", "0.4540"
                ],
                "Precision": [
                    "0.2579", "0.4250", "0.5040", "0.5072", "0.5167", "0.5237", "0.5246", "0.5246", "0.5246", "0.5246"
                ],
                "Recall": [
                    "0.2113", "0.3541", "0.4143", "0.4231", "0.4351", "0.4385", "0.4391", "0.4391", "0.4391", "0.4391"
                ]
            }
        },
        "2": {
            "kNN-LM": {
                "F1": [
                    "0.0790", "0.0813", "0.0792", "0.0771", "0.0737", "0.0712", "0.0686", "0.0660", "0.0637", "0.0611"
                ],
                "Precision": [
                    "0.0743", "0.0760", "0.0736", "0.0720", "0.0687", "0.0658", "0.0625", "0.0610", "0.0640", "0.0604"
                ],
                "Recall": [
                    "0.1148", "0.1193", "0.1110", "0.1080", "0.1055", "0.1004", "0.0983", "0.0947", "0.0905", "0.0868"
                ]
            },
            "REALM": {
                "F1": [
                    "0.1189", "0.1233", "0.1314", "0.1314", "0.1314", "0.1314", "0.1314", "0.1314", "0.1314", "0.1314"
                ],
                "Precision": [
                    "0.1563", "0.1607", "0.1703", "0.1703", "0.1703", "0.1703", "0.1703", "0.1703", "0.1703", "0.1703"
                ],
                "Recall": [
                    "0.1127", "0.1168", "0.1262", "0.1262", "0.1262", "0.1262", "0.1262", "0.1262", "0.1262", "0.1262"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.0552", "0.0644", "0.0804", "0.0974", "0.1135", "0.1562", "0.2124", "0.2475", "0.2732", "0.2732"
                ],
                "Precision": [
                    "0.0715", "0.0757", "0.0901", "0.1215", "0.1423", "0.1888", "0.2545", "0.2891", "0.3179", "0.3179"
                ],
                "Recall": [
                    "0.0520", "0.0638", "0.0803", "0.0926", "0.1052", "0.1497", "0.2034", "0.2415", "0.2650", "0.2650"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.1874", "0.2373", "0.2767", "0.2839", "0.2906", "0.3317", "0.3348", "0.3314", "0.3374", "0.3374"
                ],
                "Precision": [
                    "0.2083", "0.2714", "0.3216", "0.3323", "0.3433", "0.3947", "0.4048", "0.3971", "0.4015", "0.4015"
                ],
                "Recall": [
                    "0.1987", "0.2430", "0.2713", "0.2811", "0.2849", "0.3237", "0.3250", "0.3206", "0.3268", "0.3268"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.1967", "0.2480", "0.2928", "0.3065", "0.3031", "0.3539", "0.3800", "0.3753", "0.3968", "0.3968"
                ],
                "Precision": [
                    "0.2299", "0.2848", "0.3387", "0.3490", "0.3516", "0.4050", "0.4344", "0.4245", "0.4540", "0.4540"
                ],
                "Recall": [
                    "0.1911", "0.2424", "0.2858", "0.3025", "0.2946", "0.3489", "0.3713", "0.3678", "0.3876", "0.3876"
                ]
            }
        },
        "3": {
            "kNN-LM": {
                "F1": [
                    "0.1114", "0.1089", "0.1095", "0.1056", "0.1065", "0.1073", "0.1084", "0.1108", "0.1072", "0.1055"
                ],
                "Precision": [
                    "0.1056", "0.1033", "0.1046", "0.1001", "0.1003", "0.1019", "0.1008", "0.1043", "0.1010", "0.1007"
                ],
                "Recall": [
                    "0.1534", "0.1492", "0.1519", "0.1489", "0.1571", "0.1456", "0.1546", "0.1600", "0.1583", "0.1548"
                ]
            },
            "REALM": {
                "F1": [
                    "0.1381", "0.1573", "0.1610", "0.1610", "0.1610", "0.1639", "0.1639", "0.1639", "0.1639", "0.1639"
                ],
                "Precision": [
                    "0.1824", "0.2025", "0.2044", "0.2044", "0.2044", "0.2073", "0.2073", "0.2073", "0.2073", "0.2073"
                ],
                "Recall": [
                    "0.1346", "0.1531", "0.1569", "0.1569", "0.1569", "0.1598", "0.1598", "0.1598", "0.1598", "0.1598"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.0877", "0.1227", "0.1614", "0.1769", "0.1962", "0.2523", "0.2803", "0.3002", "0.3227", "0.3227"
                ],
                "Precision": [
                    "0.1196", "0.1602", "0.2035", "0.2242", "0.2465", "0.2979", "0.3395", "0.3578", "0.3850", "0.3850"
                ],
                "Recall": [
                    "0.0850", "0.1169", "0.1576", "0.1676", "0.1838", "0.2475", "0.2668", "0.2871", "0.3099", "0.3099"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.2037", "0.2631", "0.2820", "0.3155", "0.3386", "0.3627", "0.3731", "0.3763", "0.3797", "0.3797"
                ],
                "Precision": [
                    "0.2307", "0.3001", "0.3251", "0.3659", "0.3959", "0.4228", "0.4380", "0.4482", "0.4513", "0.4513"
                ],
                "Recall": [
                    "0.2140", "0.2678", "0.2857", "0.3130", "0.3347", "0.3576", "0.3634", "0.3640", "0.3671", "0.3671"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.2005", "0.3043", "0.3025", "0.3508", "0.3637", "0.4001", "0.4149", "0.4073", "0.4229", "0.4229"
                ],
                "Precision": [
                    "0.2250", "0.3419", "0.3407", "0.3996", "0.4058", "0.4507", "0.4677", "0.4622", "0.4822", "0.4822"
                ],
                "Recall": [
                    "0.1989", "0.3017", "0.2982", "0.3436", "0.3619", "0.3960", "0.4142", "0.4038", "0.4156", "0.4156"
                ]
            }
        },
        "4": {
            "kNN-LM": {
                "F1": [
                    "0.2304", "0.2260", "0.2237", "0.2304", "0.2349", "0.2394", "0.2394", "0.2394", "0.2394", "0.2394"
                ],
                "Precision": [
                    "0.2304", "0.2260", "0.2237", "0.2304", "0.2349", "0.2394", "0.2394", "0.2394", "0.2394", "0.2394"
                ],
                "Recall": [
                    "0.2304", "0.2260", "0.2237", "0.2304", "0.2349", "0.2394", "0.2394", "0.2394", "0.2394", "0.2394"
                ]
            },
            "REALM": {
                "F1": [
                    "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676"
                ],
                "Precision": [
                    "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676"
                ],
                "Recall": [
                    "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676", "0.4676"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.4765", "0.4631", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698"
                ],
                "Precision": [
                    "0.4765", "0.4631", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698"
                ],
                "Recall": [
                    "0.4765", "0.4631", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698", "0.4698"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.5414", "0.5391", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369"
                ],
                "Precision": [
                    "0.5414", "0.5391", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369"
                ],
                "Recall": [
                    "0.5414", "0.5391", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369", "0.5369"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ],
                "Precision": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ],
                "Recall": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ]
            }
        }
    },
    "lm": {
        "1": {
            "kNN-LM": {
                "Target ranking accuracy": [
                    "0.5792", "0.5792", "0.5792", "0.5708", "0.5667", "0.5542", "0.5542", "0.5500", "0.5500", "0.5500"
                ],
                "Hits@1": [
                    "0.4125", "0.4042", "0.4083", "0.4125", "0.4208", "0.4250", "0.4250", "0.4250", "0.4250", "0.4250"
                ],
                "Hits@5": [
                    "0.5083", "0.5417", "0.5542", "0.5542", "0.5583", "0.5625", "0.5625", "0.5625", "0.5625", "0.5625"
                ]
            },
            "REALM": {
                "Target ranking accuracy": [
                    "0.6458", "0.7167", "0.7292", "0.7333", "0.7292", "0.7292", "0.7292", "0.7292", "0.7292", "0.7292"
                ]
            },
            "DPR + FiD": {
                "Target ranking accuracy": [
                    "0.6125", "0.6833", "0.6958", "0.7083", "0.7042", "0.7083", "0.7083", "0.7083", "0.7083", "0.7083"
                ],
                "Hits@1": [
                    "0.1542", "0.1917", "0.2292", "0.2417", "0.2458", "0.2667", "0.2667", "0.2667", "0.2667", "0.2667"
                ]
            },
            "Contriever + ATLAS": {
                "Target ranking accuracy": [
                    "0.7583", "0.7958", "0.8000", "0.8125", "0.8167", "0.8125", "0.8125", "0.8125", "0.8125", "0.8125"
                ],
                "Hits@1": [
                    "0.3167", "0.5208", "0.5750", "0.5750", "0.5750", "0.5833", "0.5833", "0.5833", "0.5833", "0.5833"
                ]
            },
            "Contriever + Flan-T5": {
                "Target ranking accuracy": [
                    "0.8208", "0.8667", "0.8917", "0.8958", "0.9125", "0.9292", "0.9292", "0.9292", "0.9292", "0.9292"
                ],
                "Hits@1": [
                    "0.1458", "0.2167", "0.2417", "0.2458", "0.2583", "0.2333", "0.2333", "0.2333", "0.2333", "0.2333"
                ]
            }
        },
        "2": {
            "kNN-LM": {
                "Target ranking accuracy": [
                    "0.5741", "0.6312", "0.6198", "0.6236", "0.6274", "0.6160", "0.6008", "0.5932", "0.5932", "0.5817"
                ],
                "Hits@1": [
                    "0.4030", "0.4030", "0.4106", "0.4106", "0.4068", "0.4144", "0.4068", "0.3954", "0.3954", "0.3954"
                ],
                "Hits@5": [
                    "0.5095", "0.5475", "0.5589", "0.5627", "0.5665", "0.5627", "0.5665", "0.5665", "0.5627", "0.5589"
                ]
            },
            "REALM": {
                "Target ranking accuracy": [
                    "0.6958", "0.7148", "0.7072", "0.7034", "0.7110", "0.7148", "0.6996", "0.6996", "0.6996", "0.6996"
                ]
            },
            "DPR + FiD": {
                "Target ranking accuracy": [
                    "0.6464", "0.6730", "0.6692", "0.6844", "0.6730", "0.6768", "0.6958", "0.7110", "0.7529", "0.7529"
                ],
                "Hits@1": [
                    "0.0798", "0.0989", "0.1331", "0.1521", "0.1711", "0.2053", "0.2281", "0.2510", "0.2395", "0.2395"
                ]
            },
            "Contriever + ATLAS": {
                "Target ranking accuracy": [
                    "0.8175", "0.8517", "0.8593", "0.8593", "0.8631", "0.8555", "0.8517", "0.8517", "0.8441", "0.8441"
                ],
                "Hits@1": [
                    "0.3270", "0.4715", "0.5057", "0.5361", "0.5437", "0.5665", "0.5475", "0.5323", "0.5209", "0.5209"
                ]
            },
            "Contriever + Flan-T5": {
                "Target ranking accuracy": [
                    "0.8935", "0.9049", "0.8973", "0.9049", "0.8973", "0.9163", "0.9163", "0.9316", "0.9240", "0.9240"
                ],
                "Hits@1": [
                    "0.1445", "0.1673", "0.2053", "0.2091", "0.1901", "0.1065", "0.0608", "0.0494", "0.0228", "0.0228"
                ]
            }
        },
        "3": {
            "kNN-LM": {
                "Target ranking accuracy": [
                    "0.6100", "0.6139", "0.6332", "0.6293", "0.6255", "0.6448", "0.6332", "0.6371", "0.6293", "0.6293"
                ],
                "Hits@1": [
                    "0.4093", "0.4093", "0.4208", "0.4170", "0.4170", "0.4131", "0.4208", "0.4208", "0.4208", "0.4208"
                ],
                "Hits@5": [
                    "0.5058", "0.5367", "0.5637", "0.5714", "0.5830", "0.5907", "0.5907", "0.5869", "0.5946", "0.5869"
                ]
            },
            "REALM": {
                "Target ranking accuracy": [
                    "0.7220", "0.7452", "0.7568", "0.7606", "0.7645", "0.7722", "0.7568", "0.7568", "0.7529", "0.7529"
                ]
            },
            "DPR + FiD": {
                "Target ranking accuracy": [
                    "0.6486", "0.6525", "0.6486", "0.6680", "0.6834", "0.7143", "0.7104", "0.7413", "0.7490", "0.7490"
                ],
                "Hits@1": [
                    "0.1081", "0.1351", "0.1544", "0.1853", "0.2201", "0.2625", "0.2741", "0.2934", "0.2934", "0.2934"
                ]
            },
            "Contriever + ATLAS": {
                "Target ranking accuracy": [
                    "0.8147", "0.8224", "0.8417", "0.8456", "0.8533", "0.8571", "0.8456", "0.8649", "0.8610", "0.8610"
                ],
                "Hits@1": [
                    "0.3436", "0.4440", "0.4826", "0.5212", "0.5405", "0.5676", "0.5907", "0.6100", "0.5985", "0.5985"
                ]
            },
            "Contriever + Flan-T5": {
                "Target ranking accuracy": [
                    "0.8571", "0.8571", "0.8456", "0.8378", "0.8456", "0.8764", "0.8803", "0.8803", "0.8610", "0.8610"
                ],
                "Hits@1": [
                    "0.1776", "0.1660", "0.1506", "0.1506", "0.1467", "0.0772", "0.0579", "0.0347", "0.0193", "0.0193"
                ]
            }
        },
        "4": {
            "kNN-LM": {
                "Target ranking accuracy": [
                    "0.2528", "0.2595", "0.2573", "0.2506", "0.2506", "0.2416", "0.2416", "0.2394", "0.2394", "0.2394"
                ],
                "Hits@1": [
                    "0.2237", "0.2192", "0.2237", "0.2304", "0.2304", "0.2349", "0.2349", "0.2349", "0.2371", "0.2371"
                ],
                "Hits@5": [
                    "0.3468", "0.4161", "0.4497", "0.4676", "0.4720", "0.4743", "0.4765", "0.4765", "0.4765", "0.4765"
                ]
            },
            "REALM": {
                "Target ranking accuracy": [
                    "0.4541", "0.5772", "0.5928", "0.5973", "0.5973", "0.5973", "0.5973", "0.5973", "0.5973", "0.5973"
                ]
            },
            "DPR + FiD": {
                "Target ranking accuracy": [
                    "0.3803", "0.4855", "0.5414", "0.5548", "0.5548", "0.5548", "0.5548", "0.5548", "0.5548", "0.5548"
                ],
                "Hits@1": [
                    "0.1477", "0.2170", "0.2304", "0.2461", "0.2461", "0.2461", "0.2461", "0.2461", "0.2461", "0.2461"
                ]
            },
            "Contriever + ATLAS": {
                "Target ranking accuracy": [
                    "0.3736", "0.5347", "0.5615", "0.5794", "0.5794", "0.5794", "0.5794", "0.5794", "0.5794", "0.5794"
                ],
                "Hits@1": [
                    "0.1857", "0.3356", "0.3691", "0.3669", "0.3669", "0.3669", "0.3669", "0.3669", "0.3669", "0.3669"
                ]
            },
            "Contriever + Flan-T5": {
                "Target ranking accuracy": [
                    "0.4407", "0.6219", "0.6980", "0.7069", "0.7092", "0.7092", "0.7092", "0.7092", "0.7092", "0.7092"
                ],
                "Hits@1": [
                    "0.0738", "0.2819", "0.3624", "0.3758", "0.3758", "0.3758", "0.3758", "0.3758", "0.3758", "0.3758"
                ]
            }
        }
    }
}

In [None]:
# Gold Facts

lm_performance = {
    "qa": {
        "2": {
            "kNN-LM": {
                "F1": [
                    "0.1388", "0.1354", "0.1346", "0.1333", "0.1350", "0.1299", "0.1307", "0.1307", "0.1307", "0.1307"
                ],
                "Precision": [
                    "0.1221", "0.1202", "0.1204", "0.1196", "0.1209", "0.1130", "0.1161", "0.1161", "0.1161", "0.1161"
                ],
                "Recall": [
                    "0.2314", "0.2198", "0.2205", "0.2169", "0.2196", "0.2157", "0.2161", "0.2161", "0.2161", "0.2161"
                ]
            },
            "REALM": {
                "F1": [
                    "0.1928", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984"
                ],
                "Precision": [
                    "0.2627", "0.2673", "0.2673", "0.2673", "0.2673", "0.2673", "0.2673", "0.2673", "0.2673", "0.2673"
                ],
                "Recall": [
                    "0.1814", "0.1871", "0.1871", "0.1871", "0.1871", "0.1871", "0.1871", "0.1871", "0.1871", "0.1871"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.2334", "0.2753", "0.2943", "0.3030", "0.3062", "0.3130", "0.3130", "0.3130", "0.3130", "0.3130"
                ],
                "Precision": [
                    "0.2789", "0.3287", "0.3494", "0.3580", "0.3616", "0.3664", "0.3664", "0.3664", "0.3664", "0.3664"
                ],
                "Recall": [
                    "0.2299", "0.2709", "0.2905", "0.2985", "0.3017", "0.3082", "0.3082", "0.3082", "0.3082", "0.3082"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.2135", "0.3023", "0.3217", "0.3294", "0.3394", "0.3313", "0.3313", "0.3313", "0.3313", "0.3313"
                ],
                "Precision": [
                    "0.2559", "0.3517", "0.3772", "0.3874", "0.3987", "0.3898", "0.3898", "0.3898", "0.3898", "0.3898"
                ],
                "Recall": [
                    "0.2147", "0.3029", "0.3196", "0.3264", "0.3364", "0.3286", "0.3286", "0.3286", "0.3286", "0.3286"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.2195", "0.3630", "0.4132", "0.4149", "0.4248", "0.4332", "0.4339", "0.4339", "0.4339", "0.4339"
                ],
                "Precision": [
                    "0.2587", "0.4213", "0.4772", "0.4801", "0.4895", "0.5033", "0.5042", "0.5042", "0.5042", "0.5042"
                ],
                "Recall": [
                    "0.2108", "0.3491", "0.3982", "0.3995", "0.4090", "0.4159", "0.4165", "0.4165", "0.4165", "0.4165"
                ]
            }
        }
    },
    "lm": {
        "2": {
            "kNN-LM": {
                "Target ranking accuracy": [
                    "0.6190", "0.6349", "0.6548", "0.6389", "0.6389", "0.6151", "0.6190", "0.6151", "0.6190", "0.6190"
                ],
                "Hits@1": [
                    "0.3889", "0.3889", "0.3849", "0.3929", "0.3968", "0.3968", "0.3968", "0.3968", "0.3968", "0.3968"
                ],
                "Hits@5": [
                    "0.4921", "0.5119", "0.5317", "0.5357", "0.5317", "0.5397", "0.5397", "0.5397", "0.5397", "0.5397"
                ]
            },
            "REALM": {
                "Target ranking accuracy": [
                    "0.8214", "0.8413", "0.8492", "0.8532", "0.8532", "0.8532", "0.8532", "0.8532", "0.8532", "0.8532"
                ]
            },
            "DPR + FiD": {
                "Target ranking accuracy": [
                    "0.6786", "0.7302", "0.7341", "0.7262", "0.7421", "0.7460", "0.7460", "0.7460", "0.7460", "0.7460"
                ],
                "Hits@1": [
                    "0.1746", "0.2143", "0.2460", "0.2460", "0.2500", "0.2698", "0.2698", "0.2698", "0.2698", "0.2698"
                ]
            },
            "Contriever + ATLAS": {
                "Target ranking accuracy": [
                    "0.8452", "0.8730", "0.8730", "0.8849", "0.8849", "0.8810", "0.8810", "0.8810", "0.8810", "0.8810"
                ],
                "Hits@1": [
                    "0.3214", "0.4802", "0.5278", "0.5198", "0.5198", "0.5278", "0.5278", "0.5278", "0.5278", "0.5278"
                ]
            },
            "Contriever + Flan-T5": {
                "Target ranking accuracy": [
                    "0.9048", "0.9048", "0.9087", "0.9087", "0.9246", "0.9286", "0.9286", "0.9286", "0.9286", "0.9286"
                ],
                "Hits@1": [
                    "0.1508", "0.1825", "0.2103", "0.2063", "0.2183", "0.1984", "0.1984", "0.1984", "0.1984", "0.1984"
                ]
            }
        }
    }
}

In [None]:
# Single Fact
max_lm_performance = {
    "qa": {
        "2": {
            "kNN-LM": {
                "F1": [
                    "0.3359"
                ],
                "Precision": [
                    "0.2787"
                ],
                "Recall": [
                    "0.5846"
                ]
            },
            "REALM": {
                "F1": [
                    "0.5680"
                ],
                "Precision": [
                    "0.7005"
                ],
                "Recall": [
                    "0.5432"
                ]
            },
            "DPR + FiD": {
                "F1": [
                    "0.6010"
                ],
                "Precision": [
                    "0.6738"
                ],
                "Recall": [
                    "0.5907"
                ]
            },
            "Contriever + ATLAS": {
                "F1": [
                    "0.6192"
                ],
                "Precision": [
                    "0.6843"
                ],
                "Recall": [
                    "0.6269"
                ]
            },
            "Contriever + Flan-T5": {
                "F1": [
                    "0.4810"
                ],
                "Precision": [
                    "0.5508"
                ],
                "Recall": [
                    "0.4724"
                ]
            }
        }
    }
}

In [None]:
retriever_recall = {
    "qa": {
        "kNN-LM": [
            0.057348748812163425, 0.09594870745785378, 0.13351403846830673, 0.15798908070249537, 0.17611271689015603, 0.35121984214972013, 0.5200472714954425, 0.683542744518354, 0.8242312057860843, 0.9711013664519763
        ],
        "REALM": [
            0.2103191100904517, 0.321355871432091, 0.3948609571322986, 0.46877606641326147, 0.5132001636574808, 0.6931373842959205, 0.8289823320311123, 0.9213618106148597, 1.0, 1.0
        ],
        "DPR": [
            0.027002384471896675, 0.06874780030267831, 0.11482222046246436, 0.15936257171013277, 0.18455405536198227, 0.3779547434273043, 0.5995368537289268, 0.7905188646042305, 1.0,
            1.0
        ],
        "Contriever": [
            0.2396045604124873, 0.35790846179565716, 0.4272566694822794, 0.4868525190933728, 0.5357107661985712, 0.7100053232675183, 0.8246516779291166, 0.9233595757223811, 1.0,
            1.0
        ]
    },
    "lm": {
        "kNN-LM": [
            0.17652946700565755, 0.24998181932705751, 0.29524840239125955, 0.33273752261847506, 0.37039198676103435, 0.5060694816647198, 0.6299271055223434, 0.7375475274284796, 0.8200718923933207, 0.9788736285760097
        ],
        "REALM": [
            0.1899207494445591, 0.29159523236904206, 0.36765357658214837, 0.42323303900684867, 0.46959432763004194, 0.6499305698710462, 0.7965116012735061, 0.9282101055910582, 1.0,
            1.0
        ],
        "DPR": [
            0.041356593142307434, 0.078877207448636, 0.11788749169701547, 0.16890031265031274, 0.21290899356375556, 0.42512683524588285, 0.6318047252571061, 0.8282012299869437, 1.0,
            1.0
        ],
        "Contriever": [
            0.2629635355825835, 0.4317186891591658, 0.5075694587599353, 0.5683126617650426, 0.620163368377654, 0.7779857818548295, 0.8735111889873794, 0.9409740832359881, 1.0,
            1.0
        ]
    }
}

In [None]:
retriever_accuracy = {
    "qa": {
        "kNN-LM": [
            0.1522472107837962
        ],
        "REALM": [
            0.429145989511843
        ],
        "DPR": [
            0.1413177946714533
        ],
        "Contriever": [
            0.4681478328581989
        ]
    },
    "lm": {
        "kNN-LM": [
            0.33395605717034277
        ],
        "REALM": [
            0.4041980175908749
        ],
        "DPR": [
            0.15837355986165513
        ],
        "Contriever": [
            0.5707312638860255
        ]
    }
}

In [None]:
# Flan Scores
contriever_flan_performance = {
    "qa": {
        "1": {
            "Flan-T5-small": {
                "F1": [
                    "0.1589", "0.2842", "0.3326", "0.3499", "0.3661", "0.3751", "0.3775", "0.3775", "0.3775", "0.3775"
                ],
                "Precision": [
                    "0.1908", "0.3269", "0.3830", "0.3984", "0.4201", "0.4278", "0.4303", "0.4303", "0.4303", "0.4303"
                ],
                "Recall": [
                    "0.1513", "0.2742", "0.3234", "0.3430", "0.3573", "0.3667", "0.3693", "0.3693", "0.3693", "0.3693"
                ]
            },
            "Flan-T5-base": {
                "F1": [
                    "0.2190", "0.3671", "0.4314", "0.4381", "0.4492", "0.4533", "0.4540", "0.4540", "0.4540", "0.4540"
                ],
                "Precision": [
                    "0.2579", "0.4250", "0.5040", "0.5072", "0.5167", "0.5237", "0.5246", "0.5246", "0.5246", "0.5246"
                ],
                "Recall": [
                    "0.2113", "0.3541", "0.4143", "0.4231", "0.4351", "0.4385", "0.4391", "0.4391", "0.4391", "0.4391"
                ]
            },
            "Flan-T5-large": {
                "F1": [
                    "0.2689", "0.4152", "0.4800", "0.4950", "0.4992", "0.5214", "0.5185", "0.5185", "0.5185", "0.5185"
                ],
                "Precision": [
                    "0.3045", "0.4712", "0.5435", "0.5588", "0.5689", "0.5914", "0.5887", "0.5887", "0.5887", "0.5887"
                ],
                "Recall": [
                    "0.2638", "0.4042", "0.4695", "0.4884", "0.4887", "0.5125", "0.5094", "0.5094", "0.5094", "0.5094"
                ]
            },
            "Flan-T5-xl": {
                "F1": [
                    "0.3245", "0.4349", "0.5074", "0.5269", "0.5533", "0.5726", "0.5701", "0.5701", "0.5701", "0.5701"
                ],
                "Precision": [
                    "0.3532", "0.4793", "0.5585", "0.5776", "0.6106", "0.6282", "0.6245", "0.6245", "0.6245", "0.6245"
                ],
                "Recall": [
                    "0.3231", "0.4311", "0.4995", "0.5224", "0.5479", "0.5671", "0.5648", "0.5648", "0.5648", "0.5648"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.3685", "0.4756", "0.5389", "0.5548", "0.5702", "0.5878", "0.5878", "0.5878", "0.5878", "0.5878"
                ],
                "Precision": [
                    "0.3989", "0.5191", "0.5958", "0.6127", "0.6303", "0.6500", "0.6473", "0.6473", "0.6473", "0.6473"
                ],
                "Recall": [
                    "0.3713", "0.4761", "0.5344", "0.5511", "0.5659", "0.5846", "0.5852", "0.5852", "0.5852", "0.5852"
                ]
            }
        },
        "2": {
            "Flan-T5-small": {
                "F1": [
                    "0.1301", "0.1915", "0.2317", "0.2385", "0.2477", "0.2526", "0.2600", "0.2430", "0.2576", "0.2576"
                ],
                "Precision": [
                    "0.1578", "0.2245", "0.2716", "0.2779", "0.2857", "0.2866", "0.2931", "0.2817", "0.2985", "0.2985"
                ],
                "Recall": [
                    "0.1269", "0.1873", "0.2269", "0.2309", "0.2409", "0.2512", "0.2556", "0.2343", "0.2494", "0.2494"
                ]
            },
            "Flan-T5-base": {
                "F1": [
                    "0.1967", "0.2480", "0.2928", "0.3065", "0.3031", "0.3539", "0.3800", "0.3753", "0.3968", "0.3968"
                ],
                "Precision": [
                    "0.2299", "0.2848", "0.3387", "0.3490", "0.3516", "0.4050", "0.4344", "0.4245", "0.4540", "0.4540"
                ],
                "Recall": [
                    "0.1911", "0.2424", "0.2858", "0.3025", "0.2946", "0.3489", "0.3713", "0.3678", "0.3876", "0.3876"
                ]
            },
            "Flan-T5-large": {
                "F1": [
                    "0.2546", "0.3132", "0.3554", "0.3575", "0.3752", "0.4135", "0.4197", "0.4334", "0.4475", "0.4475"
                ],
                "Precision": [
                    "0.2902", "0.3581", "0.4043", "0.4065", "0.4309", "0.4801", "0.4808", "0.4998", "0.5254", "0.5254"
                ],
                "Recall": [
                    "0.2500", "0.3069", "0.3461", "0.3513", "0.3646", "0.3966", "0.4093", "0.4202", "0.4274", "0.4274"
                ]
            },
            "Flan-T5-xl": {
                "F1": [
                    "0.3214", "0.3538", "0.3801", "0.3865", "0.4112", "0.4353", "0.4642", "0.4773", "0.5004", "0.5004"
                ],
                "Precision": [
                    "0.3508", "0.3799", "0.4128", "0.4178", "0.4462", "0.4794", "0.5083", "0.5361", "0.5589", "0.5589"
                ],
                "Recall": [
                    "0.3212", "0.3603", "0.3798", "0.3883", "0.4131", "0.4326", "0.4597", "0.4649", "0.4866", "0.4866"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.3631", "0.4019", "0.4242", "0.4302", "0.4363", "0.4605", "0.4935", "0.4944", "0.5182", "0.5182"
                ],
                "Precision": [
                    "0.3922", "0.4360", "0.4552", "0.4645", "0.4710", "0.4964", "0.5380", "0.5331", "0.5612", "0.5612"
                ],
                "Recall": [
                    "0.3658", "0.4040", "0.4340", "0.4416", "0.4464", "0.4691", "0.4980", "0.5039", "0.5212", "0.5212"
                ]
            }
        },
        "3": {
            "Flan-T5-small": {
                "F1": [
                    "0.1561", "0.2125", "0.2196", "0.2349", "0.2541", "0.2389", "0.2733", "0.2772", "0.2802", "0.2802"
                ],
                "Precision": [
                    "0.1832", "0.2425", "0.2513", "0.2710", "0.2907", "0.2734", "0.3223", "0.3213", "0.3277", "0.3277"
                ],
                "Recall": [
                    "0.1568", "0.2109", "0.2185", "0.2294", "0.2506", "0.2349", "0.2681", "0.2698", "0.2767", "0.2767"
                ]
            },
            "Flan-T5-base": {
                "F1": [
                    "0.2005", "0.3043", "0.3025", "0.3508", "0.3637", "0.4001", "0.4149", "0.4073", "0.4229", "0.4229"
                ],
                "Precision": [
                    "0.2250", "0.3419", "0.3407", "0.3996", "0.4058", "0.4507", "0.4677", "0.4622", "0.4822", "0.4822"
                ],
                "Recall": [
                    "0.1989", "0.3017", "0.2982", "0.3436", "0.3619", "0.3960", "0.4142", "0.4038", "0.4156", "0.4156"
                ]
            },
            "Flan-T5-large": {
                "F1": [
                    "0.2533", "0.3359", "0.3641", "0.3980", "0.4129", "0.4547", "0.4672", "0.4813", "0.4992", "0.4992"
                ],
                "Precision": [
                    "0.2829", "0.3758", "0.4142", "0.4546", "0.4705", "0.5254", "0.5476", "0.5660", "0.5800", "0.5800"
                ],
                "Recall": [
                    "0.2518", "0.3321", "0.3527", "0.3842", "0.4013", "0.4428", "0.4510", "0.4614", "0.4819", "0.4819"
                ]
            },
            "Flan-T5-xl": {
                "F1": [
                    "0.3164", "0.3665", "0.4019", "0.4287", "0.4398", "0.4723", "0.4934", "0.5248", "0.5337", "0.5337"
                ],
                "Precision": [
                    "0.3378", "0.3934", "0.4385", "0.4658", "0.4861", "0.5107", "0.5523", "0.5917", "0.5960", "0.5960"
                ],
                "Recall": [
                    "0.3240", "0.3670", "0.3967", "0.4243", "0.4357", "0.4777", "0.4905", "0.5152", "0.5217", "0.5217"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.3548", "0.4417", "0.4529", "0.4631", "0.4525", "0.4909", "0.5066", "0.5453", "0.5548", "0.5548"
                ],
                "Precision": [
                    "0.3886", "0.4802", "0.4950", "0.5020", "0.4957", "0.5288", "0.5422", "0.5831", "0.5969", "0.5969"
                ],
                "Recall": [
                    "0.3567", "0.4420", "0.4584", "0.4664", "0.4592", "0.5067", "0.5223", "0.5626", "0.5677", "0.5677"
                ]
            }
        },
        "4": {
            "Flan-T5-small": {
                "F1": [
                    "0.5391", "0.5347", "0.5481", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526"
                ],
                "Precision": [
                    "0.5391", "0.5347", "0.5481", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526"
                ],
                "Recall": [
                    "0.5391", "0.5347", "0.5481", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526", "0.5526"
                ]
            },
            "Flan-T5-base": {
                "F1": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ],
                "Precision": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ],
                "Recall": [
                    "0.6107", "0.6488", "0.6644", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734", "0.6734"
                ]
            },
            "Flan-T5-large": {
                "F1": [
                    "0.6846", "0.7696", "0.7919", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076"
                ],
                "Precision": [
                    "0.6846", "0.7696", "0.7919", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076"
                ],
                "Recall": [
                    "0.6846", "0.7696", "0.7919", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076", "0.8076"
                ]
            },
            "Flan-T5-xl": {
                "F1": [
                    "0.7383", "0.8345", "0.8546", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658"
                ],
                "Precision": [
                    "0.7383", "0.8345", "0.8546", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658"
                ],
                "Recall": [
                    "0.7383", "0.8345", "0.8546", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658", "0.8658"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.7383", "0.8635", "0.8904", "0.9016", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038"
                ],
                "Precision": [
                    "0.7383", "0.8635", "0.8904", "0.9016", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038"
                ],
                "Recall": [
                    "0.7383", "0.8635", "0.8904", "0.9016", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038", "0.9038"
                ]
            }
        }
    }
}

In [None]:
# DSP Scores
dsp_scores = {
    "no-dsp": {
        "1": {
            "Flan-T5-base": {
                "F1": [
                    "0.4041"
                ],
                "Precision": [
                    "0.4568"
                ],
                "Recall": [
                    "0.4055"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.5743"
                ],
                "Precision": [
                    "0.6639"
                ],
                "Recall": [
                    "0.5518"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.4682"
                ],
                "Precision": [
                    "0.4539"
                ],
                "Recall": [
                    "0.6254"
                ]
            }
        },
        "2": {
            "Flan-T5-base": {
                "F1": [
                    "0.2848"
                ],
                "Precision": [
                    "0.3216"
                ],
                "Recall": [
                    "0.2908"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.4157"
                ],
                "Precision": [
                    "0.4760"
                ],
                "Recall": [
                    "0.4080"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.3689"
                ],
                "Precision": [
                    "0.3452"
                ],
                "Recall": [
                    "0.4928"
                ]
            }
        },
        "4": {
            "Flan-T5-base": {
                "F1": [
                    "0.6353"
                ],
                "Precision": [
                    "0.6353"
                ],
                "Recall": [
                    "0.6353"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.9128"
                ],
                "Precision": [
                    "0.9128"
                ],
                "Recall": [
                    "0.9128"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.7170"
                ],
                "Precision": [
                    "0.7036"
                ],
                "Recall": [
                    "0.8591"
                ]
            }
        }
    },
    "dsp": {
        "1": {
            "Flan-T5-base": {
                "F1": [
                    "0.0624"
                ],
                "Precision": [
                    "0.0516"
                ],
                "Recall": [
                    "0.1334"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.1971"
                ],
                "Precision": [
                    "0.1537"
                ],
                "Recall": [
                    "0.6190"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.5258"
                ],
                "Precision": [
                    "0.5630"
                ],
                "Recall": [
                    "0.5522"
                ]
            }
        },
        "2": {
            "Flan-T5-base": {
                "F1": [
                    "0.0568"
                ],
                "Precision": [
                    "0.0500"
                ],
                "Recall": [
                    "0.1047"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.1252"
                ],
                "Precision": [
                    "0.0884"
                ],
                "Recall": [
                    "0.4654"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.4070"
                ],
                "Precision": [
                    "0.4225"
                ],
                "Recall": [
                    "0.4334"
                ]
            }
        },
        "4": {
            "Flan-T5-base": {
                "F1": [
                    "0.4787"
                ],
                "Precision": [
                    "0.4787"
                ],
                "Recall": [
                    "0.4787"
                ]
            },
            "Flan-T5-xxl": {
                "F1": [
                    "0.7993"
                ],
                "Precision": [
                    "0.7990"
                ],
                "Recall": [
                    "0.8076"
                ]
            },
            "GPT-3": {
                "F1": [
                    "0.8019"
                ],
                "Precision": [
                    "0.7956"
                ],
                "Recall": [
                    "0.8613"
                ]
            }
        }
    }
}

### All Visualizations

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])
task = "lm"
metric = "Target ranking accuracy"

fig = make_subplots(rows=2, cols=2, subplot_titles=("EntailmentBank-Easy", "EntailmentBank-Hard", "EntailmentBank-3", "StrategyQA"),
                    horizontal_spacing = 0.05, vertical_spacing = 0.15,
                    x_title='# of retrieved statements',
                    y_title='Accuracy')
for eb in [1, 2, 3, 4]:
    if eb >= 3:
      row = 2
    else:
      row = 1
    if eb % 2 == 1:
      col = 1
    else:
      col = 2

    for i, model in enumerate(models):
        ys = ret_lm_performance[task][str(eb)][model][metric]
        if eb == 1:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=row, col=col)
        else:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=False), row=row, col=col)


fig.update_xaxes(type="log", tickvals=x)
fig.update_yaxes(range=[20, 100])

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.15,
    xanchor="left",
    x=0.09,
    font_size=10
))
fig.update_layout(font_size=10)
fig.update_layout(autosize=False, width=700, height=400)
fig.update_annotations(font=dict(size=10))
fig.update_layout(margin=dict(l=50,r=0,b=45,t=20))
fig.show()
fig.write_image("images/lm-EB1-EB2-EB3-SQA.pdf", engine="kaleido")

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])
task = "qa"
metric = "F1"

fig = make_subplots(rows=2, cols=2, subplot_titles=("EntailmentBank-Easy", "EntailmentBank-Hard", "EntailmentBank-3", "StrategyQA"),
                    horizontal_spacing = 0.05, vertical_spacing = 0.15,
                    x_title='# of retrieved statements',)
for eb in [1, 2, 3, 4]:
    if eb >= 3:
      row = 2
    else:
      row = 1
    if eb % 2 == 1:
      col = 1
    else:
      col = 2

    for i, model in enumerate(models):
        ys = ret_lm_performance[task][str(eb)][model][metric]
        if eb == 1:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=row, col=col)
        else:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=False), row=row, col=col)


fig.update_xaxes(type="log", tickvals=x)
fig.update_yaxes(range=[0, 70])

fig.add_annotation(dict(x=-0.1,
                        y=0.5,
                        showarrow=False,
                        font_size=12,
                        text='Token overlap F1 score',
                        textangle=-90,
                        xref="paper",
                        yref="paper"
                        )
                  )
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.25,
    xanchor="left",
    x=-0.01,
    font_size=11
))
fig.update_layout(font_size=12)
fig.update_annotations(font=dict(size=12))
fig.update_layout(autosize=False, width=500, height=400)
fig.update_layout(margin=dict(l=40,r=0,b=50,t=20))
fig.show()
# fig.write_image("images/qa-EB1-EB2-EB3-SQA.pdf", engine="kaleido")

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15])

fig = make_subplots(rows=2, cols=1, subplot_titles=("Language Modeling", "Question Answering"),
                    horizontal_spacing = 0.05, vertical_spacing = 0.15,
                    x_title='# of retrieved statements',)
metrics = {"lm": "Target ranking accuracy", "qa": "F1"}
for task in metrics.keys():
  if task == "lm":
    row = 1
  else:
    row = 2
  for i, model in enumerate(models):
      ys = ret_lm_performance[task]["2"][model][metrics[task]]
      fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (all)', line_shape='spline', line_dash='dot', marker_color=color_map[model]), row=row, col=1)
      ys = lm_performance[task]["2"][model][metrics[task]]
      fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (gold)', line_shape='spline', marker_color=color_map[model]), row=row, col=1)

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(
                  yaxis1_title='Accuracy', yaxis2_title='Token overlap F1 score')
fig.update_layout(showlegend=True, legend=dict(y=0.5, font_size=12))
fig.update_layout(font_size=12)
fig.update_layout(autosize=False, width=600, height=500)
fig.update_annotations(font=dict(size=12))
fig.update_layout(margin=dict(l=0,r=0,b=30,t=30))
fig.show()
# fig.write_image("images/all-vs-gold.pdf", engine="kaleido")

In [None]:
import plotly.graph_objects as go

metric = {"qa": "F1", "lm": "Target ranking accuracy"}
fig_title = {"qa": "Token overlap F1 score", "lm": "Accuracy"}

def plot_additional_distracting(task, bound=[0, 50]):
    fig = go.Figure()
    data = {}
    additional_colors = []
    for model in models:
      data[f'{model} (gold)'] = []
      data[f'{model} (all)'] = []

      additional_colors.append(color_map[model])
      additional_colors.append(color_map[model])

    for model in models:
      data[f'{model} (gold)'].append(float(lm_performance[task]["2"][model][metric[task]][-1])*100)
      data[f'{model} (all)'].append(float(ret_lm_performance[task]["2"][model][metric[task]][-1])*100)

    # for model in models:
    #   data[f'{model} (gold)'].append(qa_vs_data_cf[model]['gold'][-1]*100)
    #   data[f'{model} (all)'].append(qa_vs_data_cf[model]['all'][-1]*100)

    for i, (model, scores) in enumerate(data.items()):
      fig.add_trace(go.Bar(
          name=model, x=['Models'], y=scores,
          marker_color=additional_colors[i], marker_pattern_shape="" if i % 2 == 0 else '/'
      ))
    fig.update_yaxes(range=bound)
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="top",
        y=1.7,
        xanchor="left",
        x=-0.001,
        font_size=11
    ))
    fig.update_layout(barmode='group', showlegend=True)
    fig.update_layout(autosize=False, width=400, height=300)
    fig.update_annotations(font=dict(size=10))
    fig.update_layout(margin=dict(l=0,r=0,b=0,t=0))
    fig.update_layout(yaxis_title=fig_title[task], font_size=9)
    fig.show()
    fig.write_image(f"images/{task}_additional_distracting.pdf", engine="kaleido")

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])

fig = go.Figure()

fig.add_trace(go.Scatter(x=x, y=[float(max_lm_performance['qa']['2']['REALM']['F1'][0])*100] * (1+ len(lm_performance['qa']['2']['REALM']['F1'])), mode="lines+text", line_dash='dot', marker_color='black', name='Single Statement', showlegend=True, legendrank=6))
for i, model in enumerate(models):
    ys = lm_performance['qa']['2'][model]['F1']
    fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], legendrank=i+1))
    fig.add_trace(go.Scatter(x=x, y=[float(max_lm_performance['qa']['2'][model]['F1'][0])*100] * (1+ len(ys)), mode="lines+text", line_dash='dot', marker_symbol=symbols[i], marker_color=color_map[model], textposition="top left", showlegend=False))


fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(
                  yaxis_title='Token overlap F1 score',
                  xaxis_title='# of retrieved statements')
fig.update_yaxes(range=[10, 70])
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.35,
    xanchor="left",
    x=-0.00001,
    font_size=10
))
fig.update_layout(font_size=10)
fig.update_layout(autosize=False, width=400, height=300)
fig.update_annotations(font=dict(size=10))
fig.update_layout(margin=dict(l=45,r=0,b=0,t=20))
fig.show()
# fig.write_image("images/qa_gold_vs_max.pdf", engine="kaleido")

In [None]:
rets = ['DPR', 'kNN-LM', 'REALM', 'Contriever']

import plotly.graph_objects as go

def plot_top_retriever_recall():
    fig = go.Figure()
    data = {ret: [] for ret in rets}
    additional_colors = []
    for ret in rets:
      data[ret] = [retriever_accuracy["lm"][ret][0]*100, retriever_accuracy["qa"][ret][0]*100]

    for i, (ret, scores) in enumerate(data.items()):
      fig.add_trace(go.Bar(
          name=ret, x=['Language Modeling', 'Question Answering'], y=scores,
          marker_color=colors[i+5]
      ))

    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="top",
        y=1.2,
        xanchor="left",
        x=0.02,
        font_size=11
    ))
    fig.update_layout(barmode='group', showlegend=True)
    #fig.update_traces(marker_pattern_shape=['.', '\\'])
    fig.update_layout(autosize=False, width=400, height=250)
    fig.update_annotations(font=dict(size=11))
    fig.update_layout(margin=dict(l=0,r=0,b=0,t=0))
    fig.update_layout(xaxis_title="Tasks", yaxis_title="Accuracy", font_size=10)
    fig.show()
    # fig.write_image("images/ret_accuracy.pdf", engine="kaleido")
plot_top_retriever_recall()

In [None]:
# Retriever
x=[1,2,3,4,5,10,15,20,25,100]
rets = ['DPR', 'kNN-LM', 'REALM', 'Contriever']
fig = make_subplots(rows=2, cols=1, subplot_titles=("Language Modeling", "Question Answering"), vertical_spacing = 0.12, horizontal_spacing = 0.01, x_title='# of retrieved statements')


for i, (ret) in enumerate(rets):
    ys = retriever_recall["lm"][ret]
    fig.add_trace(go.Scatter(x=x, y=[y*100 for y in ys], name=ret, line_shape='spline', marker_symbol=symbols[i], marker_color=colors[i+5], showlegend=True), row=1, col=1)
for i, (ret) in enumerate(rets):
    ys = retriever_recall["qa"][ret]
    fig.add_trace(go.Scatter(x=x, y=[y*100 for y in ys], name=ret, line_shape='spline', marker_symbol=symbols[i], marker_color=colors[i+5], showlegend=False), row=2, col=1)

fig.update_xaxes(type="log", tickvals=x, tickangle=90)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.13,
    x=0.01,
    font_size=11
))

fig.add_annotation(dict(x=-0.12,
                            y=0.5,
                            showarrow=False,
                            font_size=11,
                            text='Retrieved statements recall score',
                            textangle=-90,
                            xref="paper",
                            yref="paper"
                           )
                  )

fig.update_layout(font_size=11)
fig.update_layout(autosize=False, width=400, height=400)
fig.update_annotations(font=dict(size=11))
fig.update_layout(margin=dict(l=40,r=10,b=50,t=20))
fig.show()
# fig.write_image("images/ret_recall.pdf", engine="kaleido")


In [None]:
def visualize_flans(data, metric="F1", colors=colors):
  x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])

  models = data["qa"]["1"].keys()
  model_name_map = {
      'Flan-T5-small': 'Flan-T5-small (80M)',
      'Flan-T5-base': 'Flan-T5-base (250M)',
      'Flan-T5-large': 'Flan-T5-large (780M)',
      'Flan-T5-xl': 'Flan-T5-xl (3B)',
      'Flan-T5-xxl': 'Flan-T5-xxl (11B)'
  }
  color_map={model_name_map[m]:colors[1:][i] for i,m in enumerate(models)}
  color_map['Flan-T5-xxl (11B)'] = colors[6]
  color_map['Flan-T5-small (80M)'] = colors[5]

  fig = make_subplots(rows=2, cols=2, subplot_titles=("EntailmentBank-Easy", "EntailmentBank-Hard", "EntailmentBank-3", "StrategyQA"),
                      horizontal_spacing = 0.1, vertical_spacing = 0.12,
                      x_title='# of retrieved statements',)
  for eb in [1, 2, 3, 4]:
    if eb >= 3:
      row = 2
    else:
      row = 1
    if eb % 2 == 1:
      col = 1
    else:
      col = 2
    for i, (model, ys) in enumerate(data["qa"][str(eb)].items()):
        if eb == 1:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys[metric]], name=model_name_map[model], line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model_name_map[model]], showlegend=True), row=row, col=col)
        else:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys[metric]], name=model_name_map[model], line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model_name_map[model]], showlegend=False), row=row, col=col)


  fig.update_xaxes(type="log", tickvals=x)
  fig.update_yaxes(range=[10, 95])
  fig.update_layout(font_size=11)
  fig.update_annotations(font=dict(size=12))
  fig.update_layout(width=700, height=400)

  fig.update_layout(legend=dict(
      orientation="h",
      y=1.2,
      x=0,
      font_size=12
  ))
  fig.update_layout(
                    yaxis1_title=f'Token overlap {metric} score',
                    yaxis2_title=f'Token overlap {metric} score',
                    yaxis3_title=f'Token overlap {metric} score',
                    yaxis4_title='Accuracy',
                    font_size=10)
  fig.update_layout(margin=dict(l=40,r=0,b=50,t=20))
  fig.show()
  # fig.write_image(f"images/flans-EB1-EB2-EB3-SQA.pdf", engine="kaleido")
visualize_flans(contriever_flan_performance)

### Paper_specific Visualizations

#### main content

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15])

fig = make_subplots(rows=1, cols=1, horizontal_spacing=0.1, vertical_spacing=0.1)

for i, model in enumerate(models):
    ys = lm_performance["qa"]["2"][model]['F1']
    fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (gold)', line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=1, col=1)
    ys = ret_lm_performance["qa"]["2"][model]['F1']
    fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (all)', line_shape='spline', line_dash='dot', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=1, col=1)

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.7,
    xanchor="left",
    x=-0.05,
    font_size=10
))
fig.update_layout(font_size=10)
fig.update_layout(autosize=False, width=370, height=300)
fig.update_annotations(font=dict(size=10))
fig.update_layout(
                  xaxis_title='# of retrieved statements',
                  yaxis_title='Token overlap F1 score',)
fig.update_layout(margin=dict(l=0,r=0,b=0,t=20))
fig.show()
# fig.write_image("images/qa_all_vs_gold.pdf", engine="kaleido")

In [None]:
plot_additional_distracting("qa")

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])

fig = make_subplots(rows=1, cols=2, subplot_titles=("EntailmentBank-Hard", "StrategyQA"),horizontal_spacing=0.06, x_title='# of retrieved statements')

for eb in [2, 4]:
    if eb > 3:
      col=2
    else:
      col=1
    for i, model in enumerate(models):
        ys = ret_lm_performance["qa"][str(eb)][model]["F1"]
        if eb == 2:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=1, col=col)
        else:
            fig.add_trace(go.Scatter(x=x[:6], y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=False), row=1, col=col)

fig.update_yaxes(range=[0, 75])

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.48,
    xanchor="left",
    x=0.0,
    font_size=12
))
fig.update_layout(font_size=12)
fig.update_layout(autosize=False, width=500, height=300)
fig.update_annotations(font=dict(size=12))
fig.update_layout(yaxis1_title='Token overlap F1 score')
fig.update_layout(margin=dict(l=0,r=0,b=45,t=20))
fig.show()
fig.write_image("images/qa-EB2-SQA.pdf", engine="kaleido")

In [None]:
eb=2
metric = "F1"
models = contriever_flan_performance["qa"][str(eb)].keys()
model_name_map = {
      'Flan-T5-small': 'Flan-T5-small (80M)',
      'Flan-T5-base': 'Flan-T5-base (250M)',
      'Flan-T5-large': 'Flan-T5-large (780M)',
      'Flan-T5-xl': 'Flan-T5-xl (3B)',
      'Flan-T5-xxl': 'Flan-T5-xxl (11B)'
  }
color_map={model_name_map[m]:colors[1:][i] for i,m in enumerate(models)}
color_map['Flan-T5-xxl (11B)'] = colors[6]
color_map['Flan-T5-small (80M)'] = colors[5]

x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])

fig = go.Figure()

for i, (model, ys) in enumerate(contriever_flan_performance["qa"][str(eb)].items()):
        fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys[metric]], name=model_name_map[model], line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model_name_map[model]], showlegend=True))

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(
                  yaxis_title=f'Token overlap {metric} score',
                  xaxis_title='# of retrieved statements')
fig.update_yaxes(range=[10, 55])
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.47,
    xanchor="left",
    x=0.05,
    font_size=10
))

fig.update_layout(font_size=10)
fig.update_annotations(font=dict(size=10))
fig.update_layout(autosize=False, width=350, height=250)
fig.update_layout(margin=dict(l=45,r=0,b=0,t=20))
fig.show()
fig.write_image(f"images/flans-EB2.pdf", engine="kaleido")


In [None]:
import plotly.graph_objects as go
datasets=['EB-Easy', 'EB-Hard', 'StrategyQA']

markers_dsp = ['/', '\\', 'x', '-', '|', '+', '.']
models = ["GPT-3", "Flan-T5-base", "Flan-T5-xxl"]

def plot_dsp(metric):
    fig = go.Figure()
    data = {}
    for model in models:
      data[model] = {'F1': [], 'Recall': [], 'Precision': []}
      data['DSP ({})'.format(model)] = {'F1': [], 'Recall': [], 'Precision': []}

    dsp_colors = [colors[0], 'skyblue', colors[2], colors[7], colors[6], 'lightpink']

    for ds in [1,2,4]:
      for model in models:
        for metr in dsp_scores["dsp"]["1"][model].keys():
          data[model][metr].append(float(dsp_scores["no-dsp"][str(ds)][model][metr][0])*100)
          data['DSP ({})'.format(model)][metr].append(float(dsp_scores["dsp"][str(ds)][model][metr][0])*100)

        # data[model]['recall'].append(qa_no_dsp[ds][model]['recall'][4]*100)
        # data['DSP ({})'.format(model)]['recall'].append(qa_dsp[ds][model]['recall']*100)

        # data[model]['precision'].append(qa_no_dsp[ds][model]['precision'][4]*100)
        # data['DSP ({})'.format(model)]['precision'].append(qa_dsp[ds][model]['precision']*100)

    for i, (model, scores) in enumerate(data.items()):
      fig.add_trace(go.Bar(
          name=model, x=datasets, y=scores[metric],
          marker_color=dsp_colors[i]
      ))

    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="top",
        y=1.34,
        xanchor="left",
        x=0.08,
        font_size=11
    ))
    fig.update_layout(barmode='group', showlegend=True)
    fig.update_layout(autosize=False, width=400, height=300)
    fig.update_annotations(font=dict(size=10))
    fig.update_layout(margin=dict(l=0,r=0,b=0,t=0))
    fig.update_layout(xaxis_title="Datasets", yaxis_title="Token overlap {} score".format(metric.lower()), font_size=9)
    fig.show()
    fig.write_image("images/dsp-{}-EB1-EB2-SQA.pdf".format(metric), engine="kaleido")

In [None]:
plot_dsp("F1")

#### appendix

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15])

fig = make_subplots(rows=1, cols=1, horizontal_spacing=0.1, vertical_spacing=0.1)

for i, model in enumerate(models):
    ys = lm_performance["lm"]["2"][model]['Target ranking accuracy']
    fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (gold)', line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=1, col=1)
    ys = ret_lm_performance["lm"]["2"][model]['Target ranking accuracy']
    fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model + ' (all)', line_shape='spline', line_dash='dot', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=1, col=1)

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.7,
    xanchor="left",
    x=-0.05,
    font_size=10
))
fig.update_layout(font_size=10)
fig.update_layout(autosize=False, width=370, height=300)
fig.update_annotations(font=dict(size=10))
fig.update_layout(
                  xaxis_title='# of retrieved statements',
                  yaxis_title='Accuracy',)
fig.update_layout(margin=dict(l=0,r=0,b=0,t=20))
fig.show()
fig.write_image("images/lm_all_vs_gold.pdf", engine="kaleido")

In [None]:
plot_additional_distracting("lm", bound=[0,100])

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])

fig = make_subplots(rows=1, cols=2, subplot_titles=("EntailmentBank-1 (Easy)", "EntailmentBank-3",),horizontal_spacing=0.06, vertical_spacing=0.07, x_title='# of retrieved statements',
                    y_title='Token overlap F1 score')

for eb in [1, 3]:
    row = 1
    if eb < 2:
        col = 1
    else:
        col = 2
    for i, model in enumerate(models):
        ys = ret_lm_performance["qa"][str(eb)][model]["F1"]
        # if eb == 4 and model == 'REALM':
        #     fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_size=9, marker_color=color_map[model], showlegend=False), row=row, col=col)
        if eb == 1:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=True), row=row, col=col)
        else:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys], name=model, line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model], showlegend=False), row=row, col=col)


fig.update_yaxes(range=[0, 60])

fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.25,
    xanchor="left",
    x=0.07,
    font_size=10
))
fig.update_layout(font_size=10)
fig.update_layout(autosize=False, width=700, height=250)
fig.update_annotations(font=dict(size=10))
fig.update_layout(margin=dict(l=55,r=0,b=45,t=20))
fig.show()
fig.write_image("images/qa-EB1-EB3.pdf", engine="kaleido")

In [None]:
x = np.array([1, 2, 3, 4, 5, 10, 15, 20, 25])
models = contriever_flan_performance["qa"][str(eb)].keys()
model_name_map = {
      'Flan-T5-small': 'Flan-T5-small (80M)',
      'Flan-T5-base': 'Flan-T5-base (250M)',
      'Flan-T5-large': 'Flan-T5-large (780M)',
      'Flan-T5-xl': 'Flan-T5-xl (3B)',
      'Flan-T5-xxl': 'Flan-T5-xxl (11B)'
  }
color_map={model_name_map[m]:colors[1:][i] for i,m in enumerate(models)}
color_map['Flan-T5-xxl (11B)'] = colors[6]
color_map['Flan-T5-small (80M)'] = colors[5]

fig = make_subplots(rows=2, cols=1, subplot_titles=("EntailmentBank-Easy", "StrategyQA"), horizontal_spacing=0.1, vertical_spacing=0.1)
for eb in [1, 4]:
    if eb > 3:
      row = 2
    else:
      row = 1
    for i, (model, ys) in enumerate(contriever_flan_performance["qa"][str(eb)].items()):
        if eb == 1:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys[metric]], name=model_name_map[model], line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model_name_map[model]], showlegend=True), row=row, col=1)
        else:
            fig.add_trace(go.Scatter(x=x, y=[float(y)*100 for y in ys[metric]], name=model_name_map[model], line_shape='spline', marker_symbol=symbols[i], marker_color=color_map[model_name_map[model]], showlegend=False), row=row, col=1)


fig.update_xaxes(type="log", tickvals=x)
fig.update_layout(font_size=11)
fig.update_annotations(font=dict(size=12))
fig.update_layout(autosize=False, width=350, height=400)

fig.update_layout(legend=dict(
    orientation="h",
    yanchor="top",
    y=1.3,
    xanchor="left",
    x=0.01,
    font_size=10
))
fig.update_layout(
                  yaxis1_title=f'Token overlap {metric} score',
                  yaxis2_title=f'Accuracy',
                  xaxis2_title='# of retrieved statements',
                  font_size=11)
fig.update_layout(margin=dict(l=0,r=0,b=0,t=20))
fig.show()
fig.write_image(f"images/flans-EB1-SQA.pdf", engine="kaleido")


In [None]:
plot_dsp("Recall")