In [2]:
import torch
import sys
import datasets
from transformers import AutoTokenizer, XLMRobertaModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load the dataset
dataset = datasets.load_dataset(path="universal_dependencies", name="en_ewt", trust_remote_code=True)
print(dataset)
train_dataset = dataset["train"]
valid_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(train_dataset["text"][:10])

DatasetDict({
    train: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 12543
    })
    validation: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 2002
    })
    test: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 2077
    })
})
['Al-Zaman : American forces killed Shaikh Abdullah al-Ani, the preacher at the mosque in the town of Qaim, near the Syrian border.', '[This killing of a respected cleric will be causing us trouble for years to come.]', 'DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.', 'Two of them were being run by 2 officials of the Ministry of the Interior!', 'The MoI in Iraq is equivalent to the US FBI, so this would be like having J. Edgar H

In [4]:
all_deprels = [
    # these are the default UD dependency relations according to https://universaldependencies.org/u/dep/
    "acl", "acl:relcl", "advcl", "advcl:relcl", "advmod", "advmod:emph", "advmod:lmod", "amod", "appos",
    "aux", "aux:pass", "case", "cc", "cc:preconj", "ccomp", "clf", "compound", "compound:lvc",
    "compound:prt", "compound:redup", "compound:svc", "conj", "cop", "csubj", "csubj:outer",
    "csubj:pass", "dep", "det", "det:numgov", "det:nummod", "det:poss", "discourse", "dislocated",
    "expl", "expl:impers", "expl:pass", "expl:pv", "fixed", "flat", "flat:foreign", "flat:name",
    "goeswith", "iobj", "list", "mark", "nmod", "nmod:poss", "nmod:tmod", "nsubj", "nsubj:outer",
    "nsubj:pass", "nummod", "nummod:gov", "obj", "obl", "obl:agent", "obl:arg", "obl:lmod",
    "obl:tmod", "orphan", "parataxis", "punct", "reparandum", "root", "vocative", "xcomp",

    # we need some more for en_ewt
    "det:predet", "obl:npmod", "nmod:npmod"
]

# construct deprel to ID mapping
deprel_to_id = {rel: idx for idx, rel in enumerate(all_deprels)}

In [5]:
# Code for the assignment in https://github.com/coli-saar/cl/wiki/Assignment:-Dependency-parsing
# Alexander Koller, December 2023

def strip_none_heads(examples, i):
    tokens = examples["tokens"][i]
    heads = examples["head"][i]
    deprels = examples["deprel"][i]

    non_none = [(t, h, d) for t, h, d in zip(tokens, heads, deprels) if h != "None"]
    return zip(*non_none)


def map_first_occurrence(nums):
    """
    Maps a list of numbers to a dictionary that assigns each unique number the position of its first occurrence.

    Example:
    > map_first_occurrence([0,1,2,3,3,3,4])
    {0: 0, 1: 1, 2: 2, 3: 3, 4: 6}

    :param nums:
    :return:
    """
    seen = set()
    return {num: i for i, num in enumerate(nums) if num is not None and num not in seen and not seen.add(num)}


def pad_to_same_size(lists, padding_symbol):
    maxlen = max([len(l) for l in lists])
    return [l + (padding_symbol,) * (maxlen - len(l)) for l in lists]


def tokenize_and_align_labels(examples, deprel_to_id, tokenizer, skip_index=-100):
    # delete tokens with "None" head and their annotations
    examples_tokens, examples_heads, examples_deprels = [], [], []
    for sentence_id in range(len(examples["tokens"])):
        tt, hh, dd = strip_none_heads(examples, sentence_id)
        examples_tokens.append(tt)
        examples_heads.append(hh)
        examples_deprels.append(dd)

    tokenized_inputs = tokenizer(examples_tokens, truncation=True, is_split_into_words=True,
                                 padding=True)  # get "tokenizer" from global variable
    # tokenized_inputs is a dictionary with keys input_ids and attention_mask;
    # each is a list (per sentence) of lists (per token).

    remapped_heads = []  # these will be lists (per sentence) of lists (per token)
    deprel_ids = []
    tokens_representing_words = []
    num_words: list[int] = []
    maxlen_t2w = 0  # max length of a token_to_word_here list

    for sentence_id, annotated_heads in enumerate(examples_heads):
        deprels = examples_deprels[sentence_id]
        word_ids = tokenized_inputs.word_ids(batch_index=sentence_id)
        word_pos_to_token_pos = map_first_occurrence(
            word_ids)  # word-pos to first token-pos; both start at 0 for first word (actual) / first token (BOS)

        previous_word_idx = None
        heads_here: list[int] = []
        deprel_ids_here: list[int] = []

        # list of token positions that map to words (first token of each word)
        # token 0 -> word 0 (BOS)
        tokens_representing_word_here: list[int] = [0]

        for sentence_position, word_idx in enumerate(word_ids):
            # Special tokens (BOS, EOS) have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                heads_here.append(skip_index)
                deprel_ids_here.append(skip_index)

            # We set the label for the first token of each word;
            # subsequent tokens of the same word will have the same word_idx.
            elif word_idx != previous_word_idx:
                if annotated_heads[word_idx] == "None":  # added by padding
                    print("A 'None' head survived!")
                    sys.exit(0)
                else:
                    # Map HEAD annotation to position of first token of head word.
                    # HEAD = 0 => map it to first token (BOS)
                    # Otherwise, look up first token for HEAD-1 (HEAD is 1-based, word positions are 0-based)
                    head_word_pos = int(annotated_heads[word_idx])
                    head_token_pos = 0 if head_word_pos == 0 else word_pos_to_token_pos[head_word_pos - 1]

                    heads_here.append(head_token_pos)
                    deprel_ids_here.append(deprel_to_id[deprels[word_idx]])

                    tokens_representing_word_here.append(sentence_position)  # first word is index 1; index 0 is BOS

            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                heads_here.append(skip_index)
                deprel_ids_here.append(skip_index)

            previous_word_idx = word_idx

        remapped_heads.append(heads_here)
        deprel_ids.append(deprel_ids_here)
        tokens_representing_words.append(tokens_representing_word_here)

        num_words.append(len(tokens_representing_word_here))
        if len(tokens_representing_word_here) > maxlen_t2w:
            maxlen_t2w = len(tokens_representing_word_here)

    # pad t2w lists to same length
    for t2w in tokens_representing_words:
        t2w += [-1] * (maxlen_t2w - len(t2w))

    tokenized_inputs["head"] = remapped_heads
    tokenized_inputs["deprel_ids"] = deprel_ids
    tokenized_inputs["tokens_representing_words"] = tokens_representing_words
    tokenized_inputs["num_words"] = num_words
    tokenized_inputs["tokenid_to_wordid"] = [tokenized_inputs.word_ids(batch_index=i) for i in
                                             range(len(examples_heads))]  # map token ID to word ID

    return tokenized_inputs

In [6]:
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")

# test tokenization
tokenized_inputs = tokenize_and_align_labels(train_dataset[:10], deprel_to_id, tokenizer)

for i in range(10):
    tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][i])  # i 是句子的索引
    word_ids = tokenized_inputs["tokenid_to_wordid"][i]

    print(f"Example {i + 1}")
    print(f"{'Token':<15}{'Token ID':<10}{'Head':<10}{'Deprel':<15}{'Word Mapping':<15}")
    for j, token in enumerate(tokens):
        if token == tokenizer.pad_token:
            break
        token_id = tokenized_inputs["input_ids"][i][j]
        head = tokenized_inputs["head"][i][j]
        deprel = all_deprels[tokenized_inputs["deprel_ids"][i][j]] if tokenized_inputs["deprel_ids"][i][
                                                                          j] != -100 else "None"
        word_mapping = word_ids[j]

        token_str = token if token else "None"
        token_id = str(token_id)
        head_str = str(head)
        deprel_str = deprel if deprel else "None"
        word_mapping_str = str(word_mapping) if word_mapping is not None else "None"

        print(f"{token_str:<15}{token_id:<10}{head_str:<10}{deprel_str:<15}{word_mapping_str:<15}")



