In [74]:
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import models
from torch import nn

word_embedding_model = models.Transformer('nghuyong/ernie-1.0')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
# dense_model = models.Dense(
#     in_features=pooling_model.get_sentence_embedding_dimension(), 
#     out_features=512,
#     activation_function=nn.Tanh(),
#     bias=True)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# model = SentenceTransformer('nghuyong/ernie-1.0')

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [75]:
import pandas as pd
import json
wiki_info = json.load(open('data/cleaned_wiki.json'))
train_csv = pd.read_csv('data/train.csv')
test_csv = pd.read_csv('data/val.csv')

In [76]:
def preprocess(data, wiki_info): # use riddle and the explanation of each choice to compute the similarity
    questions = []
    contexts = []
    choices = []
    labels = []
    for idx, row in data.iterrows():
        questions.append(f'{row["riddle"]}')
        labels.append(int(row['label']))
        context = []
        choice = []
        for i in range(5):
            name = f'choice{i}'
            explanation = wiki_info.get(row[name], '')
            choice.append(row[name])
            context.append(explanation)
        contexts.append(context)
        choices.append(choice)
    return questions, contexts, choices, labels

questions, contexts, choices, labels = preprocess(train_csv, wiki_info)

In [77]:
# training
from sentence_transformers import InputExample, losses, evaluation
from torch.utils.data import DataLoader

train_examples = []
valid_examples = []
train_size = int(len(questions) * 0.9)
eval_size = len(questions) - train_size

# train dataset
for i in range(train_size):
   question, context, label = questions[i], contexts[i], labels[i]
   for idx, text in enumerate(context):
      if(label == idx):
         train_examples.append(InputExample(texts=[question, text], label=1))
         # increase ratio of positive example
         # train_examples.append(InputExample(texts=[question, text], label=1))
         # train_examples.append(InputExample(texts=[question, text], label=1))
         # train_examples.append(InputExample(texts=[question, text], label=1))
      else:
         train_examples.append(InputExample(texts=[question, text], label=0))

# valid dataset
sentences1 = []
sentences2 = []
scores = []
for i in range(train_size, len(questions)):
   question, context, label = questions[i], contexts[i], labels[i]
   for idx, text in enumerate(context):
      sentences1.append(question)
      sentences2.append(text)
      scores.append(label == idx)
evaluator = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.ContrastiveLoss(model, margin=0.50)

In [78]:
model.fit(train_objectives=[(train_dataloader, train_loss)], 
            epochs=1, 
            warmup_steps=30, 
            optimizer_params={'lr': 4e-5}, 
            evaluator=evaluator, 
            evaluation_steps=200)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/563 [00:00<?, ?it/s]

In [79]:
questions, contexts, choices, labels = preprocess(test_csv, wiki_info)

In [80]:
import re
def charnum_to_num(charnum):
    lens={
            '一字':1,
            '二字':2,
            '三字':3,
            '四字':4,
            '五字':5,
            '六字':6,
            '七字':7,
            '八字':8,
            '九字':9,
            '十字':10,
            }
    return lens.get(charnum,None)
def pre_select(quiz,options=None): #str,str[5]，谜面和选项，返回bool[5],bool=true代表选项筛选后可能对
    poss=[True,True,True,True,True]#谜底都可能正确
    #按字数筛选
    charnum=re.findall('（.*?([一二三四五六七八九]字).*?）',quiz)#返回一个list，因为在括号内，所以不包含谜面的x字，只包含谜底字数
    if(len(charnum)!=0):#有关于谜底字数的描述
        num=charnum_to_num(charnum[0]) #谜底长度 注意不包括标点符号
        for i in range(5):
            #将options的标点都去掉，不占字数
            tmp_option=options[i].replace('，','')
            if(len(tmp_option)!=num):
                poss[i]=False
    return poss

In [81]:
# evaluate
tot = 0
tp = 0
for question, context, label in zip(questions, contexts, labels):
    pred = 0
    val = 0
    poss = pre_select(question, choices[tot])
    for idx, text in enumerate(context):
        if poss[idx] == False:
            continue
        embeddings1 = model.encode(question, convert_to_tensor=True)
        embeddings2 = model.encode(text, convert_to_tensor=True)
        score = util.pytorch_cos_sim(embeddings1, embeddings2)
        if score > val:
            pred = idx
            val = score
    tp += (pred == label)
    tot += 1
    if tot % 20 == 0:
        print(f'{tot}: tp={tp}')


20: tp=12
40: tp=22
60: tp=33
80: tp=47
100: tp=62
120: tp=77
140: tp=88
160: tp=101
180: tp=110
200: tp=121
220: tp=129
240: tp=142
260: tp=155
280: tp=169
300: tp=183
320: tp=196
340: tp=209
360: tp=220
380: tp=230
400: tp=241
420: tp=254
440: tp=267
460: tp=277
480: tp=292
500: tp=306


In [82]:
tp, tp / len(questions)

(306, 0.612)