In [2]:
import numpy as np
import pandas as pd
import pickle as pkl
import torch
from tqdm import tqdm
from transformers import BertTokenizer
from parallel_processor import process_data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
with open('datasets/dbpedia/train.txt', 'r') as fin:
    data = [item.replace('\n', '') for item in fin.readlines()]

In [4]:
with open('datasets/dbpedia/train_labels.txt', 'r') as fin:
    label = [int(item.replace('\n', '')) for item in fin.readlines()]

In [5]:
dic = torch.load('datasets/dbpedia/category_vocab.pt', 'rb')

In [6]:
tker = BertTokenizer.from_pretrained('bert-base-uncased')

In [7]:
with open('/hy-tmp/weakly-text-classification/datasets/dbpedia/cate_vocab_with_weight.pkl', 'rb') as fin:
    dic = pkl.load(fin)

In [58]:
with open('dbpedia_lotclass_dic.pkl', 'rb') as fin:
    dic = pkl.load(fin)

In [52]:
def pseudo_label(x, vocab):
    score = {k: 0 for k in vocab}
    for k, v in vocab.items():
        cnt = 0
        for w in v:
            if w[0] in x:
                score[k] += w[1] ** 2
                cnt += 1
        if cnt == 0:
            score[k] = 0
        else:
            score[k] /= (cnt ** 0.5)
    return list(score.values())

In [46]:
np.argmax(list(pseudo_label(data[-100000], dic)))

11

In [38]:
data[-3]

"The Blithedale Romance. The Blithedale Romance ( 1852 ) is Nathaniel Hawthorne's third major romance. In Hawthorne ( 1879 ) Henry James called it the lightest the brightest the liveliest of Hawthorne's unhumorous fictions."

In [59]:
def pseudo_label_batch(data, vocab):
    def pseudo_label(x):
        score = {k: 0 for k in vocab}
        for k, v in vocab.items():
            cnt = 0
            for w in v:
                if w[0] in x:
                    score[k] += w[1]
                    cnt += 1
            if cnt == 0:
                score[k] = 0
            else:
                score[k] /= (cnt ** 0.5)
        return list(score.values())
    def pseudo_label(x):
        score = {k: 0 for k in vocab}
        for k, v in vocab.items():
            for w in v:
                if w in x:
                    score[k] += 1
        return list(score.values())

    plabel = []
    for doc in tqdm(data):
        plabel.append(np.argmax(pseudo_label(doc)))
    return plabel

from functools import partial

pseudo_label_batch_p = partial(pseudo_label_batch, vocab=dic)

plabel = process_data(data, pseudo_label_batch_p, num_workers=60)

100%|██████████| 9334/9334 [00:06<00:00, 1347.85it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1145.80it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1163.35it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1117.91it/s]
100%|██████████| 9334/9334 [00:07<00:00, 1246.50it/s]
100%|██████████| 9334/9334 [00:07<00:00, 1262.28it/s]
100%|██████████| 9334/9334 [00:07<00:00, 1224.15it/s]
100%|██████████| 9334/9334 [00:07<00:00, 1219.11it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1163.55it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1119.32it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1147.79it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1112.20it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1077.12it/s]
100%|██████████| 9334/9334 [00:09<00:00, 1034.31it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1058.02it/s]
100%|██████████| 9334/9334 [00:08<00:00, 1102.77it/s]
100%|██████████| 9334/9334 [00:09<00:00, 1008.23it/s]
100%|██████████| 9334/9334 [00:09<00:00, 970.94it/s]]
100%|██████████| 9334/9334 [

In [60]:
set(plabel)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}

In [61]:
plabel = np.array(plabel)
label = np.array(label)
np.mean(plabel == label)

0.27028214285714286

In [56]:
for i in tqdm(range(len(set(label)))):
    print(i, np.mean(label[label == i] == plabel[label==i]))

100%|██████████| 14/14 [00:00<00:00, 1569.01it/s]

0 0.645
1 0.927375
2 0.194875
3 0.432175
4 0.337425
5 0.013
6 0.59065
7 0.9355
8 0.7808
9 0.002875
10 0.978925
11 0.327875
12 0.706475
13 0.906475





In [62]:
for i in tqdm(range(len(set(label)))):
    print(i, np.mean(label[label == i] == plabel[label==i]))

100%|██████████| 14/14 [00:00<00:00, 1046.95it/s]

0 0.254475
1 0.64595
2 0.1094
3 0.041125
4 0.16575
5 0.02585
6 0.095375
7 0.402425
8 0.975575
9 0.008925
10 0.377875
11 0.323325
12 0.214625
13 0.143275



