In [56]:
# write in train.py
import sys, os
sys.path.append(os.pardir)

In [57]:
import tensorflow as tf
import json
import argparse

from data_utils import Data
from models.char_cnn_zhang import CharCNNZhang
from models.char_cnn_kim import CharCNNKim

In [61]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='char_cnn_zhang', help='Specifies which model to use: char_cnn_zhang or char_cnn_kim')
FLAGS = parser.parse_args(["--model", "char_cnn_zhang"])

In [62]:
FLAGS.model

'char_cnn_zhang'

In [59]:
# Load configurations
config = json.load(open('../config.json'))

In [33]:
for key, value in config.items():
    print(key, value)
    print()

notes default

data {'alphabet': 'abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:\'"/\\|_@#$%^&*~`+-=<>()[]{}', 'alphabet_size': 69, 'input_size': 1014, 'num_of_classes': 4, 'training_data_source': 'data/ag_news_csv/train.csv', 'validation_data_source': 'data/ag_news_csv/test.csv'}

model {'embedding_size': 128, 'conv_layers': [[256, 7, 3], [256, 7, 3], [256, 3, -1], [256, 3, -1], [256, 3, -1], [256, 3, 3]], 'fully_connected_layers': [1024, 1024], 'threshold': 1e-06, 'dropout_p': 0.5, 'optimizer': 'adam', 'loss': 'categorical_crossentropy'}

training {'epochs': 5000, 'batch_size': 128, 'evaluate_every': 100, 'checkpoint_every': 100}

char_cnn_zhang {'embedding_size': 128, 'conv_layers': [[256, 7, 3], [256, 7, 3], [256, 3, -1], [256, 3, -1], [256, 3, -1], [256, 3, 3]], 'fully_connected_layers': [1024, 1024], 'threshold': 1e-06, 'dropout_p': 0.5, 'optimizer': 'adam', 'loss': 'categorical_crossentropy'}

