In [1]:
from model.DecisionTree import DecisionTree
from model.DNN import DNN
from model.NaiveBayes import NaiveBayes
from model.SVM import SVM
import numpy as np
import jieba
from transformers import BertTokenizer, BertForSequenceClassification
import torch


def get_text_encoding(texts, dictionary):
    # texts: [text1, text2, ...]
    # dictionary: [word1, word2, ....]

    encoding_size = len(dictionary)
    text_encoding = np.zeros((len(texts), encoding_size))
    for textId in range(len(texts)):
        words = set(jieba.cut(texts[textId]))
        for word in words:
            for dicId  in range(encoding_size):
                if word == dictionary[dicId]:
                    text_encoding[textId, dicId] = 1
                    break
            
    return text_encoding


dictionary = []
with open('dictionary.txt', 'r') as file:
    for line in file:
        dictionary.append(line.strip())
print(dictionary)


decision_tree = DecisionTree()
naive_bayes = NaiveBayes()
dnn = DNN()

decision_tree.load('param/decision_tree_model.pth')
naive_bayes.load('param/naive_bayes_model.pth')
dnn.load('param/dnn_model.pth')

model_name = './bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
# model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
model = BertForSequenceClassification.from_pretrained('./bert-param')


def predict(text):
  
    encoding = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
    output = model(**encoding)
    predicted_label = torch.argmax(output.logits, dim=1).item()

    text_encoding = get_text_encoding([text], dictionary)
    return [decision_tree.predict(text_encoding)[0], 
            naive_bayes.predict(text_encoding)[0], 
            dnn.predict(text_encoding)[0], predicted_label]

print(predict(''))

['嘻嘻', '抓狂', '鼓掌', '回复', '偷笑', '一个', 'cn', 'http', '可爱', '开心', '北京', '转发', '中国', '喜欢', '谢谢', '真的', '...', '朋友', '美食', '呵呵', '威武', 'good', '老师', '围观', '旅游', '馋嘴', '酒店', '时间', '哈哈哈', '生活', '终于', '感谢', '希望', '失望', '微博', '孩子', '不错', '吃惊', '幸福', '亲亲', '期待', '支持', '悲伤', '世界', '发现', '可怜', '感觉', '花心', '快乐', '活动', '上海', '关注', '抱抱', '回来', '很多', '好吃', '同学', '伤心', '照片', '回家', '居然', '妈妈', '旅行', '小时', '鼻屎', '感动', '地方', '工作', '东西', '怒骂', '分享', '推荐', '手机', '思考', '加油', '餐厅', '两个', '特别', '只能', '蜡烛', '官方', '害羞', '恭喜', '记得', '机会', '味道', '电影', '童鞋', '超级', '美女', '兔子', '赶紧', '咖啡', '公司', '睡觉', '委屈', '真心', '人生', '心情', '蛋糕', '礼物', '男人', '好好', '话筒', '美丽', '新浪', '全球', '告诉', '生病', '女人', '节目', '努力', '不到', '吃货', '参加', '可惜', '摄影', '好像', '亲们', '看着', '实在', '吃饭', '不了', '现场', '估计', '确实', '鄙视', '竟然', '路上', '还要', '亲爱', '几天', '永远', '围脖', '一点', '小伙伴', '飞机', '电话', '辛苦', '神马', '太阳', '一句', '姐姐', '几个', '到底', '第一次', '演员', '不用', '美好', '上班', '国际', '一种', '阳光', '享受', '肯定', '小姐', '创意', '精彩', '时尚', '第一', '姑娘', '学习', '事情', '漂亮', '不想', '

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\28495\AppData\Local\Temp\jieba.cache
Loading model cost 0.351 seconds.
Prefix dict has been built successfully.


[0, 0, tensor(1), 1]


  prob = F.softmax(out)