Example 1
Token          Token ID  Head      Deprel         Word Mapping   
<s>            0         -100      None           None           
▁Al            884       0         root           0              
▁-             20        1         punct          1              
▁Zaman         53113     1         flat           2              
▁:             152       1         punct          3              
▁American      15672     6         amod           4              
▁forces        84616     7         nsubj          5              
▁killed        152388    1         parataxis      6              
▁Sha           7224      7         obj            7              
ikh            41336     -100      None           7              
▁Abdullah      34490     8         flat           8              
▁al            144       8         flat           9              
▁-             20        8         punct          10             
▁Ani           32340     8         flat           11             


In [7]:
from datasets import Dataset

# tokenized dataset and construct dataloader
train_tokenized_inputs = tokenize_and_align_labels(train_dataset[:], deprel_to_id, tokenizer)
# Convert BatchEncoding to Dataset
train_dataset = Dataset.from_dict(train_tokenized_inputs.data)
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'head', 'deprel_ids'])
print(train_dataset)

valid_tokenized_inputs = tokenize_and_align_labels(valid_dataset[:], deprel_to_id, tokenizer)
valid_dataset = Dataset.from_dict(valid_tokenized_inputs.data)
valid_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'head', 'deprel_ids'])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
next(iter(train_dataloader))

Dataset({
    features: ['input_ids', 'attention_mask', 'head', 'deprel_ids', 'tokens_representing_words', 'num_words', 'tokenid_to_wordid'],
    num_rows: 12543
})