char_cnn_kim {'embedding_size': 128, 'conv_layers': [[256, 10], [256, 7], [256, 5], [256,

In [64]:
# See the data 
config['data']

{'alphabet': 'abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:\'"/\\|_@#$%^&*~`+-=<>()[]{}',
 'alphabet_size': 69,
 'input_size': 1014,
 'num_of_classes': 4,
 'training_data_source': 'data/ag_news_csv/train.csv',
 'validation_data_source': 'data/ag_news_csv/test.csv'}

In [65]:
# See the data 
config[FLAGS.model]['embedding_size']

128

In [66]:
model_name = config['model']
model_name

'char_cnn_zhang'

In [67]:
# change key from 'model' to 'char_cnn_zhang'
config['model'] = config[model_name]

In [68]:
config['model']

{'conv_layers': [[256, 7, 3],
  [256, 7, 3],
  [256, 3, -1],
  [256, 3, -1],
  [256, 3, -1],
  [256, 3, 3]],
 'dropout_p': 0.5,
 'embedding_size': 128,
 'fully_connected_layers': [1024, 1024],
 'loss': 'categorical_crossentropy',
 'optimizer': 'adam',
 'threshold': 1e-06}

In [70]:
# Set the data path in order to run in the notebook 
config['data']["training_data_source"] = '../data/ag_news_csv/train.csv'
config['data']["validation_data_source"] = '../data/ag_news_csv/test.csv'


In [37]:
# Load training data
training_data = Data(data_source=config["data"]["training_data_source"],
                     alphabet=config["data"]["alphabet"],
                     input_size=config["data"]["input_size"],
                     num_of_classes=config["data"]["num_of_classes"])
training_data.load_data()
training_inputs, training_labels = training_data.get_all_data()

Data loaded from ../data/ag_news_csv/train.csv


In [40]:
training_inputs.shape

(120000, 1014)

In [39]:
training_labels

array([[0, 0, 1, 0],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       ...,
       [0, 1, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 0, 0]])

## data_utils

### parameter 

In [37]:
# Dive into the Data Class to see what happend
alphabet="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
alphabet_size = len(alphabet)

char_dict = {}  # Maps each character to an integer
no_of_classes = 4

for idx, char in enumerate(alphabet):
    char_dict[char] = idx + 1

length = 1014
data_source = '../data/ag_news_csv/train.csv'


### read data 

In [2]:
import pandas as pd

In [9]:
train_df = pd.read_csv(data_source, header=None)

In [11]:
train_df.head()

Unnamed: 0,0,1,2
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."


In [21]:
train_df[1] = train_df[1] + train_df[2]

In [23]:
train_df[1][0]

"Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."

In [24]:
train_df.drop([2], axis=1)

Unnamed: 0,0,1
0,3,Wall St. Bears Claw Back Into the Black (Reute...
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new..."
5,3,"Stocks End Up, But Near Year Lows (Reuters)Reu..."
6,3,Money Funds Fell in Latest Week (AP)AP - Asset...
7,3,Fed minutes show dissent over inflation (USATO...
8,3,Safety Net (Forbes.com)Forbes.com - After earn...
9,3,Wall St. Bears Claw Back Into the Black NEW YO...


In [32]:
train_df[0].unique()

array([3, 4, 2, 1])

In [26]:
train_df.values

array([[3,
        "Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
        "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."],
       [3,
        'Carlyle Looks Toward Commercial Aerospace (Reuters)Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.',
        'Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
       [3,
        "Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market n

In [16]:
# Another way to load the data
import csv
import re
import numpy as np

data = []
with open(data_source, 'r', encoding='utf-8') as f:
    rdr = csv.reader (f, delimiter=',', quotechar='"')
    for row in rdr:
        txt = ""
        for s in row[1:]:
            txt = txt + " " + re.sub("^\s*(.-)\s*$", "%1", s).replace("\\n", "\n")
        data.append((int(row[0]), txt))  # format: (label, text)
    data = np.array(data)

In [18]:
data[0]

array(['3',
       " Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."],
      dtype='<U1013')

### get all data 

In [35]:
data_size = len(data)
print(data_size)
start_index = 0
end_index = data_size
print(end_index)
batch_texts = data[start_index:end_index]
print(batch_texts)
one_hot = np.eye(no_of_classes, dtype='int64')
print(one_hot)


120000
120000
[['3'
  " Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."]
 ['3'
  ' Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.']
 ['3'
  " Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums."]
 ...
 ['2'
  " Saban not going to Dolphins yet The Miami Dolphins will put their courtship of LSU coach Nick Saban on hold to comply with the NFL's hiring policy by interviewing at least one minority candidate, a team source told The Associated Press last night."]
 ['2'
  " Today's NFL games PITTS

In [41]:
# Dive into the Data Class to see what happend
alphabet="abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"
alphabet_size = len(alphabet)

char_dict = {}  # Maps each character to an integer
no_of_classes = 4

for idx, char in enumerate(alphabet):
    char_dict[char] = idx + 1

length = 1014
data_source = '../data/ag_news_csv/train.csv'

char_dict

{'!': 41,
 '"': 45,
 '#': 51,
 '$': 52,
 '%': 53,
 '&': 55,
 "'": 44,
 '(': 64,
 ')': 65,
 '*': 56,
 '+': 59,
 ',': 38,
 '-': 60,
 '.': 40,
 '/': 46,
 '0': 27,
 '1': 28,
 '2': 29,
 '3': 30,
 '4': 31,
 '5': 32,
 '6': 33,
 '7': 34,
 '8': 35,
 '9': 36,
 ':': 43,
 ';': 39,
 '<': 62,
 '=': 61,
 '>': 63,
 '?': 42,
 '@': 50,
 '[': 66,
 '\\': 47,
 ']': 67,
 '^': 54,
 '_': 49,
 '`': 58,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '{': 68,
 '|': 48,
 '}': 69,
 '~': 57}

In [72]:
len(char_dict)

68

In [43]:
length = 1014 # 1014

def str_to_indexes(s):
    """
    Convert a string to character indexes based on character dictionary.

    Args:
        s (str): String to be converted to indexes

    Returns:
        str2idx (np.ndarray): Indexes of characters in s

    """
    s = s.lower()
    max_length = min(len(s), length)
    str2idx = np.zeros(length, dtype='int64')
    for i in range(1, max_length + 1):
        c = s[-i] # Notice, the order to read the sentence is from end to start. 
        if c in char_dict:
            str2idx[i - 1] = char_dict[c] # str2idx[0]=40, becasue the last char '.' have index 40
    return str2idx

In [45]:
batch_texts[0][1]

" Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."

In [46]:
str_to_indexes(batch_texts[0][1])

array([40, 14,  9, ...,  0,  0,  0])

In [47]:
one_hot = np.eye(no_of_classes, dtype='int64')
batch_indices = []
classes = []
for c, s in batch_texts:
    batch_indices.append(str_to_indexes(s))
    c = int(c) - 1
    classes.append(one_hot[c])

Here we get the train_input(batch_indices) and train_labels(classes)

In [50]:
training_inputs = np.asarray(batch_indices)
training_labels = np.asarray(classes)

In [53]:
print(training_inputs.shape)
print(training_labels.shape)
print(training_inputs[:3])
print(training_labels[:3])

(120000, 1014)
(120000, 4)
[[40 14  9 ...  0  0  0]
 [40 20  5 ...  0  0  0]
 [40 19 13 ...  0  0  0]]
[[0 0 1 0]
 [0 0 1 0]
 [0 0 1 0]]


In [71]:
# Load validation data
validation_data = Data(data_source=config["data"]["validation_data_source"],
                       alphabet=config["data"]["alphabet"],
                       input_size=config["data"]["input_size"],
                       num_of_classes=config["data"]["num_of_classes"])
validation_data.load_data()
validation_inputs, validation_labels = validation_data.get_all_data()


Data loaded from ../data/ag_news_csv/test.csv
