<center><h1>Contextualising ancient texts with generative neural networks</h1></center>

<center>
<i>Yannis Assael<sup>*</sup>, Thea Sommerschield<sup>*</sup>, Alison Cooley, Brendan Shillingford, John Pavlopoulos, Priyanka Suresh, Bailey Herms, Justin Grayston, Benjamin Maynard, Nicholas Dietrich, Robbe Wulgaert, Jonathan Prag, Alex Mullen, Shakir Mohamed</i>
</center>
<br>

Human history is born in writing. Inscriptions, among the earliest written forms, offer direct insights into the thought, language, and history of ancient civilisations. Historians capture these insights by identifying parallels — inscriptions with shared phrasing, function, or cultural setting — to enable the contextualisation of texts within broader historical frameworks, and perform key tasks such as restoration and geographical or chronological attribution. However, current digital methods are restricted to literal matches and narrow historical scopes. We introduce Aeneas, the first generative neural network for contextualising ancient texts. Aeneas retrieves textual and contextual parallels, leverages visual inputs, handles arbitrary-length text restoration, and advances the state-of-the-art in key tasks. To evaluate its impact, we conduct the largest Historian-AI study to date, with historians considering Aeneas’ retrieved parallels useful research starting points in 90% of cases, improving their confidence in key tasks by 44%. Restoration and geographical attribution tasks yielded superior results when historians were paired with Aeneas, outperforming both humans and AI alone. For dating, Aeneas achieved a 13-year distance from ground-truth ranges. We demonstrate Aeneas’ contribution to historical workflows through analysis of key traits in the *Res Gestae Divi Augusti*, the most renowned Roman inscription, showing how integrating Science and Humanities can create transformative tools to assist historians and advance our understanding of the past.
<br><br>

---
### References

