In [None]:
! git clone https://github.com/amee342/semantic_role_labeling.git

In [None]:
cd semantic_role_labeling/

In [None]:
!pip install -q transformers datasets accelerate evaluate seqeval

In [None]:
# set up saving repo in drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Argument setting
model_checkpoint = "distilbert/distilbert-base-uncased"
batch_size = 16
task = "SRL"
training_epoch = 1

In [None]:
# Set random seed!

SEED = 0
set_seed(SEED)

There is an repository in drive called "SRL" for storing finetuned models

In [None]:
from typing import List
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification

from datasets import Dataset

## Load and Parse CONLLU dataset

In [None]:
def load_conll_sentences(path: str):

    sentences = []
    sent = []

    with open(path, "r", encoding="utf-8") as f:
        for line in f:

            line = line.rstrip("\n")

            #  save previous sentence
            # at boundary between 2 sentences
            if line.strip() == "":
                if sent:
                    sentences.append(sent)
                    sent = []
                continue

            # skip comments
            if line.startswith("#"):
                continue

            cols = line.split("\t")
            sent.append(cols)

    if sent:
        sentences.append(sent)

    return sentences

## Preprocessing



In [None]:
def count_sentences_and_tokens(sentences: List):
  """
  Return number of sentences (n_sent)
  and number of tokens from these sentences (n_token)
  """
  n_sent = len(sentences)
  n_token = sum(len(s) for s in sentences)

  return n_sent, n_token





### Replicate each sentence for each predicate

In [None]:
def find_predicate_index(sent,
                           label_col,
                           predicate_markers=("V", "B-V")):
  for i, row in enumerate(sent):
    if len(row) > label_col and row[label_col] in predicate_markers:
      return i
  return None



In [None]:
def replicate_sentences(sentences,
                        base_cols: int=11):

  instances = []
  """
  base_cols: the column with specified predicates
  """
  for sent in sentences:

    # check the maximum columns in specific sentence
    # assume it's consistent per token row
    max_cols = max(len(r) for r in sent)

    # nr of predicate-specific label columns
    k = max(0, max_cols-base_cols)

    if k == 0 :
      # sentence has no predicate
      continue

    # rely on k
    for j in range(k):
      label_col = base_cols + j  # 0-based index

      pred_index = find_predicate_index(sent, label_col)

      # fallback if no V marker found
      if pred_index is None:
        pred_index = next((i for i,r in enumerate(sent) if len(r) > 9 and r[9] not in ("_", "-", "")), None)


      tokens = [r[1] for r in sent] # FORM column

      #labels = [(r[label_col] if len(r) > label_col else "O") for r in sent]
      labels = [
                    "O" if (len(r) <= label_col or r[label_col] == "_")
                    else r[label_col]
                    for r in sent
              ]

      instances.append({
                "tokens": tokens,
                "predicate_index": pred_index,
                "labels": labels,
            })
  return instances



In [None]:
def load_and_preprocess(path:str):
  sentences = load_conll_sentences(path)
  before_s, before_t = count_sentences_and_tokens(sentences)

  instances = replicate_sentences(sentences)
  after_s, after_t = count_sentences_and_tokens(instances)

  return {
        "sentences": sentences,
        "instances": instances,
        "stats": {
            "before_sentences": before_s,
            "before_tokens": before_t,
            "after_instances": after_s,
            "after_tokens": after_t
        }
    }

### Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# check if tokenizer is backed by RUST
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
labels_all_tokens = True

def tokenize_and_align_labels(example):
  tokenized_inputs = tokenizer(example["tokens"], truncation=True, is_split_into_words=True)

  word_ids = tokenized_inputs.word_ids()
  previous_word_idx=None
  label_ids=[]

  for word_idx in word_ids:
    if word_idx is None:
      # for special token that is ignored in Pytorch,
      # set as -100
      label_ids.append(-100)
    elif word_idx != previous_word_idx:
      label_ids.append(example["labels"][word_idx])
    else:
      label_ids.append(example["labels"][word_idx]) if labels_all_tokens else label_ids.append(-100)

    previous_word_idx = word_idx
  return label_ids


In [None]:
labels_all_tokens = True  # or False

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True
    )

    aligned_labels = []

    for i, labels in enumerate(examples["labels_str"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []

        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label2id[labels[word_idx]])
            else:
                label_ids.append(label2id[labels[word_idx]] if labels_all_tokens else -100)

            previous_word_idx = word_idx

        aligned_labels.append(label_ids)

    tokenized_inputs["labels"] = aligned_labels  # <-- ints + -100
    return tokenized_inputs

In [None]:
# sanity check

dataset = load_and_preprocess("/content/semantic_role_labeling/data/en_ewt-up-test.conllu")
ds = Dataset.from_list(dataset['instances'])
ds = ds.rename_column("labels", "labels_str")


In [None]:
label_list = sorted({l for ex in ds for l in ex["labels_str"]})
label2id = {l:i for i,l in enumerate(label_list)}
id2label = {i:l for l,i in label2id.items()}

In [None]:
tokenized_datasets = ds.map(tokenize_and_align_labels, batched=True)

## Fine Tune Model

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import set_seed


model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    eval_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=training_epoch,
    weight_decay=0.01,
    seed=SEED,
    report_to="none",
)

In [None]:
# batch dataset
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
import json, pathlib

nb_path = "/content/semantic_role_labeling/bert_finetuning.ipynb"  # <- change this
p = pathlib.Path(nb_path)

nb = json.loads(p.read_text(encoding="utf-8"))

# Remove widget metadata that breaks nbconvert/GitHub rendering
meta = nb.get("metadata", {})
if "widgets" in meta:
    meta.pop("widgets", None)
    nb["metadata"] = meta

p.write_text(json.dumps(nb, ensure_ascii=False, indent=1), encoding="utf-8")
print("Cleaned:", nb_path)

In [None]:
tokenized_datasets[0]

In [None]:
tokenizer.convert_ids_to_tokens(tokenized_datasets[0]["input_ids"])

In [None]:
print(dataset["instances"][0])

In [None]:
id2label

In [None]:
from datasets import Dataset

In [None]:
ds = Dataset.from_list(dataset['instances'])

In [None]:
ds

In [None]:
tokenized_ds = ds.map(tokenize_and_align_labels, batched=True)
