In [1]:
%%capture
!pip install transformers

In [2]:
from transformers import BertTokenizer, BertForSequenceClassification

PRETRAINED_MODEL_NAME = "bert-base-chinese"
NUM_LABELS = 2

In [3]:
%%capture
#!unzip -o bestline-jc-JC-data.zip

RAW_DATA_PATH = 'JC_tbrain_train_final_0701.csv'    # T-Brain
DATA_PATH = 'bestline-jc-JC-data/JC/data'           # J.C.

In [4]:
import os
import pandas as pd
import codecs
from bs4 import BeautifulSoup

In [5]:
def load_data(crawled_data_path, original_data_path):
    
    raw_df = pd.read_csv(original_data_path) # Data provided by T-Brain

    news = []   # News crawled by J.C.
    labels = [] # Labels, AML related or not.

    for file in sorted(os.listdir(crawled_data_path)):
        # Get labels. Hint: Empty 'name' contains two characters '[]'.
        news_ID = int(file.split('_')[0])
        if len(raw_df.loc[news_ID-1, 'name']) > 2:
            labels.append(1)
        else:
            labels.append(0)

        # Get news content.
        f = codecs.open(DATA_PATH + '/' + file, 'r', 'utf-8')
        content = BeautifulSoup(f.read()).get_text()
        news.append(content)

    return news, labels

In [6]:
news, labels = load_data(crawled_data_path=DATA_PATH, original_data_path=RAW_DATA_PATH)

In [7]:
print(len(news[0][:512]), labels)

512 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,

In [8]:
import torch

In [10]:
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [11]:
inputs = tokenizer(news[10][:500], return_tensors="pt")
labels = torch.tensor(labels[10]).unsqueeze(0)

outputs = model(**inputs, labels=labels)
loss, logits = outputs[:2]

m = torch.nn.Softmax(dim=1)
output = m(logits)
_, pred = torch.max(output.data, 1)

In [27]:
m = torch.nn.Softmax(dim=1)

correct = 0
acc = 0

for i, article in enumerate(news):
    input = tokenizer(article[:500], return_tensors="pt")
    label = torch.tensor(labels[i]).unsqueeze(0)
    outputs = model(**input, labels=label)
    loss, logits = outputs[:2]

    output
    break

tensor([0]) {'input_ids': tensor([[ 101, 7030, 1265,  769, 3211, 6841, 3724, 5179, 2205, 1841, 6992, 3300,
         3126, 2205, 2834, 4281, 4220, 2356, 6818, 2399,  889, 2832, 6536, 2356,
         1842, 3797, 1240, 6632,  889, 6632, 3209, 7549, 8024, 6841, 3724,  856,
         3797, 1240,  510, 5179, 2205, 1841, 6992, 4638, 7030, 1265,  769, 3211,
          991, 1358, 3800, 4680,  511, 2201, 2157, 6134, 4850, 8024, 2967, 4500,
         7030, 1265,  769, 3211, 5032, 4526, 2832, 6536, 1378, 5500, 8024,  679,
         5052, 3221, 5993, 3176, 1914, 7531, 2772, 3221, 4958, 7531, 2356, 1842,
         8024, 5245, 3126, 1350, 3797, 1240, 2428, 1772, 1377, 7526, 6651, 1920,
         4676, 8024, 4493, 5635, 3683, 1751, 1058, 2832, 6536, 1378, 5500, 4638,
         5500, 4873, 1798, 1825, 7032, 1350,  100, 4638, 3797, 1240, 4372, 6917,
          856, 8024, 6134, 4412,  738, 3291, 4158, 4952, 2137,  511, 1920, 3149,
         3087, 3229,  807,  889, 5631, 8024, 7591, 6121, 3627, 5401, 8145, 2399,
  

  This is separate from the ipykernel package so we can avoid doing imports until


In [23]:
inputs

{'input_ids': tensor([[ 101, 2456, 6863, 3511, 6359, 3298, 2900, 8024, 8271, 5635, 9160, 6512,
         3124, 2399, 2428, 2456, 6863, 2339, 4923, 2130, 2768, 7030, 2902, 2399,
          678, 6649,  129,  119,  125,  110,  511,  712, 2375, 7376, 2157, 7689,
         6134, 4850, 8024, 3315, 3949, 3146, 7768, 2456, 6863, 2339, 4923, 7030,
         7444, 6206, 3680, 2399, 5204, 2898, 8202, 5635, 8720, 1023, 1039, 8024,
         2902, 1333,  889, 3124, 2424, 3177, 2339, 3229, 7279, 6134, 6868, 2428,
         8024, 3313,  889, 8108, 2399, 2456, 6863, 3511, 2200, 6210, 5646, 3200,
         8024, 2668,  128, 3299,  122, 3189, 4989, 3791, 3298, 1358, 4788, 1889,
         2527, 8024, 5147, 8389, 1023, 1039, 4638, 2339, 1243, 2339, 4923, 3060,
         3621, 2182, 2821, 1358, 7349,  511, 2902, 2218, 3511, 3149, 3087,  510,
         2832, 3560, 1019, 3419, 1350, 3511, 4518, 2692, 6210,  868, 1347, 5440,
         8024, 1920, 2157, 2205, 6121, 3511, 1184, 3250, 4685, 4534, 2726, 2719,
         8024,