{'input_ids': tensor([[    0,   884,    20,  ...,     1,     1,     1],
         [    0,   378,  3293,  ...,     1,     1,     1],
         [    0,   391, 12236,  ...,     1,     1,     1],
         ...,
         [    0,   581,   262,  ...,     1,     1,     1],
         [    0, 56645, 14508,  ...,     1,     1,     1],
         [    0,  1529,    83,  ...,     1,     1,     1]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'head': tensor([[-100,    0,    1,  ..., -100, -100, -100],
         [-100,   12,    3,  ..., -100, -100, -100],
         [-100,    0, -100,  ..., -100, -100, -100],
         ...,
         [-100,    6,    6,  ..., -100, -100, -100],
         [-100,    7, -100,  ..., -100, -100, -100],
         [-100,   10,   10,  ..., -100, -100, -100]]),
 'deprel_ids': tensor([[-100,  

In [8]:
class DependencyParserModel(torch.nn.Module):
    def __init__(self,
                 hidden_size=768,  # xlm-roberta-base hidden size
                 edge_mlp_dim=500,  # Dozat&Manning recommend 500
                 label_mlp_dim=100,  # Dozat&Manning recommend 100
                 num_labels=len(all_deprels),  # number of dependency labels
                 edge_predicting=True,
                 label_predicting=False,
                 dropout_prob=0.33
                 ):
        super().__init__()

        self.edge_predicting = edge_predicting
        self.label_predicting = label_predicting
        self.num_labels = num_labels

        # Load pre-trained XLM-RoBERTa model
        self.roberta = XLMRobertaModel.from_pretrained("xlm-roberta-base")

        # Freeze RoBERTa parameters
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Define MLP for edge head and dependency projections
        if self.edge_predicting:
            self.edge_mlp_head = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, edge_mlp_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout_prob)
            )
            self.edge_mlp_dep = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, edge_mlp_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout_prob)
            )

            # Define U1 and u2 for edge scoring
            self.U1 = torch.nn.Parameter(torch.empty(edge_mlp_dim, edge_mlp_dim))
            self.u2 = torch.nn.Parameter(torch.empty(edge_mlp_dim))
            # Apply Xavier initialization
            torch.nn.init.xavier_uniform_(self.U1)
            torch.nn.init.xavier_uniform_(self.u2.unsqueeze(0))
            self.u2.squeeze(0)

        # Extra: Edge labels predicting
        # Define MLP for label head and dependency projections
        if self.label_predicting:
            self.label_mlp_head = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, label_mlp_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout_prob)
            )
            self.label_mlp_dep = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, label_mlp_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout_prob)
            )

            # Define parameters for label scoring
            self.U_label = torch.nn.Parameter(torch.empty(label_mlp_dim, label_mlp_dim))
            self.W1_label = torch.nn.Parameter(torch.empty(label_mlp_dim, num_labels))
            self.W2_label = torch.nn.Parameter(torch.empty(label_mlp_dim, num_labels))
            self.b_label = torch.nn.Parameter(torch.empty(num_labels))

            # Apply Xavier initialization
            torch.nn.init.xavier_uniform_(self.U_label)
            torch.nn.init.xavier_uniform_(self.W1_label)
            torch.nn.init.xavier_uniform_(self.W2_label)
            torch.nn.init.xavier_uniform_(self.b_label.unsqueeze(0))
            self.b_label.squeeze(0)

    def forward(self, input_ids, attention_mask):
        """
        returns:
          H_head (edge MLP): [batch_size, seq_len, edge_mlp_dim]
          H_dep (edge MLP):  [batch_size, seq_len, edge_mlp_dim]
          L_head (label MLP): [batch_size, seq_len, label_mlp_dim]
          L_dep (label MLP):  [batch_size, seq_len, label_mlp_dim]
        """
        # initialize as None
        H_head = H_dep = L_head = L_dep = None

        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)

        # Shape: (batch_size, seq_length, hidden_size: 768)
        last_hidden_state = outputs.last_hidden_state

        if self.edge_predicting:
            # edge MLP projections
            H_head = self.edge_mlp_head(last_hidden_state)
            H_dep = self.edge_mlp_dep(last_hidden_state)

        if self.label_predicting:
            # Label MLP projections
            L_head = self.label_mlp_head(last_hidden_state)
            L_dep = self.label_mlp_dep(last_hidden_state)

        return H_head, H_dep, L_head, L_dep

    def score_edges(self, H_head, H_dep):
        # score[i, j] = H_head[i] * U1 * H_dep[j].T + H_head[i] * u2
        # b: batch_size, s: seq_len, d: edge_mlp_dim
        H_head_U1 = torch.einsum("bsd,dd->bsd", H_head, self.U1)
        H_head_U1_H_dep = torch.einsum("bim,bjm->bij", H_head_U1, H_dep)
        H_head_u2 = torch.einsum("bid,d->bi", H_head, self.u2)  # Shape: (batch_size, seq_len)

        # Shape: (batch_size, seq_len, seq_len) + (batch_size, seq_len, 1) broadcasting
        scores = H_head_U1_H_dep + H_head_u2.unsqueeze(2)
        return scores

    def score_labels(self, L_head, L_dep):
        # Compute biaffine term: x1 U x2^T
        L_head_U_L_dep = torch.einsum("bid,dd,bjd->bij", L_head, self.U_label,
                                      L_dep)  # Shape: (batch_size, seq_len, seq_len)

        # Compute linear terms: W1 x1 and W2 x2
        W1_L_head = torch.einsum("bid,dn->bin", L_head, self.W1_label)  # Shape: (batch_size, seq_len, num_labels)
        W2_L_dep = torch.einsum("bid,dn->bin", L_dep, self.W2_label)  # Shape: (batch_size, seq_len, num_labels)

        # Add bias term
        bias = self.b_label.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 1, num_labels)

        # Combine all terms and expand dimensions
        # Shape: (batch_size, seq_len, seq_len, num_labels)
        label_scores = L_head_U_L_dep.unsqueeze(-1) + W1_L_head.unsqueeze(2) + W2_L_dep.unsqueeze(1) + bias

        return label_scores