-   [Nature article](https://www.nature.com/articles/s41586-025-09292-5)
-   [Google DeepMind blog](https://deepmind.google/discover/blog/aeneas-transforms-how-historians-connect-the-past)
-   [Aeneas in the classroom](https://www.robbewulgaert.be/education/predicting-the-past-aeneas)

When using any of the source code or outputs of this project please cite:

<textarea readonly rows=9 cols=90>
@article{asssome2025contextualising,
  title={Contextualising ancient texts with generative neural networks},
  author={Assael*, Yannis and Sommerschield*, Thea and Cooley, Alison and Pavlopoulos, John and Shillingford, Brendan and Herms, Bailey and Suresh, Priyanka and Maynard, Benjamin and Grayston, Justin and Wulgaert, Robbe and Prag, Jonathan and Mullen, Alex and Mohamed, Shakir},
  journal={Nature},
  volume={643},
  number={8073},
  year={2025},
  publisher={Nature Publishing}
}
</textarea>

### License

```
Copyright 2025 the Aeneas Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```
---

In [None]:
#@markdown
%%html
<style>
  * { font-family: Roboto, Noto, sans-serif; }
  h1 { font-weight: 400; }
  p, ol {
    font-size: 16px;
    line-height: 1.6;
  }
  kbd, .btn {
    display: inline-block;
    background: #eee;
    border: 1px solid #aaa;
    border-radius: 3px;
    padding: 0 5px;
    margin: 0 2px;
  }
  .section-name {
    font-weight: 600;
    font-style: italic;
  }
</style>
<h1>Interactive notebook instructions</h1>
<p><i>Salve</i> (or welcome) to the Interactive Notebook of Aeneas.</p>
<p>Please follow the instructions below to begin contextualising, restoring and attributing Latin inscriptions.</p>
<ol>
  <li>
    Execute each of the cells below using <kbd>Shift</kbd>+<kbd>Enter</kbd> to prepare Aeneas.
  </li>
  <li>
    In the <span class="section-name">Setup</span> section, which you should execute but can skip reading, we:
    <ul>
      <li>install the necessary dependencies,</li>
      <li>download the model checkpoint,</li>
      <li>load the imports, and</li>
      <li>create the model.</li>
    </ul>
  </li>
  <li>
    In the section <span class="section-name">Use Aeneas for your research</span>, input 25 to 750 characters of Latin text to process with Aeneas.
  </li>
  <li>
    The <span class="section-name">Contextualisation</span> section, you will provide Aeneas&#39;s top contextualisations ranked by similarity.
  </li>
  <li>
    After executing the <span class="section-name">Restoration</span> section, you will see Aeneas&#39;s top 20 restoration hypotheses ranked by
    probability, along with the saliency map for Aeneas&#39;s top restoration.
  </li>
  <li>
    In the <span class="section-name">Attribution</span> section, one can find:
    <ul>
      <li>a bar chart and geographical map showing the top 10 geographical attribution hypotheses, ranked by probability among 62 Roman provinces;</li>
      <li>a categorical distribution over all decades from 800 BCE to 800 CE (the chronological attribution predictive distribution); and</li>
      <li>saliency maps for geographical and chronological attributions.</li>
    </ul>
  </li>
</ol>
<p>If you wish to save your changes or outputs, click <span class="btn">File</span> → <span class="btn">Save a copy in Drive</span>. This is a read-only notebook, so <b>changes are not saved by default</b>.</p>

# Setup

## Select language

In [None]:
#@title { run: "auto" }
import requests

resource_language = "Latin" #@param ["Greek", "Latin"]

RESOURCES = {
    "Latin": {
        "checkpoint.pkl": "https://storage.googleapis.com/ithaca-resources/models/aeneas_117149994_2.pkl",
        "dataset.json": "https://storage.googleapis.com/ithaca-resources/models/led.json",
        "dataset_emb.pkl": "https://storage.googleapis.com/ithaca-resources/models/led_emb_xid117149994.pkl"
    },
    "Greek": {
        "checkpoint.pkl": "https://storage.googleapis.com/ithaca-resources/models/ithaca_153143996_2.pkl",
        "dataset.json": "https://storage.googleapis.com/ithaca-resources/models/iphi.json",
        "dataset_emb.pkl": "https://storage.googleapis.com/ithaca-resources/models/iphi_emb_xid153143996.pkl"
    }
}

# Download files for the currently selected language
print(f"Downloading resources for: {resource_language}")
for filename, url in RESOURCES[resource_language].items():
    print(f"Fetching {filename}...")
    try:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
    except requests.exceptions.RequestException as e:
        print(f" -> Error: {e}")

print("Download process complete.")

# Map coordinates
if resource_language == 'Latin':
  locations_ll = [
      {
        "name": "Achaia",
        "lat": 38.099043,
        "lng": 22.4314905
      },
      {
        "name": "Aegyptus",
        "lat": 31.2001,
        "lng": 29.9187
      },
      {
        "name": "Aemilia",
        "lat": 44.4949,
        "lng": 11.3426
      },
      {
        "name": "Africa Proconsularis",
        "lat": 32.5,
        "lng": 12.5
      },
      {
        "name": "Alpes Cottiae",
        "lat": 45.1396,
        "lng": 7.0703
      },
      {
        "name": "Alpes Graiae",
        "lat": 45.62,
        "lng": 6.6819
      },
      {
        "name": "Alpes Maritimae",
        "lat": 43.7102,
        "lng": 7.262
      },
      {
        "name": "Alpes Poeninae",
        "lat": 46.1036,
        "lng": 7.0731
      },
      {
        "name": "Apulia et Calabria",
        "lat": 40.4748,
        "lng": 17.2385
      },
      {
        "name": "Aquitania",
        "lat": 44.8378,
        "lng": -0.5792
      },
      {
        "name": "Arabia",
        "lat": 32.5,
        "lng": 37.5
      },
      {
        "name": "Armenia",
        "lat": 39.5,
        "lng": 40.5
      },
      {
        "name": "Asia",
        "lat": 37.945,
        "lng": 27.3417
      },
      {
        "name": "Baetica",
        "lat": 37.8882,
        "lng": -4.7794
      },
      {
        "name": "Barbaricum",
        "lat": 51.3127,
        "lng": 9.4797
      },
      {
        "name": "Belgica",
        "lat": 49.2583,
        "lng": 4.0317
      },
      {
        "name": "Bithynia et Pontus",
        "lat": 40.7667,
        "lng": 29.9167
      },
      {
        "name": "Britannia",
        "lat": 52.5,
        "lng": -2.5
      },
      {
        "name": "Bruttium et Lucania",
        "lat": 38.1105,
        "lng": 15.6501
      },
      {
        "name": "Cappadocia",
        "lat": 39.25,
        "lng": 35.75
      },
      {
        "name": "Cilicia",
        "lat": 36.25,
        "lng": 33.25
      },
      {
        "name": "Corsica",
        "lat": 42.0978,
        "lng": 9.4544
      },
      {
        "name": "Creta",
        "lat": 35.20052306,
        "lng": 25.00709816
      },
      {
        "name": "Cyprus",
        "lat": 34.7768,
        "lng": 32.4245
      },
      {
        "name": "Cyrene",
        "lat": 32.49965333,
        "lng": 20.87174333
      },
      {
        "name": "Dacia",
        "lat": 45.5,
        "lng": 22.5
      },
      {
        "name": "Dalmatia",
        "lat": 43.5225,
        "lng": 16.4739
      },
      {
        "name": "Epirus",
        "lat": 39.54648402,
        "lng": 20.78770214
      },
      {
        "name": "Etruria",
        "lat": 43.7696,
        "lng": 11.2558
      },
      {
        "name": "Galatia",
        "lat": 39.75,
        "lng": 32.75
      },
      {
        "name": "Germania inferior",
        "lat": 50.9375,
        "lng": 6.9603
      },
      {
        "name": "Germania superior",
        "lat": 50.0011,
        "lng": 8.267
      },
      {
        "name": "Hispania citerior",
        "lat": 41.1189,
        "lng": 1.245
      },
      {
        "name": "Iudaea",
        "lat": 32.5183,
        "lng": 34.9083
      },
      {
        "name": "Latium et Campania",
        "lat": 40.8518,
        "lng": 14.2681
      },
      {
        "name": "Liguria",
        "lat": 44.4056,
        "lng": 8.9463
      },
      {
        "name": "Lugdunensis",
        "lat": 45.757,
        "lng": 4.832
      },
      {
        "name": "Lusitania",
        "lat": 38.9159,
        "lng": -6.3437
      },
      {
        "name": "Lycia et Pamphylia",
        "lat": 36.8969,
        "lng": 30.7133
      },
      {
        "name": "Macedonia",
        "lat": 41.25,
        "lng": 21.75
      },
      {
        "name": "Mauretania Caesariensis",
        "lat": 36.52285575,
        "lng": 3.6414695
      },
      {
        "name": "Mauretania Tingitana",
        "lat": 34.0725,
        "lng": -5.5548
      },
      {
        "name": "Mesopotamia",
        "lat": 37.5,
        "lng": 39.5
      },
      {
        "name": "Moesia inferior",
        "lat": 44.1728,
        "lng": 28.635
      },
      {
        "name": "Moesia superior",
        "lat": 42.5,
        "lng": 22.5
      },
      {
        "name": "Narbonensis",
        "lat": 43.1838,
        "lng": 3.0045
      },
      {
        "name": "Noricum",
        "lat": 46.7167,
        "lng": 14.4167
      },
      {
        "name": "Numidia",
        "lat": 36.607058,
        "lng": 2.1918495
      },
      {
        "name": "Pannonia inferior",
        "lat": 47.5333,
        "lng": 19.05
      },
      {
        "name": "Pannonia superior",
        "lat": 48.0772,
        "lng": 16.8583
      },
      {
        "name": "Picenum",
        "lat": 42.8542,
        "lng": 13.575
      },
      {
        "name": "Raetia",
        "lat": 48.3705,
        "lng": 10.8978
      },
      {
        "name": "Regnum Bospori",
        "lat": 45.3564,
        "lng": 36.4718
      },
      {
        "name": "Roma",
        "lat": 41.9028,
        "lng": 12.4964
      },
      {
        "name": "Sabina et Samnium",
        "lat": 41.1385,
        "lng": 14.775
      },
      {
        "name": "Sardinia",
        "lat": 39.2238,
        "lng": 9.1217
      },
      {
        "name": "Sicilia, Melita",
        "lat": 37.0692,
        "lng": 15.2875
      },
      {
        "name": "Syria",
        "lat": 36.2021,
        "lng": 36.16
      },
      {
        "name": "Thracia",
        "lat": 41.5,
        "lng": 25.0
      },
      {
        "name": "Transpadana",
        "lat": 45.4642,
        "lng": 9.19
      },
      {
        "name": "Umbria",
        "lat": 43.1122,
        "lng": 12.3888
      },
      {
        "name": "Venetia et Histria",
        "lat": 45.7722,
        "lng": 13.37
      }
    ]
elif resource_language == 'Greek':
  locations_ll = [
      {
        "name": "Achaia",
        "lat": 38.099043,
        "lng": 22.4314905
      },
      {
        "name": "Aeolis",
        "lat": 38.84330517,
        "lng": 27.01538508
      },
      {
        "name": "Africa Proconsularis",
        "lat": 32.5,
        "lng": 12.5
      },
      {
        "name": "Amorgos and vicinity",
        "lat": 36.833333,
        "lng": 25.9
      },
      {
        "name": "Arabia",
        "lat": 32.5,
        "lng": 37.5
      },
      {
        "name": "Arabian Peninsula",
        "lat": 29.5,
        "lng": 45.5
      },
      {
        "name": "Arachosia, Drangiana",
        "lat": 32.5,
        "lng": 62.5
      },
      {
        "name": "Arkadia",
        "lat": 37.61781583,
        "lng": 22.17000731
      },
      {
        "name": "Armenia",
        "lat": 39.5,
        "lng": 40.5
      },
      {
        "name": "Attica",
        "lat": 37.97278669,
        "lng": 23.99374594
      },
      {
        "name": "Babylonia",
        "lat": 32.5,
        "lng": 44.5
      },
      {
        "name": "Bactria, Sogdiana",
        "lat": 36.76782566,
        "lng": 66.9010688
      },
      {
        "name": "Bithynia",
        "lat": 40.7561925,
        "lng": 31.5858565
      },
      {
        "name": "Britannia",
        "lat": 52.5,
        "lng": -2.5
      },
      {
        "name": "Byzacena",
        "lat": 32.5,
        "lng": 7.5
      },
      {
        "name": "Cappadocia",
        "lat": 39.25,
        "lng": 35.75
      },
      {
        "name": "Caria",
        "lat": 37.042901,
        "lng": 27.420201
      },
      {
        "name": "Carmania",
        "lat": 29.0,
        "lng": 57.5
      },
      {
        "name": "Chios",
        "lat": 38.414,
        "lng": 26.053
      },
      {
        "name": "Cilicia",
        "lat": 36.25,
        "lng": 33.25
      },
      {
        "name": "Commagene",
        "lat": 36.526222,
        "lng": 37.9555335
      },
      {
        "name": "Cos and Calymna",
        "lat": 36.844,
        "lng": 27.17
      },
      {
        "name": "Crete",
        "lat": 35.20052306,
        "lng": 25.00709816
      },
      {
        "name": "Cyclades, excl. Delos",
        "lat": 36.92625341,
        "lng": 25.41590803
      },
      {
        "name": "Cyrenaica",
        "lat": 32.49965333,
        "lng": 20.87174333
      },
      {
        "name": "Dacia",
        "lat": 45.5,
        "lng": 22.5
      },
      {
        "name": "Delos",
        "lat": 37.393333,
        "lng": 25.271111
      },
      {
        "name": "Delphi",
        "lat": 38.482289,
        "lng": 22.501169
      },
      {
        "name": "Doric Sporades",
        "lat": 36.683333,
        "lng": 24.416667
      },
      {
        "name": "Doris",
        "lat": 38.75,
        "lng": 22.25
      },
      {
        "name": "Egypt and Nubia",
        "lat": 19.21140877,
        "lng": 30.56732963
      },
      {
        "name": "Eleusis",
        "lat": 38.041101,
        "lng": 23.537401
      },
      {
        "name": "Elis",
        "lat": 37.891781,
        "lng": 21.375091
      },
      {
        "name": "Epidauria",
        "lat": 37.6334625,
        "lng": 23.16015635
      },
      {
        "name": "Epirus",
        "lat": 39.54648402,
        "lng": 20.78770214
      },
      {
        "name": "Euboia",
        "lat": 38.53,
        "lng": 23.87
      },
      {
        "name": "Galatia",
        "lat": 39.75,
        "lng": 32.75
      },
      {
        "name": "Gallia",
        "lat": 46.70543722,
        "lng": 1.013706367
      },
      {
        "name": "Germania",
        "lat": 51.6054836,
        "lng": 5.795502625
      },
      {
        "name": "Hispania citerior and Lusitania",
        "lat": 40.18650564,
        "lng": -3.736305805
      },
      {
        "name": "Hyrcania, Parthia",
        "lat": 38.5,
        "lng": 57.5
      },
      {
        "name": "Iberia and Colchis",
        "lat": 41.833638,
        "lng": 44.672277
      },
      {
        "name": "Ionia",
        "lat": 38.06312675,
        "lng": 27.062852
      },
      {
        "name": "Italy, incl. Magna Graecia",
        "lat": 41.891775,
        "lng": 12.486137
      },
      {
        "name": "Lakonia and Messenia",
        "lat": 37.0557778,
        "lng": 21.9809211
      },
      {
        "name": "Lesbos, Nesos, and Tenedos",
        "lat": 39.16075203,
        "lng": 26.25450897
      },
      {
        "name": "Lycaonia",
        "lat": 37.25,
        "lng": 32.75
      },
      {
        "name": "Lycia",
        "lat": 36.5135325,
        "lng": 29.129311
      },
      {
        "name": "Lydia",
        "lat": 38.32544739,
        "lng": 28.2612252
      },
      {
        "name": "Macedonia",
        "lat": 41.25,
        "lng": 21.75
      },
      {
        "name": "Mauretania Caesariensis",
        "lat": 36.52285575,
        "lng": 3.6414695
      },
      {
        "name": "Mauretania Tingitana",
        "lat": 34.0725,
        "lng": -5.5548
      },
      {
        "name": "Media",
        "lat": 34.5,
        "lng": 46.5
      },
      {
        "name": "Megaris, Oropia, and Boiotia",
        "lat": 38.2544053,
        "lng": 23.1821091
      },
      {
        "name": "Mesopotamia",
        "lat": 37.5,
        "lng": 39.5
      },
      {
        "name": "Moesia superior",
        "lat": 42.5,
        "lng": 22.5
      },
      {
        "name": "Mysia",
        "lat": 39.13121333,
        "lng": 27.18453033
      },
      {
        "name": "Mysia Kaikos, Pergamon",
        "lat": 39.13121333,
        "lng": 27.18453033
      },
      {
        "name": "Mysia Upper Kaikos / Lydia",
        "lat": 39.61497321,
        "lng": 27.87494434
      },
      {
        "name": "Northern Aegean",
        "lat": 40.683333,
        "lng": 24.65
      },
      {
        "name": "Numidia",
        "lat": 36.607058,
        "lng": 2.1918495
      },
      {
        "name": "Osrhoene",
        "lat": 37.15,
        "lng": 38.79
      },
      {
        "name": "Palaestina",
        "lat": 31.25,
        "lng": 34.75
      },
      {
        "name": "Pamphylia",
        "lat": 36.990721,
        "lng": 30.98638
      },
      {
        "name": "Persis",
        "lat": 30.5,
        "lng": 51.5
      },
      {
        "name": "Phokis, Lokris, Aitolia, Akarnania, and Ionian Islands",
        "lat": 38.763022,
        "lng": 21.06285
      },
      {
        "name": "Phrygia",
        "lat": 38.75,
        "lng": 29.75
      },
      {
        "name": "Pisidia",
        "lat": 37.25,
        "lng": 30.75
      },
      {
        "name": "Pontus and Paphlagonia",
        "lat": 41.4407791,
        "lng": 33.25503893
      },
      {
        "name": "Raetia, Noricum, and Pannonia",
        "lat": 47.727498,
        "lng": 10.326578
      },
      {
        "name": "Rhamnous",
        "lat": 38.22219067,
        "lng": 24.02740133
      },
      {
        "name": "Rhodes and S. Dodecanese",
        "lat": 36.195597,
        "lng": 27.964125
      },
      {
        "name": "Samos",
        "lat": 37.73,
        "lng": 26.84
      },
      {
        "name": "Saronic Gulf, Corinthia, and the Argolid",
        "lat": 37.80840845,
        "lng": 22.86776751
      },
      {
        "name": "Scythia Minor",
        "lat": 44.5,
        "lng": 28.5
      },
      {
        "name": "Sicily, Sardinia",
        "lat": 37.67337489,
        "lng": 13.90469368
      },
      {
        "name": "Susiana",
        "lat": 32.5,
        "lng": 48.5
      },
      {
        "name": "Syria and Phoenicia",
        "lat": 33.25,
        "lng": 35.25
      },
      {
        "name": "Thessaly",
        "lat": 39.6,
        "lng": 22.2
      },
      {
        "name": "Thrace",
        "lat": 41.5,
        "lng": 25.0
      },
      {
        "name": "Tripolitania",
        "lat": 33.17039367,
        "lng": 10.90091267
      },
      {
        "name": "Troas",
        "lat": 39.813107,
        "lng": 26.164143
      }
    ]

In [None]:
# @title Install Aeneas
!pip install -q git+https://github.com/google-deepmind/predictingthepast || echo "*** FAILED TO INSTALL predictingthepast ***"

In [None]:
# @title Install additional packages
!pip install -q geopandas folium ipywidgets

In [None]:
# @title Imports
from IPython.display import HTML, display
try:
  from google.colab import files
  environment = 'colab'
except ModuleNotFoundError:
  import ipywidgets as widgets
  environment = 'ipython'

try:
  import predictingthepast
  import flax
except ModuleNotFoundError:
  display(
      HTML(
          '<h1><font color="#f00">Failed to import predictingthepast. Did'
          ' installation fail above?</font></h1>'
      )
  )
  raise

import functools
import io
import pickle
from PIL import Image

from flax import linen as nn
import folium
import jax
import jinja2
import matplotlib.pyplot as plt
from ml_collections import config_dict
import numpy as np

from predictingthepast.eval import inference
from predictingthepast.models.model import Model
from predictingthepast.util import alphabet as util_alphabet

In [None]:
# @title Configuration and auxiliary functions

class dataset_config:
  date_interval = 10
  date_max = 800
  date_min = -800


def bce_ce(d):
  if d is None:
    return ''
  if d < 0:
    return f'{abs(d)} BCE'
  elif d > 0:
    return f'{abs(d)} CE'
  return '0'


SALIENCY_SNIPPET_TEMPLATE = jinja2.Template("""
<div class="saliency">
  {% for char, score in pairs -%}
    <span
      style="background-color: rgba(248,127,195,{{'%.2f'|format(score)}});"
      title="Saliency score {{'%.2f'|format(score)}}">{{ char }}</span>
  {%- endfor %}
</div>
""")

SALIENCY_TEMPLATE = jinja2.Template("""<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
<style>
body { font-family: 'Roboto Mono', monospace; }
.saliency {
  word-wrap: break-word;
  white-space: normal;
}
</style>
</head>
<body>
{{body_html|safe}}
</body>
</html>
""")


def generate_saliency(text, saliency, snippet=False):
  """Generates saliency visualisation."""
  snippet_html = SALIENCY_SNIPPET_TEMPLATE.render(
      pairs=list(zip(text, saliency))
  )
  if snippet:
    return snippet_html
  return SALIENCY_TEMPLATE.render(body_html=snippet_html)


def load_checkpoint(path, language):
  """Loads a checkpoint pickle.

  Args:
    path: path to checkpoint pickle
    language: language of the model (latin/greek)

  Returns:
    a model config dictionary (arguments to the model's constructor), a dict of
    dicts containing region mapping information, an Alphabet instance with
    indices and words populated from the checkpoint, a dict of Jax arrays
    `params`, and a `forward` function.
  """

  # Pickled checkpoint dict containing params and various config:
  with open(path, 'rb') as f:
    checkpoint = pickle.load(f)

  # We reconstruct the model using the same arguments as during training, which
  # are saved as a dict in the "model_config" key, and construct a `forward`
  # function of the form required by attribute() and restore().
  params = jax.device_put(checkpoint['params'])
  model = Model(**checkpoint['model_config'])
  forward = model.apply

  # Contains the mapping between region IDs and names:
  region_map = checkpoint['region_map']

  # Use vocabulary mapping from the checkpoint, the rest of the values in the
  # class are fixed and constant e.g. the padding symbol
  if language == 'latin':
    alphabet = util_alphabet.LatinAlphabet()
  elif language == 'greek':
    alphabet = util_alphabet.GreekAlphabet()
  else:
    raise ValueError(f'Unknown language: {language}')

  return checkpoint['model_config'], region_map, alphabet, params, forward


In [None]:
# @title Load and create model
(model_config, region_map, alphabet, params, forward) = load_checkpoint(
    'checkpoint.pkl', resource_language.lower()
)
vocab_char_size = model_config['vocab_char_size']

In [None]:
# @title Load dataset and embeddings
dataset = inference.load_dataset('dataset.json')
retrieval = inference.load_retrieval('dataset_emb.pkl')

# Use Aeneas for your research

**Enter your Latin epigraphic text, including spaces, in the box below to obtain epigraphic parallels, restore missing characters, and attribute the inscription to its original place and time of writing.**

**Use a question mark (?)** for each character you want the model to predict: each query can predict up to 20 question marks (consecutive or not), and spaces count toward this limit.

**Use a single hash (#)** to predict text sequences of unknown length, and adjust the sampling temperature (controls how creative or conservative the restoration outputs are—lower is more like a formulaic funerary text, higher is more like a dedicatory verse inscription) and the maximum expected length of the restoration using the bar sliders (note: the total number of characters predicted across both unknown (#) and known (?) gaps cannot exceed the maximum length you set).

The text should be between 50 and 750 characters long.

In [None]:
# @title  { run: "auto", vertical-output: true }
text = "Deab #an et Tutelae loci pro salute et incolmitate sua suorumq omnium L Maiorius Cogitatus bficiarius cosularis vot sol l l m Idibus Iulis Gentiano et Basso cosulibus"  # @param {type:"string"}
arg_beam_width = 100  # @param {type:"integer"}
arg_max_restoration_len = 20  # @param {type:"integer"}
arg_restoration_temperature = 1.0  # @param {type:"number"}
assert (
    25 <= len(text) <= 767
), "text should be between 25 and 767 chars long, got " + str(len(text))

In [None]:
# @title Input image (optional for chronological attribution)  { run: "auto", vertical-output: true }

upload_image = False #@param {type:"boolean"}

vision_img = None
if upload_image:
  if environment == 'colab':
    uploader = files.upload()

    # Check if any file was uploaded
    if len(uploader.keys()) > 0:
      # Get the first uploaded file's name and content
      file_name = next(iter(uploader))
      image_bytes = uploader[file_name]
    else:
      print('No file was uploaded.')
  elif environment == 'ipython':
    uploader = widgets.FileUpload(
        accept='image/*',
        multiple=False
    )
    display(uploader)

# Contextualisation
[Back to top](#scrollTo=oTrYS2DhDmcX)
[Restoration](#scrollTo=p9EmORGaxn8u)
[Attribution](#scrollTo=M3h8nL22HvTh)

In [None]:
# @title Parallels
# @markdown **Aeneas’s list of the most relevant contextual parallels.**
# @markdown
# @markdown To include as much historical context as possible, we extend the list of retrieved parallels to incorporate the validation and test sets. These sets are not used for training and do not affect the model's predictions.

contextualization_results = inference.contextualize(
    text=text,
    dataset=dataset,
    retrieval=retrieval,
    forward=forward,
    params=params,
    alphabet=alphabet,
    region_map=region_map
)

contextualization_template = jinja2.Template("""
<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap" rel="stylesheet">
<style>
body {
  font-family: 'Roboto', sans-serif;
  margin: 0;
  padding: 24px;
}
.results-grid {
  display: grid;
  grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
  gap: 24px;
}
.card {
  background-color: white;
  border: 1px solid #dee2e6;
  border-radius: 8px;
  padding: 24px;
  display: flex;
  flex-direction: column;
  position: relative;
  box-shadow: 0 1px 3px rgba(0,0,0,0.05);
  transition: box-shadow 0.2s ease-in-out;
}
.card:hover {
    box-shadow: 0 4px 12px rgba(0,0,0,0.1);
}
.card-text {
  font-size: 16px;
  line-height: 1.5;
  color: #343a40;
  margin: 0 0 20px 0;
  height: 72px; /* Fixed height for ~3 lines */
  overflow: hidden;
}
.card.is-expanded .card-text {
  height: auto;
}
.plus-button {
  position: absolute;
  top: 16px;
  right: 16px;
  width: 32px;
  height: 32px;
  border-radius: 50%;
  border: 1px solid #ced4da;
  background-color: white;
  color: #495057;
  font-size: 24px;
  font-weight: 300;
  display: flex;
  align-items: center;
  justify-content: center;
  cursor: pointer;
  line-height: 1;
  user-select: none;
}
.metadata {
  font-size: 14px;
  color: #495057;
  border-top: 1px solid #e9ecef;
  padding-top: 16px;
  margin-top: auto;
}
.metadata div {
  margin-bottom: 6px;
  display: flex;
}
.metadata strong {
  font-weight: 700;
  color: #212529;
  min-width: 120px;
  flex-shrink: 0;
}
.metadata a {
    color: #007bff;
    text-decoration: none;
}
.metadata a:hover {
    text-decoration: underline;
}
.score-bar-container {
  width: 100%;
  height: 8px;
  background-color: #e9ecef;
  border-radius: 4px;
  margin-top: 12px;
  overflow: hidden;
}
.score-bar {
  height: 100%;
  background-color: #ffc107;
  border-radius: 4px;
}
</style>
</head>
<body>

<div class="results-grid">
  {% for i in range(results.ids|length) %}
    <div class="card" id="card-{{ loop.index }}">
      <div>
        <div class="plus-button" onclick="toggleText(this)">+</div>
        <p class="card-text">
          {{- results.text[i] -}}
        </p>
      </div>
      <div class="metadata">
        <div><strong>Record Number:</strong> {{ results.ids[i] }}</div>

        {# Display alternative IDs if they exist #}
        {% set alt_ids = results.ids_alt[i] %}
        {% if alt_ids %}
          {% for key, value in alt_ids.items() %}
            <div><strong>{{ key.replace('_', ' ')|title }}:</strong> {{ value }}</div>
          {% endfor %}
        {% endif %}

        {# Use the region_map to get province name #}
        {% set loc_id = results.location_ids[i] %}
        {% if loc_id is not none %}
            {% if region_map and region_map.names %}
                <div><strong>Province:</strong> {{ region_map.names[loc_id] }}</div>
            {% else %}
                <div><strong>Province ID:</strong> {{ loc_id }}</div>
            {% endif %}
        {% endif %}

        <div>
            <strong>Date:</strong>
            {% set d_min = bce_ce(results.date_min[i]) %}
            {% set d_max = bce_ce(results.date_max[i]) %}
            {% if d_min is not none and d_max is not none %}
              {{ d_min }} – {{ d_max }}
            {% elif d_min is not none %}
              from {{ d_min }}
            {% elif d_max is not none %}
              until {{ d_max }}
            {% else %}
              Not Available
            {% endif %}
        </div>

        {# Display partner link if it exists #}
        {% set link = results.partner_link[i] %}
        {% if link %}
          <div><strong>Source:</strong> <a href="{{ link }}" target="_blank" rel="noopener noreferrer">View on partner site</a></div>
        {% endif %}

        <div class="score-bar-container">
          <div class="score-bar" style="width: {{ results.score[i] * 100 }}%;"></div>
        </div>
      </div>
    </div>
  {% endfor %}
</div>

<script>
  function toggleText(button) {
    const card = button.closest('.card');
    if (card) {
      const isExpanded = card.classList.toggle('is-expanded');
      button.textContent = isExpanded ? '−' : '+';
    }
  }
</script>

</body>
</html>
""")

display(
    HTML(
        contextualization_template.render(
            results=contextualization_results,
            region_map=region_map,
            bce_ce=bce_ce
        )
    )
)

# Restoration
[Back to top](#scrollTo=oTrYS2DhDmcX)
[Contextualisation](#scrollTo=ZCV76QbofBxU)
[Attribution](#scrollTo=M3h8nL22HvTh)



In [None]:
# @title Restoration hypotheses
# @markdown Aeneas’s list of top restoration hypotheses weighted by length and probability.
# @markdown
# @markdown This visualization enables the pairing of Aeneas's suggestions with historians'
# @markdown contextual knowledge. Aeneas uses non-sequential beam search, where each beam starts
# @markdown with the prediction scoring the highest confidence, then proceeds iteratively at each timestep to restore
# @markdown the characters with the highest certainty (probability according to the model).
# @markdown
# @markdown **This step may take several minutes.**

restoration_results = inference.restore(
    text=text,
    forward=forward,
    params=params,
    alphabet=alphabet,
    vocab_char_size=vocab_char_size,
    beam_width=arg_beam_width,
    temperature=arg_restoration_temperature,
    unk_restoration_max_len=arg_max_restoration_len,
)

template = jinja2.Template("""<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
<style>
body {
  font-family: 'Roboto Mono', monospace;
  font-weight: 400;
}
.container {
  overflow-x: scroll;
  scroll-behavior: smooth;
}
table {
  table-layout: fixed;
  font-size: 16px;
  padding: 0;
  white-space: nowrap;
}
table tr:first-child {
  font-weight: bold;
}
table td {
  border-bottom: 1px solid #ccc;
  padding: 3px 0;
}
table td.header {
  font-family: Roboto, Helvetica, sans-serif;
  text-align: right;
  position: -webkit-sticky;
  position: sticky;
  background-color: white;
}
.header-1 {
  background-color: white;
  width: 120px;
  min-width: 120px;
  max-width: 120px;
  left: 0;
}
.header-2 {
  left: 120px;
  width: 50px;
  max-width: 50px;
  min-width: 50px;
  padding-right: 5px;
}
table td:not(.header) {
  border-left: 1px solid black;
  padding-left: 5px;
}
.header-2col {
  width: 170px;
  min-width: 170px;
  max-width: 170px;
  left: 0;
  padding-right: 5px;
}
.pred {
  background: #ddd;
}
</style>
</head>
<body>
Scroll sideways to see all the text if it is wider than the screen.
<button id="btn">jump to restoration area</button>
<div class="container">
<table cellspacing="0">
  <tr>
    <td colspan="2" class="header header-2col">Input text:</td>
    <td>
    {% for char in restoration_results.input_text -%}
      {%- if loop.index0 in restoration_results.missing -%}
        <span class="pred">{{char}}</span>
      {%- else -%}
        {{char}}
      {%- endif -%}
    {%- endfor %}
    </td>
  </tr>
  <!-- Predictions: -->
  {% for pred in restoration_results.predictions %}
  <tr>
    <td class="header header-1">Hypothesis {{ loop.index }}:</td>
    <td class="header header-2">{{ "%.1f%%"|format(100 * pred.score) }}</td>
    <td>
      {% for char in pred.text -%}
        {%- if loop.index0 in pred.restored -%}
          <span class="pred">{{char}}</span>
        {%- else -%}
          {{char}}
        {%- endif -%}
      {%- endfor %}
    </td>
  </tr>
  {% endfor %}
</table>
</div>
<script>
document.querySelector('#btn').addEventListener('click', () => {
  const pred = document.querySelector(".pred");
  pred.scrollIntoViewIfNeeded();
});
</script>
</body>
</html>
""")
display(
    HTML(
        template.render(
            restoration_results=restoration_results,
        )
    )
)

In [None]:
# @title Restoration saliency map
# @markdown Saliency maps for each character predicted in Aeneas’s top restoration hypothesis. Highlighted in purple shading are the unique input text features which contributed most to the prediction.

template = jinja2.Template("""<!DOCTYPE html>
<html>
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet">
<style>
body {
  font-family: 'Roboto Mono', monospace;
  font-weight: 400;
}
.container {
  overflow-x: scroll;
  scroll-behavior: smooth;
}
table {
  table-layout: fixed;
  font-size: 16px;
  padding: 0;
  white-space: nowrap;
}
table tr:first-child {
  font-weight: bold;
}
table td {
  border-bottom: 1px solid #ccc;
  padding: 3px 0;
}
table td.header {
  font-family: Roboto, Helvetica, sans-serif;
  position: -webkit-sticky;
  position: sticky;
  background-color: white;
}
.header-1 {
  background-color: white;
  width: 80px;
  min-width: 80px;
  max-width: 80px;
  left: 0;
  text-align: right;
}
.header-2 {
  left: 80px;
  width: 40px;
  max-width: 40px;
  min-width: 40px;
  text-align: center;
  padding-right: 5px;
}
table td:not(.header) {
  border-left: 1px solid black;
  padding-left: 5px;
}
.header-2col {
  width: 120px;
  min-width: 120px;
  max-width: 120px;
  left: 0;
  padding-right: 5px;
}
.pred {
  background: #eee;
}
tr:hover .header-2 {
  background: #bdb;
  font-weight: bold;
}
tr:hover span.restored-pos {
  background: #bdb;
  font-weight: bold;
}
</style>
</head>
<body>
Scroll sideways to see all the text if it is wider than the screen.
<button id="btn">jump to restoration area</button>
<div class="container">
<table cellspacing="0">
  <tr>
    <td colspan="2" class="header header-2col">Input text:</td>
    <td>
    {% for char in restoration_results.input_text -%}
      {%- if loop.index0 in restoration_results.missing -%}
        <span class="pred">{{char}}</span>
      {%- else -%}
        {{char}}
      {%- endif -%}
    {%- endfor %}
    </td>
  </tr>
  {% for sal in restoration_results.prediction_saliency %}
  <tr>
    <td class="header header-1">Step {{ loop.index }}:</td>
    <td class="header header-2">{{ sal.text[sal.restored_idx].replace(" ", "⎵") }}</td>
    <td>
      {% for char in sal.text -%}
        {%- if loop.index0 == sal.restored_idx -%}
          <span class="pred restored-pos">{{char}}</span>
        {%- else -%}
          {%- set alpha = sal.saliency[loop.index0] -%}
          <span style="background: rgba(171,71,188,{{"%.2f"|format(alpha)}})"
          >{{char}}</span>
        {%- endif -%}
      {%- endfor %}
    </td>
  </tr>
  {% endfor %}
</table>
</div>
<script>
document.querySelector('#btn').addEventListener('click', () => {
  const pred = document.querySelector(".pred");
  pred.scrollIntoViewIfNeeded();
});
</script>
</body>
</html>
""")
display(
    HTML(
        template.render(
            restoration_results=restoration_results,
        )
    )
)

# Attribution
[Back to top](#scrollTo=oTrYS2DhDmcX)
[Contextualisation](#scrollTo=ZCV76QbofBxU)
[Restoration](#scrollTo=p9EmORGaxn8u)



In [None]:
# @title Geographical attribution
# @markdown Bar chart and map distribution for Aeneas's top 10 geographical attribution hypotheses, ranked by probability among 62 regions of the ancient world. The circle size on the map is directly proportional to the prediction’s probability.

# Process image
if upload_image:
  # Check if a file has been uploaded
  if environment == 'ipython' and uploader.value:
    uploaded_file = uploader.value[0]
    file_name = uploaded_file['name']
    image_bytes = uploaded_file['content']

  # Load the image from the in-memory bytes
  try:
    vision_img = Image.open(io.BytesIO(image_bytes))
  except Exception as e:
    print(f'Error: The uploaded file: {file_name} does not appear to be a valid image. Please try again.')
    print(f'PIL Error: {e}')

# Get attribution results
attribution_results = inference.attribute(
    text=text,
    forward=forward,
    params=params,
    alphabet=alphabet,
    vocab_char_size=vocab_char_size,
    vision_img=vision_img
)

locations = []
scores = []
for l in attribution_results.locations[:10]:
  locations.append(region_map['names'][l.location_id])
  scores.append(l.score)

# Generate figure
fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
y_pos = range(len(locations))
ax.barh(y_pos, scores, color='#da2b2f')

# x-axis
xticks = np.arange(0, 1.1, 0.25)
xticks_str = list(map(lambda x: f'{int(x*100)}%', xticks))
ax.set_xticks(xticks)
ax.set_xticklabels(xticks_str, fontsize=12)

# y-axis
ax.set_yticks(y_pos)
ax.set_yticklabels(locations, fontsize=12)
ax.invert_yaxis()
ax.set_xlabel('Probability', fontsize=14)
ax.set_title('Geographical attribution')

plt.show()

In [None]:
# @title Geographical attribution map
# Compute map center
center_location = np.zeros(2)
for l in attribution_results.locations:
  row = locations_ll[l.location_id]
  center_location += np.array([row['lat'], row['lng']]) * l.score

# Create map
folium_map = folium.Map(
    location=center_location, tiles='OpenStreetMap', zoom_start=6
)

# Create markers
for l in attribution_results.locations:
  if l.score < 0.01:
    continue
  row = locations_ll[l.location_id]
  folium.Circle(
      location=[row['lat'], row['lng']],
      radius=l.score * 50000,
      fill=True,
      popup=(
          f'{l.score * 100:.1f}% -'
          f' {row["name"]}'
      ),
      fill_color='darkred',
      color='darkred',
  ).add_to(folium_map)
folium_map

<small>Map tiles by <a href="http://stamen.com">Stamen Design</a> (unmodified), under <a href="http://creativecommons.org/licenses/by/3.0">CC BY 3.0</a>. Data by <a href="http://openstreetmap.org">OpenStreetMap</a>, under <a href="http://creativecommons.org/licenses/by-sa/3.0">CC BY SA</a>.</small>

In [None]:
# @title Geographical attribution saliency map
# @markdown Saliency map shows unique input text features contributed the most to Aeneas's top geographical
# @markdown attribution hypothesis, where deeper purple indicates a greater contribution.
display(
    HTML(
        generate_saliency(
            text=attribution_results.input_text,
            saliency=attribution_results.location_saliency,
            snippet=False,
        )
    )
)

In [None]:
# @title Chronological attribution
# @markdown Aeneas’s chronological attribution hypotheses, visualized as a categorical distribution over decades, in yellow, between 800 BCE and 800 CE. This visualisation enables the handling of date intervals more effectively and aids the interpretability of the hypotheses.

# Compute scores
date_pred_y = np.array(attribution_results.year_scores)
date_pred_x = np.arange(
    dataset_config.date_min + dataset_config.date_interval / 2,
    dataset_config.date_max + dataset_config.date_interval / 2,
    dataset_config.date_interval,
)
date_pred_argmax = (
    date_pred_y.argmax() * dataset_config.date_interval
    + dataset_config.date_min
    + dataset_config.date_interval // 2
)
date_pred_avg = np.dot(date_pred_y, date_pred_x)

# Plot figure
fig = plt.figure(figsize=(10, 5), dpi=100)

plt.bar(
    date_pred_x,
    date_pred_y,
    color='#fdcc75',
    width=10.0,
    label='Aeneas distribution',
)
plt.axvline(
    x=date_pred_avg, color='#da2b2f', linewidth=2.0, label='Aeneas average'
)


plt.ylabel('Probability', fontsize=14)
yticks = np.arange(0, 1.1, 0.1)
yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks))
plt.yticks(yticks, yticks_str, fontsize=12, rotation=0)
plt.ylim(0, int((date_pred_y.max() + 0.1) * 10) / 10)

plt.xlabel('Date', fontsize=14)
xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25))
xticks_str = list(map(bce_ce, xticks))
plt.xticks(xticks, xticks_str, fontsize=12, rotation=0)
plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100))
plt.legend(loc='upper right', fontsize=12)

plt.show()

In [None]:
# @title Chronological attribution saliency map
# @markdown Saliency map shows which unique input text features contributed the most to Aeneas’s top chronological attribution hypothesis.
display(
    HTML(
        generate_saliency(
            text=attribution_results.input_text,
            saliency=attribution_results.date_saliency,
        )
    )
)