In [1]:
import pandas as pd
import pickle
import numpy as np

p_train = r'../data/ag_train.csv'
p_test = r'../data/ag_test.csv'
df_train = pd.read_csv(p_train, header=None)
df_test = pd.read_csv(p_test, header=None)
df_train.columns = ['labels', 'title', 'data']
df_test.columns = ['labels', 'title', 'data']
print(df_train.shape, df_test.shape)
print(df_train['labels'].value_counts())
print(df_test['labels'].value_counts())

# df_train_sample => 12만개 데이터에서 2만개만 추출
sample_df = df_train.groupby('labels').apply(lambda x: x.sample(frac=0.2, random_state=1))
print(sample_df.head())
# sample_df = sample_df.sample(frac=1)
sample_df = sample_df.sample(frac=1, random_state=1).reset_index(drop=True)
print(sample_df.shape)
print(sample_df['labels'].value_counts())
sample_df.to_csv('../data/ag_train_sample.csv')
df_train = sample_df

df_train.head()

(120000, 3) (7600, 3)
3    30000
4    30000
2    30000
1    30000
Name: labels, dtype: int64
3    1900
4    1900
2    1900
1    1900
Name: labels, dtype: int64
               labels                                              title  \
labels                                                                     
1      40546        1      Haitians Pray for 1,500 Killed by Jeanne (AP)   
       48245        1      U.S.-Led Forces Tighten Grip, Draw Complaints   
       118691       1  U.N. Says Bugging Device Found at Its Geneva H...   
       33489        1            Kerry Questions Bush's Judgment on Iraq   
       83190        1  Seoul Asks Bush to Focus on N.Korea Nuclear Cr...   

                                                            data  
labels                                                            
1      40546   AP - In a cathedral ankle-deep in mud and over...  
       48245    SAMARRA, Iraq (Reuters) - U.S.-led forces tig...  
       118691  Reuters - The United Nat

Unnamed: 0,labels,title,data
0,2,Leafs veteran Alex Mogilny undergoes another h...,TORONTO (CP) - Winger Alexander Mogilny of the...
1,3,Borders Posts Third Quarter Loss (Reuters),Reuters - Book retailer Borders Group Inc.\ on...
2,2,Britannia rules as baton again lets US down,SINCE the United States started taking part in...
3,1,Italy Calls To End Kyoto Climate Limits After ...,ROME - Italy has called for an end to the Kyot...
4,1,Singapore warns of deadly illness,The authorities in Singapore voice concern abo...


In [2]:
df_train['raw'] = df_train['title'] + ' '+df_train['data']
df_test['raw'] = df_test['title'] + ' '+df_test['data']

# label이 1~4까지라서 => 0~3 까지로 변경해주어야 함
df_train['labels'] = df_train['labels'] - 1
df_test['labels'] = df_test['labels'] - 1

df_train

Unnamed: 0,labels,title,data,raw
0,1,Leafs veteran Alex Mogilny undergoes another h...,TORONTO (CP) - Winger Alexander Mogilny of the...,Leafs veteran Alex Mogilny undergoes another h...
1,2,Borders Posts Third Quarter Loss (Reuters),Reuters - Book retailer Borders Group Inc.\ on...,Borders Posts Third Quarter Loss (Reuters) Reu...
2,1,Britannia rules as baton again lets US down,SINCE the United States started taking part in...,Britannia rules as baton again lets US down SI...
3,0,Italy Calls To End Kyoto Climate Limits After ...,ROME - Italy has called for an end to the Kyot...,Italy Calls To End Kyoto Climate Limits After ...
4,0,Singapore warns of deadly illness,The authorities in Singapore voice concern abo...,Singapore warns of deadly illness The authorit...
...,...,...,...,...
23995,1,Angry Kick Puts Cubs' Farnsworth on DL (AP),AP - Chicago Cubs reliever Kyle Farnsworth too...,Angry Kick Puts Cubs' Farnsworth on DL (AP) AP...
23996,2,FASB delays new options expensing rule by 6 mo...,"Bowing to corporate pressure, the group that s...",FASB delays new options expensing rule by 6 mo...
23997,0,"Blasts, Gunfire Shake Najaf As Talks Drag","NAJAF, Iraq - Explosions and gunfire shook Naj...","Blasts, Gunfire Shake Najaf As Talks Drag NAJA..."
23998,2,"1,500 Imperial Oil jobs leaving Toronto",Imperial Oil Ltd. says it will shift its head ...,"1,500 Imperial Oil jobs leaving Toronto Imperi..."


# Tokenizing 미리 해서 저장해두기