In [9]:
import wandb


# train the model
def train(model, train_dataloader, valid_dataset, device, num_epochs=5, lr=2e-3, weight_decay=1e-2, alpha=1, beta=1):
    # Initialize wandb
    wandb.init(project="dependency-parsing", name="dependency-parsing", config={
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "alpha": alpha,
        "beta": beta
    })

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # ignore the padding tokens
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

    model.to(device)
    model.train()

    print("batch num", len(train_dataloader))

    for epoch in range(num_epochs):
        total_loss = 0.0

        for i, data in enumerate(train_dataloader):
            x = data["input_ids"].to(device)
            mask = data["attention_mask"].to(device)
            head = data["head"].to(device)

            # Extra: Train edge labels predicting
            if model.label_predicting:
                label = data["deprel_ids"].to(device)

            optimizer.zero_grad()
            H_head, H_dep, L_head, L_dep = model(x, mask)

            edge_scores = model.score_edges(H_head, H_dep)  # Shape: (batch_size, seq_len, seq_len)

            # Extra: Train edge labels predicting
            if model.label_predicting:
                label_scores = model.score_labels(L_head, L_dep)  # Shape: (batch_size, seq_len, seq_len, num_labels)

            loss_edge = criterion(
                edge_scores.view(-1, edge_scores.size(-1)),  # Shape: (batch_size * seq_len, seq_len)
                head.view(-1)  # Shape: (batch_size, seq_len)
            )

            # Extra: Train edge labels predicting
            loss_label = 0
            if model.label_predicting:
                # Shape: (batch_size, seq_len, seq_len, num_labels)
                batch_size, n, _, num_labels = label_scores.shape

                # batch_indices: (batch_size, 1)
                # dep_indices:   (1,   n)
                batch_indices = torch.arange(batch_size).unsqueeze(1).to(device)  # batch index
                dep_indices = torch.arange(n).unsqueeze(0).to(device)  # dep index

                # gather label_scores according to the gold head
                # Shape: (batch_size, seq_len, num_labels)
                label_scores_for_gold_edge = label_scores[batch_indices, dep_indices, head, :]

                loss_label = criterion(label_scores_for_gold_edge.view(-1, num_labels), label.view(-1))

            loss = alpha * loss_edge + beta * loss_label

            loss.backward()
            optimizer.step()

            # log the loss curve
            total_loss += loss.item()
            print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item()}")
            wandb.log({"train_loss": loss.item(), "iteration": epoch * len(train_dataloader) + i})

            if i > 0 and i % 50 == 0:
                # verify the model
                model.eval()
                with torch.no_grad():
                    x = valid_dataset["input_ids"].to(device)
                    mask = valid_dataset["attention_mask"].to(device)
                    head = valid_dataset["head"].to(device)

                    H_head, H_dep, L_head, L_dep = model(x, mask)
                    edge_scores = model.score_edges(H_head, H_dep)

                    valid_loss = criterion(edge_scores.view(-1, edge_scores.size(-1)), head.view(-1))
                    print(f"Epoch {epoch}, Iteration {i}, Valid Loss: {valid_loss}")
                    wandb.log({"valid_loss": valid_loss.item(), "iteration": epoch * len(train_dataloader) + i})
                model.train()

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}")

    torch.save(model.state_dict(), "dependency_parser.pth")

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = DependencyParserModel()

