<a href="https://colab.research.google.com/github/Odeuropa/wp4-conditions/blob/main/odeuropa_extract.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import csv
import os
import json
import fnmatch
from transformers import pipeline
from torch.utils.data import Dataset
from tqdm import tqdm

In [None]:
def extract_row(QA, row):
  context = row["Full_Sentence"]
  sword = row["q_smell_word"]
  question_cause = f"What is the cause of the {sword}?"
  question_effect = f"What is the effect of the {sword}?"
  result_cause = QA(question=question_cause, context=context, handle_impossible_answer=True)
  row["cause"] = result_cause["answer"] if result_cause["answer"] else "null"
  row["cause_start"] = result_cause["start"]
  row["cause_end"] = result_cause["end"]
  row["cause_score"] = result_cause["score"]
  result_effect = QA(question=question_effect, context=context, handle_impossible_answer=True)
  row["effect"] = result_effect["answer"] if result_effect["answer"] else "null"
  row["effect_start"] = result_effect["start"]
  row["effect_end"] = result_effect["end"]
  row["effect_score"] = result_effect["score"]
  return row

In [None]:
MODEL_NAME = "mbartolo/roberta-large-synqa-ext"
QA = pipeline("question-answering", model=MODEL_NAME, tokenizer=MODEL_NAME, device=0)

In [None]:
dataset_dir = "/content/drive/MyDrive/datasets/cause-effect"
output_dir = "/content/drive/MyDrive/colab_output/odeuropa"
matches = []
for root, dirnames, filenames in os.walk(dataset_dir):
    for filename in fnmatch.filter(filenames, '*.tsv'):
        matches.append(os.path.join(root, filename))
fpaths = [x for x in matches if "meta" not in x]
fpaths = sorted(fpaths)
data = []
accept = {"scent", "odour", "odor", "stench", "stink", "stunk", "perfume", "aroma", "reek", "fragrance", "whiff"}


for fpath in fpaths:
  total = len(open(fpath).readlines())
  ii = 1
  output_file = os.path.basename(fpath).rstrip(".tsv") + ".jsonl"
  output_file = os.path.join(output_dir, output_file)
  print(f"{fpath} -> {output_file}")
  seen = set()
  if os.path.exists(output_file):
    ids = set([json.loads(r.strip())["rid"] for r in open(output_file).readlines()])
    seen = seen | ids
    print(f"Seen: {len(seen)}")
    total -= len(seen)
  pbar = tqdm(total=total, position=0, leave=True)

  csvfile = open(fpath)
  csvreader = csv.DictReader(csvfile, dialect=csv.excel_tab)
  for row in csvreader:
    if ii in seen:
      ii += 1
      continue
    
    smell_word = "smell"
    if "smel" in row["Smell_Word"].lower():
      smell_word = "smell"
    else:
      for word in accept:
        if word in row["Smell_Word"].lower():
          smell_word = word
          break
    if smell_word == "stunk":
      smell_word = "stink"

    try:
      row["rid"] = ii
      row["q_smell_word"] = smell_word
      row["Full_Sentence"] = row[None][0]
      del row[None]
      row = extract_row(QA, row)
    except:
      ii += 1
      continue
    with open(output_file, "a") as fout:
      print(json.dumps(row), file=fout)
    ii += 1
    pbar.update(1)
  pbar.close()
