In [1]:
try:
    import amlta  # noqa: F401
except ImportError:
    %pip install -q --no-dependencies -U git+https://github.com/woranov/amlta-project.git

In [2]:
from pathlib import Path

try:
    from google.colab import drive  # pyright: ignore[reportMissingImports]

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

from amlta.config import config

In [3]:
if IN_COLAB:
    mount_point = Path("/content/drive")
    drive_path = mount_point / "MyDrive"

    # edit
    data_dir = drive_path / "uni" / "ws2425" / "amlta" / "project" / "data"

    config.update(data_dir=data_dir)

    if not mount_point.exists():
        drive.mount(str(mount_point))

In [None]:
import pandas as pd
from sklearn.metrics import classification_report
from tqdm.contrib.concurrent import thread_map

from amlta.probas.flows import extract_process_flows
from amlta.probas.processes import ProcessData
from amlta.question_generation.process import (
    QuestionData,
    load_batches,
)
from amlta.question_generation.query_params import get_flows_for_query
from amlta.tapas.model import load_tapas_model, load_tapas_tokenizer
from amlta.tapas.retrieve import retrieve_rows

In [None]:
training_df = pd.read_parquet(
    config.data_dir / "tapas-ft" / "data" / "tapas_train_batched_dfs_shuffled.parquet"
)

In [None]:
not_trained_on_data = training_df.iloc[int(len(training_df) * 0.8) :]
start_batch = not_trained_on_data["batch"].values[0]
start_question_id = int(not_trained_on_data["question_id"].values[0])
start_process_uuid = not_trained_on_data["process_uuid"].values[0]

In [5]:
tokenizer = load_tapas_tokenizer()
model = load_tapas_model()

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'TapasTokenizer'. 
The class this function is called from is 'CustomTapasTokenizer'.


In [6]:
question_data = load_batches()
print(question_data[-1])

len(question_data)

{'batch': 'batch_gpt-4o-mini_800_1000', 'process_uuid': 'ecf185cb-c44b-4450-abc4-a92c4bf0b9b9', 'question_id': 2, 'basic_query': 'metal production', 'general_query': 'non-ferrous metal production', 'specific_query': 'aluminum production in Germany 2015', 'flow_query_params': {'query_type': 'names', 'direction': 'output', 'aggregation': 'count', 'flow_names': ['HFC-245ca', 'perfluoropropane', 'CFC-11', 'nickel']}, 'question': 'What are the output amounts of HFC-245ca, perfluoropropane, CFC-11, and nickel from <the process>?', 'question_replaced_basic': 'What are the output amounts of HFC-245ca, perfluoropropane, CFC-11, and nickel from metal production?', 'question_replaced_general': 'What are the output amounts of HFC-245ca, perfluoropropane, CFC-11, and nickel from non-ferrous metal production?', 'question_replaced_specific': 'What are the output amounts of HFC-245ca, perfluoropropane, CFC-11, and nickel from aluminum production in Germany 2015?'}


3000

In [None]:
start_valid_idx = next(
    i
    for i, q in enumerate(question_data)
    if q["batch"] == start_batch
    and q["question_id"] == start_question_id
    and q["process_uuid"] == start_process_uuid
)
valid_data = question_data[start_valid_idx:]
len(valid_data)

722

In [8]:
threshold = 0.5

In [None]:
import functools


@functools.lru_cache(maxsize=512)
def get_process(uuid):
    return ProcessData.from_uuid(uuid)


@functools.lru_cache(maxsize=512)
def get_flows(uuid):
    return extract_process_flows(get_process(uuid))

In [21]:
def get_true_labels(question: QuestionData):
    df = get_flows(question["process_uuid"])

    filtered = get_flows_for_query(df, question["flow_query_params"])

    labels = df.index.isin(filtered.index).astype(bool).tolist()
    aggregation = question["flow_query_params"]["aggregation"]

    aggregation = aggregation.replace("list", "NONE").upper()

    return labels, aggregation

In [48]:
def get_tapas_labels(question: QuestionData, threshold=threshold):
    df = get_flows(question["process_uuid"])

    query = question["question"]
    query = query.replace("<", "").replace(">", "")

    rows, aggregation, probs = retrieve_rows(
        df,
        query=query,
        model=model,
        tokenizer=tokenizer,
        threshold=0,
        return_probabilities=True,
    )
    rows_sorted = []
    probs_sorted = []

    for row, prob in sorted(zip(rows, probs), key=lambda x: x[0]):
        rows_sorted.append(row)
        probs_sorted.append(prob)

    labels = [bool(prob >= threshold) for prob in probs_sorted]

    return labels, aggregation, probs

In [50]:
y_preds_labels = []
y_preds_probs = []
y_preds_aggregation = []

for res in thread_map(get_tapas_labels, valid_data[:3], max_workers=8):
    labels, aggregation, probs = res
    y_preds_labels.append(labels)
    y_preds_probs.append(probs)
    y_preds_aggregation.append(aggregation)

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
y_trues_labels = []
y_trues_aggregation = []

for res in thread_map(get_true_labels, valid_data[:3], max_workers=8):
    labels, aggregation = res
    y_trues_labels.append(labels)
    y_trues_aggregation.append(aggregation)

  0%|          | 0/3 [00:00<?, ?it/s]

In [51]:
print(
    classification_report(
        [label for labels in y_trues_labels for label in labels],
        [label for labels in y_preds_labels for label in labels],
    )
)

              precision    recall  f1-score   support

       False       1.00      1.00      1.00       179
        True       1.00      1.00      1.00        64

    accuracy                           1.00       243
   macro avg       1.00      1.00      1.00       243
weighted avg       1.00      1.00      1.00       243



In [44]:
y_trues_aggregation

['NONE', 'NONE', 'COUNT']

In [45]:
y_preds_aggregation

['SUM', 'NONE', 'COUNT']

In [52]:
print(classification_report(y_trues_aggregation, y_preds_aggregation))

              precision    recall  f1-score   support

       COUNT       1.00      1.00      1.00         1
        NONE       1.00      0.50      0.67         2
         SUM       0.00      0.00      0.00         0

    accuracy                           0.67         3
   macro avg       0.67      0.50      0.56         3
weighted avg       1.00      0.67      0.78         3



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
