In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

num = 3

def count_avg(output):
    sum_all = output[1, :]
    for i in list(range(2, num + 1)):
        sum_all += output[i, :]
    
    return sum_all / num

df = pd.DataFrame(columns=('Token', 'bert-base-chinese', 'bert-chinese-wwm', 'roberta-chinese-wwm', 'simbert-chinese-base'))

input_text = list(map(lambda x: x.strip(), open(f"./data/ci_{num}_random.txt", "r", encoding="utf-8").readlines()))

bert_model_path = "../2.SememeV2/pretrained_model/bert-base-chinese"
simbert_model_path = "../2.SememeV2/pretrained_model/simbert-chinese-base"
bert_wwm_model_path = "../2.SememeV2/pretrained_model/chinese-bert-wwm-ext"
roberta_wwm_model_path = "../2.SememeV2/pretrained_model/chinese-roberta-wwm-ext"

bert_tkn = AutoTokenizer.from_pretrained(bert_model_path)
simbert_tkn = AutoTokenizer.from_pretrained(simbert_model_path)
bert_wwm_tkn = AutoTokenizer.from_pretrained(bert_wwm_model_path)
roberta_wwm_tkn = AutoTokenizer.from_pretrained(roberta_wwm_model_path)

bert = AutoModel.from_pretrained(bert_model_path).to("cuda:1")
simbert = AutoModel.from_pretrained(simbert_model_path).to("cuda:1")
bert_wwm = AutoModel.from_pretrained(bert_wwm_model_path).to("cuda:1")
roberta_wwm = AutoModel.from_pretrained(roberta_wwm_model_path).to("cuda:1")

with torch.no_grad():
    for char in tqdm(input_text):

        bert_input_encoded = { k: v.to("cuda:1") for k,v in bert_tkn(char, return_tensors="pt").items() }
        simbert_input_encoded = { k: v.to("cuda:1") for k,v in simbert_tkn(char, return_tensors="pt").items() }
        bert_wwm_input_encoded = { k: v.to("cuda:1") for k,v in bert_wwm_tkn(char, return_tensors="pt").items() }
        roberta_wwm_input_encoded = { k: v.to("cuda:1") for k,v in roberta_wwm_tkn(char, return_tensors="pt").items() }

        bert_out = bert(**bert_input_encoded).last_hidden_state.squeeze(0)
        bert_char_similarity = torch.cosine_similarity(
            count_avg(bert_out), bert_out[0, :],
            dim=0, eps=1e-08
        ).item()

        simbert_out = simbert(**simbert_input_encoded).last_hidden_state.squeeze(0)
        simbert_char_similarity = torch.cosine_similarity(
            count_avg(simbert_out), simbert_out[0, :],
            dim=0, eps=1e-08
        ).item()

        bert_wwm_out = bert_wwm(**bert_wwm_input_encoded).last_hidden_state.squeeze(0)
        bert_wwm_char_similarity = torch.cosine_similarity(
            count_avg(bert_wwm_out), bert_wwm_out[0, :],
            dim=0, eps=1e-08
        ).item()

        roberta_wwm_out = roberta_wwm(**roberta_wwm_input_encoded).last_hidden_state.squeeze(0)
        roberta_wwm_char_similarity = torch.cosine_similarity(
            count_avg(roberta_wwm_out), roberta_wwm_out[0, :],
            dim=0, eps=1e-08
        ).item()

        df = df.append(pd.Series({
            'Token':char,
            'bert-base-chinese':bert_char_similarity,
            'bert-chinese-wwm':bert_wwm_char_similarity,
            'roberta-chinese-wwm':roberta_wwm_char_similarity,
            'simbert-chinese-base':simbert_char_similarity
        }), ignore_index=True)

df.to_excel(f"./CLS与avg(TOK[{num}])_random.xlsx")

Some weights of the model checkpoint at ../2.SememeV2/pretrained_model/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ../2.SememeV2/pretrained_model/simbert-chinese-base were not used when initializing 

KeyboardInterrupt: 