train(model, train_dataloader, valid_dataset[:32], device)

  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile

batch num 392
Epoch 0, Iteration 0, Loss: 5.33665132522583
Epoch 0, Iteration 1, Loss: 5.124763011932373
Epoch 0, Iteration 2, Loss: 4.970407009124756
Epoch 0, Iteration 3, Loss: 4.807603359222412
Epoch 0, Iteration 4, Loss: 4.513617038726807
Epoch 0, Iteration 5, Loss: 4.342376232147217
Epoch 0, Iteration 6, Loss: 4.202470302581787
Epoch 0, Iteration 7, Loss: 4.026703357696533
Epoch 0, Iteration 8, Loss: 4.041365623474121
Epoch 0, Iteration 9, Loss: 3.9576528072357178
Epoch 0, Iteration 10, Loss: 3.6682827472686768
Epoch 0, Iteration 11, Loss: 3.5443265438079834
Epoch 0, Iteration 12, Loss: 3.296586751937866
Epoch 0, Iteration 13, Loss: 3.558723211288452
Epoch 0, Iteration 14, Loss: 3.331453323364258
Epoch 0, Iteration 15, Loss: 3.3524065017700195
Epoch 0, Iteration 16, Loss: 3.1872758865356445
Epoch 0, Iteration 17, Loss: 3.415900707244873
Epoch 0, Iteration 18, Loss: 3.2841956615448
Epoch 0, Iteration 19, Loss: 3.3367886543273926
Epoch 0, Iteration 20, Loss: 3.2996182441711426
Epoch