In [9]:
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('nghuyong/ernie-1.0')

Some weights of the model checkpoint at /data00/yuzihao.2001/.cache/torch/sentence_transformers/nghuyong_ernie-1.0 were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
import pandas as pd
import json
wiki_info = json.load(open('data/wiki_info_v2.json'))
train_csv = pd.read_csv('data/train.csv')
test_csv = pd.read_csv('data/val.csv')

In [11]:
def preprocess(data, wiki_info): # use riddle and the explanation of each choice to compute the similarity
    questions = []
    contexts = []
    labels = []
    for idx, row in data.iterrows():
        questions.append(f'{row["riddle"]}')
        labels.append(int(row['label']))
        context = []
        for i in range(5):
            name = f'choice{i}'
            explanation = wiki_info.get(row[name], '')
            context.append(explanation)
        contexts.append(context)
    return questions, contexts, labels

questions, contexts, labels = preprocess(train_csv, wiki_info)

In [12]:
# training
from sentence_transformers import InputExample, losses, evaluation
from torch.utils.data import DataLoader

train_examples = []
valid_examples = []
train_size = int(len(questions) * 0.8)
eval_size = len(questions) - train_size

# train dataset
for i in range(train_size):
   question, context, label = questions[i], contexts[i], labels[i]
   for idx, text in enumerate(context):
      if(label == idx):
         train_examples.append(InputExample(texts=[question, context], label=1))
         # increase ratio of positive example
         train_examples.append(InputExample(texts=[question, context], label=1))
         train_examples.append(InputExample(texts=[question, context], label=1))
         train_examples.append(InputExample(texts=[question, context], label=1))

      else:
         train_examples.append(InputExample(texts=[question, context], label=0))

# valid dataset
sentences1 = []
sentences2 = []
scores = []
for i in range(train_size, len(questions)):
   question, context, label = questions[i], contexts[i], labels[i]
   for idx, text in enumerate(context):
      sentences1.append(question)
      sentences2.append(text)
      scores.append(label == idx)
evaluator = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.ContrastiveLoss(model, margin=0.1)

In [13]:
model.fit(train_objectives=[(train_dataloader, train_loss)], 
            epochs=1, 
            warmup_steps=100, 
            optimizer_params={'lr': 5e-6}, 
            evaluator=evaluator, 
            evaluation_steps=100)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/800 [00:00<?, ?it/s]

In [14]:
questions, contexts, labels = preprocess(test_csv, wiki_info)

In [15]:
# evaluate
tp = 0
for question, context, label in zip(questions, contexts, labels):
    pred = 0
    val = 0
    for idx, text in enumerate(context):
        embeddings1 = model.encode(question, convert_to_tensor=True)
        embeddings2 = model.encode(text, convert_to_tensor=True)
        score = util.pytorch_cos_sim(embeddings1, embeddings2)
        if score > val:
            pred = idx
            val = score
    tp += (pred == label)


In [16]:
tp, tp / len(questions)

(190, 0.38)