- train dataset에서 vocab 추출하고
- 추출한 vocab, token2idx로 train_data 바꿔치기 하고
- train_data의 max 길이로 패딩 (BasicCollator 참조)
- train_data랑, train_labels 묶어서 ag_train.pkl 로 저장
- 마찬가지로
- 추출한 vocab, token2idx로 test_data 바꿔치기 하고
- train_data의 max 길이로 패딩 (BasicCollator 참조)
- test_data, test_labels 묶어서 ag_test.pkl 로 저장

In [3]:
import collections
from typing import List, Tuple, Dict
from transformers import BertTokenizer


def build_tok_vocab(tokenize_target: List,
                    tokenizer,
                    min_freq: int = 1,
                    max_vocab=19998) -> Tuple[List[str], Dict]:
    vocab = []
    print('start tokenizing')
    for i, target in enumerate(tokenize_target):
        if i % 10000 == 0:
            print(i)
        try:
            temp = tokenizer.tokenize(target)
            vocab.extend(temp)
        except Exception as e_msg:
            error_target = f'idx: {i} \t target:{target}'

    print('start counting')
    vocab = collections.Counter(vocab)
    temp = {}
    # min_freq보다 적은 단어 거르기
    for key in vocab.keys():
        if vocab[key] >= min_freq:
            temp[key] = vocab[key]
    vocab = temp

    print('start sorting')
    # 가장 많이 등장하는 순으로 정렬한 후, 적게 나온것 위주로 vocab set에서 빼기
    vocab = sorted(vocab, key=lambda x: -vocab[x])
    if len(vocab) > max_vocab:
        vocab = vocab[:max_vocab]

    tok2idx = {'<pad>': 0, '<unk>': 1}
    for tok in vocab:
        tok2idx[tok] = len(tok2idx)
    vocab.extend(['<pad>', '<unk>'])
    print('tokenizing done')

    return vocab, tok2idx

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# full
data = [row['raw'] for _, row in df_train.iterrows()]
vocab_set, tok2idx = build_tok_vocab(data, tokenizer, min_freq=1, max_vocab=19998)
print(f'Vocab set size: {len(tok2idx)}')
print(vocab_set[0:5])

start tokenizing
0
10000
20000
start counting
start sorting
tokenizing done
Vocab set size: 20000
['.', 'the', ',', '-', 'to']


In [4]:
max_len = 0
for i in data:
    if len(i) > max_len:
        max_len = len(i)
print(f'max length: {max_len}')

tokenized_idx_data = []

for sentence in data:
    tokened_sentence = tokenizer.tokenize(sentence)
    token_list = []
    for word in tokened_sentence:
        if word not in tok2idx.keys():
            token_list.append(tok2idx['<unk>'])
        else:
            token_list.append(tok2idx[word])

    padding_list = [0] * (max_len - len(token_list))
    token_list = padding_list + token_list
    tokenized_idx_data.append(token_list)

print(len(tokenized_idx_data), len(tokenized_idx_data[0]))

max length: 996
24000 996


In [5]:
train_tokenized_idx = np.array(tokenized_idx_data)
train_labels_np = np.array(df_train['labels'])
train_data = (train_tokenized_idx, train_labels_np)
print(train_tokenized_idx.shape)
print(train_labels_np.shape)

print('now dumping pickle')
# with open(file='ag_train.pkl', mode='wb') as f:
with open(file='ag_train_sample.pkl', mode='wb') as f:
    pickle.dump(train_data, f)

(24000, 996)
(24000,)
now dumping pickle


In [8]:
# - 마찬가지로 test 에 대해서도

test_data = [row['raw'] for i, row in df_test.iterrows()]

test_max_len = 0
for i in test_data:
    if len(i) > test_max_len:
        test_max_len = len(i)
print(f'test max length: {test_max_len}')
if test_max_len > max_len:
    print('test max length is bigger than train_max_len')

tokenized_idx_test_data = []

for sentence in test_data:
    tokened_sentence = tokenizer.tokenize(sentence)
    token_list = []
    for word in tokened_sentence:
        if word not in tok2idx.keys():
            token_list.append(tok2idx['<unk>'])
        else:
            token_list.append(tok2idx[word])

    padding_list = [0] * (max_len - len(token_list))
    token_list = padding_list + token_list
    tokenized_idx_test_data.append(token_list)

print(len(tokenized_idx_test_data), len(tokenized_idx_test_data[0]))

test max length: 892
7600 996


In [7]:
test_tokenized_idx = np.array(tokenized_idx_test_data)
test_labels_np = np.array(df_test['labels'])
test_data = (test_tokenized_idx, test_labels_np)
print(test_tokenized_idx.shape)
print(test_labels_np.shape)

print('now dumping test pickle')
# with open(file='ag_test.pkl', mode='wb') as f:
with open(file='ag_test_sample.pkl', mode='wb') as f:
    pickle.dump(test_data, f)

(7600, 996)
(7600,)
now dumping test pickle
