* In this notebook we will define a custom loss function for token classification task. We want to find named entities in python code, so we may want to weight more words like np, pd etc. Thats why we build a new loss function.
* For each sample, we count how many key_tokens (np, pd etc) appear in that sequence (sample). That frequency will be the weight for that sample. Then we multiply the weight with the sample's loss and compute weighted mean over all samples.

In [2]:
!pip install accelerate

Collecting accelerate
  Downloading accelerate-0.6.2-py3-none-any.whl (65 kB)
     ---------------------------------------- 65.9/65.9 KB 3.7 MB/s eta 0:00:00
Installing collected packages: accelerate
Successfully installed accelerate-0.6.2




In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator

In [4]:
accelerator = Accelerator()
tokenizer = AutoTokenizer.from_pretrained("huggingface-course/code-search-net-tokenizer")
model = AutoModelForCausalLM.from_pretrained("huggingface-course/codeparrot-ds")

# get the input id of key tokens using the pretrained Tokenizer
keytoken_ids = []
for keyword in [
    "plt",
    "pd",
    "sk",
    "fit",
    "predict",
    " plt",
    " pd",
    " sk",
    " fit",
    " predict",
]:
    ids = tokenizer([keyword]).input_ids[0]
    keytoken_ids.append(ids[0])

batch = tokenizer(["import numpy as np"], return_tensors="pt")
model = accelerator.prepare(model)

Downloading:   0%|          | 0.00/265 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/771k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.28M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/938 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/486M [00:00<?, ?B/s]

In [5]:
from torch.nn import CrossEntropyLoss
import torch

def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate and scale weighting
    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
        axis=[0, 2]
    )
    weights = alpha * (1.0 + weights)
    # Calculate weighted average
    weighted_loss = (loss_per_sample * weights).mean()
    return weighted_loss

In [6]:
logits = model(batch["input_ids"]).logits
loss = keytoken_weighted_loss(batch["input_ids"], logits, keytoken_ids)
accelerator.backward(loss)



In [7]:
from transformers import Trainer

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        outputs = model(input_ids)
        loss = keytoken_weighted_loss(input_ids, outputs.logits, keytoken_ids)

        return (loss, outputs) if return_outputs else loss