In [1]:

from transformers import BertTokenizerFast
from transformers import AutoModelForSequenceClassification
import torch
import label_loader

tokenizer = BertTokenizerFast.from_pretrained('../bert-base-chinese')

device = torch.device("cuda")

model = AutoModelForSequenceClassification.from_pretrained("./channel-classifier-man")
model.to(device)
# man_model = AutoModelForSequenceClassification.from_pretrained("./channel-classifier-man")
# man_model.to(device)
# classifier = AutoModelForSequenceClassification.from_pretrained("../classification-yue/yue-classifier-can")
# classifier.to(device)
pass

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 40  # took about 6GB of GPU memory
label_idx = label_loader.load_label_index('labels.tsv')
idx_label = {v:k for k,v in label_idx.items()}

In [3]:
from torch.utils.data import DataLoader
from tqdm import tqdm
def text2token(txt):
  d = tokenizer(txt,padding="max_length", truncation=True,return_tensors='pt',max_length=64)
  d.update((k,v[0]) for k,v in d.items())
  return d
def batchPredict(batch_text : list,model=model):
    ret = []
    tokens = [text2token(t) for t in batch_text]
    d = DataLoader(tokens,batch_size=batch_size,shuffle=False)
    datasets = tqdm(d)
    for batch in datasets:
        inputs = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        outputs = model(input_ids=inputs, attention_mask=mask)
        logits = outputs.logits
        ret= [*ret,*logits.tolist()]
    return ret

def predict(text,model=model):
    token = tokenizer(text,return_tensors='pt')
    inputs = token['input_ids'].to(device)
    mask = token['attention_mask'].to(device)
    outputs = model(input_ids=inputs, attention_mask=mask)
    logits = outputs.logits
    return logits.tolist()

# def classify(batch_text):
#     print("classifing")
#     ret = []
#     tokens = [text2token(t) for t in batch_text]
#     d = DataLoader(tokens,batch_size=batch_size,shuffle=False)
#     datasets = tqdm(d)
#     datasets.set_description("cantonese classification")
#     for batch in datasets:
#         inputs = batch['input_ids'].to(device)
#         mask = batch['attention_mask'].to(device)
#         outputs = classifier(input_ids=inputs, attention_mask=mask)
#         logits = outputs.logits
#         ret= [*ret,*logits.argmax(-1).tolist()]
#     return ret

# def batchPredictByPrediction(batch_text: list,specific_classificaiton : int,spec_model):
#     p_classification = classify(batch_text)
#     filtered_batch = []
#     for pc,t in zip(p_classification,batch_text):
#         if pc == specific_classificaiton:
#             filtered_batch.append(t)
#     if specific_classificaiton == 1:
#         return batchPredict(filtered_batch, spec_model)
#     else:
#         return batchPredict(filtered_batch, spec_model)

In [5]:
def sort_then_output_labels(labels):
  dct = {}
  ret = []
  for label in labels:
    for name,val in zip(label_idx.keys(),label):
      dct[name] = val
    ret.append(list(reversed(sorted(dct.items(),key=lambda x:x[1]))))
  return ret

In [11]:
text = "上ifc要用安心出行嗎？"
labels = sort_then_output_labels (predict(text))
torch.cuda.empty_cache()
labels[0][:3]

[('手機台', 8.261186599731445),
 ('時事台', 5.150082111358643),
 ('Apps台', 3.9472124576568604)]

In [6]:
torch.cuda.empty_cache()
p_val = sort_then_output_labels(batchPredict(batch_text))
p_labels = [(p[0][0],p[1][0],p[2][0]) for p in p_val]

100%|██████████| 1/1 [00:01<00:00,  1.46s/it]


In [None]:
torch.cuda.empty_cache()
# single 
cnt = 0
correct = 0
with open('out_with_labels.tsv',encoding='utf-8') as f:
  lines = f.readlines()
for l in lines:
  grp = str(l).strip().split('\t')
  p_label = sort_then_output_labels(predict(grp[1]))[0][0]
  cnt += 1
  correct += int(p_label == grp[0])
print(f"accuracy: {(correct / cnt) * 100:2f}%")

In [5]:
with open('out_with_labels_ori.tsv',encoding='utf-8') as f:
  lines = [str(l).strip().split('\t') for l in f.readlines()]

In [6]:
# batch
torch.cuda.empty_cache()
p_val = sort_then_output_labels(batchPredict([grp[1] for grp in lines[300000:310000]]))
eval_label = [grp[0] for grp in lines]
p_labels = [(p[0][0],p[1][0],p[2][0]) for p in p_val]

100%|██████████| 250/250 [00:44<00:00,  5.66it/s]


In [7]:
cnt = 0
correct = 0
grp_cnt = {}
grp_correct = {}
for e,p in zip(eval_label,p_labels):
  if e in p:
    correct += 1
    grp_correct[e] = grp_correct.get(e,0) + 1
  grp_cnt[e] = grp_cnt.get(e,0) + 1
  cnt += 1
print(f"accuracy in general: {(correct / cnt) * 100:2f}%")
for label in grp_cnt.keys():
  print(f"accuracy of {label}: {grp_correct.get(label,0) / grp_cnt[label] * 100 :2f} count:{grp_cnt[label]}")

accuracy in general: 17.630000%
accuracy of 電訊台: 6.666667 count:45
accuracy of 動漫台: 7.344633 count:177
accuracy of 吹水台: 48.101266 count:1027
accuracy of 時事台: 25.134553 count:1858
accuracy of 財經台: 11.725453 count:1049
accuracy of 娛樂台: 16.297609 count:1129
accuracy of 體育台: 16.401590 count:1006
accuracy of 遊戲台: 10.240964 count:166
accuracy of 汽車台: 2.631579 count:114
accuracy of 校園台: 4.237288 count:118
accuracy of 感情台: 11.724138 count:580
accuracy of World: 3.389831 count:59
accuracy of 創意台: 15.151515 count:330
accuracy of 影視台: 7.427056 count:377
accuracy of 音樂台: 4.878049 count:246
accuracy of 學術台: 14.159292 count:226
accuracy of 上班台: 6.907895 count:304
accuracy of 手機台: 0.943396 count:106
accuracy of 寵物台: 3.389831 count:59
accuracy of 飲食台: 8.013937 count:287
accuracy of 玩具台: 4.761905 count:21
accuracy of 健康台: 11.278195 count:133
accuracy of 政事台: 9.523810 count:42
accuracy of 房屋台: 8.270677 count:133
accuracy of 站務台: 2.702703 count:37
accuracy of 攝影台: 6.666667 count:30
accuracy of 潮流台: 1.785