#Zero-Shot Learning on a Korean NLI Dataset

The idea here is to test the capabilities of our finetuned NLI model on an unseen (on the use case) language. We chose **korean** as its syntax and typography weren't seen in the data we used to finetune the model.

In [1]:
%pip install -qU pandas torch transformers datasets onnxruntime tqdm torchvision

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
from torchinfo import summary
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

MODEL_PATH = "ajayat/xlm-roberta-large-xnli"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
summary(model.eval())

Layer (type:depth-idx)                                            Param #
XLMRobertaForSequenceClassification                               --
├─XLMRobertaModel: 1-1                                            --
│    └─XLMRobertaEmbeddings: 2-1                                  --
│    │    └─Embedding: 3-1                                        256,002,048
│    │    └─Embedding: 3-2                                        526,336
│    │    └─Embedding: 3-3                                        1,024
│    │    └─LayerNorm: 3-4                                        2,048
│    │    └─Dropout: 3-5                                          --
│    └─XLMRobertaEncoder: 2-2                                     --
│    │    └─ModuleList: 3-6                                       302,309,376
├─XLMRobertaClassificationHead: 1-2                               --
│    └─Linear: 2-3                                                1,049,600
│    └─Dropout: 2-4                                           

We use the KorNLI dataset. It is a Korean Natural Language Inference (NLI) dataset. The dataset is constructed by automatically translating the training sets of the SNLI, XNLI and MNLI datasets. It contains 942,854 training examples translated automatically and 7,500 evaluation (development and test) examples translated manually.

In [None]:
from datasets import load_dataset

dataset = load_dataset("klue", "nli")
df = dataset["validation"].to_pandas()

In [None]:
def tokenize_batch(premises, hypotheses, padding="longest"): # We need to tokenize, as usual
    return tokenizer(
        premises,
        hypotheses,
        padding=padding,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

In [5]:
def predict(df):
    predictions = []
    batch_size = 32

    for i in tqdm(range(0, len(df), batch_size)):
        batch = df.iloc[i:i+batch_size]
        tokens = tokenize_batch(batch["premise"].tolist(), batch["hypothesis"].tolist())
        input_ids = tokens["input_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            predictions.extend(preds)

    return predictions

In [6]:
from sklearn.metrics import classification_report, accuracy_score

df["prediction"] = predict(df)

if "label" in df.columns:
    print(classification_report(df["label"], df["prediction"], target_names=['entailment', 'neutral', 'contradiction']))
    print("Accuracy:", accuracy_score(df["label"], df["prediction"]))

100%|██████████| 94/94 [05:33<00:00,  3.55s/it]

               precision    recall  f1-score   support

   entailment       0.80      0.95      0.87      1000
      neutral       0.91      0.73      0.81      1000
contradiction       0.86      0.88      0.87      1000

     accuracy                           0.85      3000
    macro avg       0.86      0.85      0.85      3000
 weighted avg       0.86      0.85      0.85      3000

Accuracy: 0.8526666666666667





We also get a pretty good accuracy on an unseen language ! That's probably due to XLM Roberta's pre-training, which was also done on Korean. The finetuning enhances its performance on an NLI task.

Example of misclassified instance :



*   1997년 현대 유니콘스는 선수층 빈약함이 여실히 드러나게 되어 정규 시즌 6위를 기록한다: *In 1997, the Hyundai Unicorns' weakness in player base was clearly revealed, and they finished in 6th place in the regular season*
*   현대 유니콘스는 1997년 정규 시즌 순위에 들지 못한다: *The Hyundai Unicorns failed to qualify for the 1997 regular season*
* Real label: 2 (contradiction)
* Predicted label: 0 (entailment)


Let's try to explain which tokens are responsible for the classification by using the integrated gradients method.

In [7]:
%pip install -qU captum

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [8]:
from captum.attr import LayerIntegratedGradients
import torch

# Make sure model is in eval mode
model.eval()

# We'll use the embeddings layer of XLMRoberta
model_input = model.roberta.embeddings

def model_output(inputs, attention_mask=None):
    outputs = model(inputs, attention_mask=attention_mask)
    return outputs.logits

For the integrated gradients method to work, we need to compute a baseline :

In [14]:
def construct_input_and_baseline(premise, hypothesis, tokenizer, device):
    encoding = tokenize_batch(premise, hypothesis)

    # Move the input tensors to the specified device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Baseline: just pad tokens where there are real tokens
    baseline_input_ids = torch.full_like(input_ids, tokenizer.pad_token_id).to(device)

    # Get tokens for visualization (these will remain on the CPU as they are for display purposes)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())

    return input_ids, attention_mask, baseline_input_ids, tokens

In [15]:
lig = LayerIntegratedGradients(model_output, model_input)

In [16]:
def interpret_premise_hypothesis(premise, hypothesis, true_class, target_class, tokenizer,device):
    input_ids, attention_mask, baseline_input_ids, tokens = construct_input_and_baseline(premise, hypothesis, tokenizer,device)

    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=baseline_input_ids,
        additional_forward_args=(attention_mask,),
        return_convergence_delta=True,
        internal_batch_size=1,
        target=target_class,
    )

    attributions_sum = attributions.sum(dim=-1).squeeze(0)
    attributions_sum = attributions_sum / torch.norm(attributions_sum)

    from captum.attr import visualization as viz

    pred = model_output(input_ids, attention_mask)

    score_vis = viz.VisualizationDataRecord(
        word_attributions=attributions_sum,
        pred_prob=torch.max(pred),
        pred_class=torch.argmax(pred).cpu().item(),
        true_class=true_class,
        attr_class=f"{premise} [SEP] {hypothesis}",
        attr_score=attributions_sum.sum(),
        raw_input_ids=tokens,
        convergence_score=delta
    )

    viz.visualize_text([score_vis])

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
premise = "1997년 현대 유니콘스는 선수층 빈약함이 여실히 드러나게 되어 정규 시즌 6위를 기록한다."
hypothesis = "현대 유니콘스는 1997년 정규 시즌 순위에 들지 못한다."
true_class = 2
target_class = 0

interpret_premise_hypothesis(premise, hypothesis, true_class, target_class, tokenizer, device)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,0 (1.36),1997년 현대 유니콘스는 선수층 빈약함이 여실히 드러나게 되어 정규 시즌 6위를 기록한다. [SEP] 현대 유니콘스는 1997년 정규 시즌 순위에 들지 못한다.,-0.08,#s ▁1997 년 ▁현대 ▁유니 콘 스는 ▁선수 층 ▁빈 약 함이 ▁여 실 히 ▁드러 나 게 ▁되어 ▁정 규 ▁시즌 ▁6 위를 ▁기록 한다 . #/s #/s ▁현대 ▁유니 콘 스는 ▁1997 년 ▁정 규 ▁시즌 ▁순 위에 ▁들 지 ▁못 한다 . #/s
,,,,


"시즌" (season) is versatile and can appear in both neutral and contradictory contexts, reducing its discriminative value.   
"위를" (6th place) is definitive thus strengthening its role in identifying relationships like contradiction or entailment.