In [1]:
from datasets import load_dataset
from nltk.tokenize import wordpunct_tokenize
dataset = load_dataset('ag_news')

In [2]:
text=[]
label=[]
for row in dataset['train']['text']+dataset['test']['text']:
    text.append(wordpunct_tokenize(row.lower()))
for row in dataset['train']['label']+dataset['test']['label']:
    label.append(row)

In [5]:
word_dict={'PADDING':0}
for sent in text:    
    for token in sent:        
        if token not in word_dict:
            word_dict[token]=len(word_dict)

In [6]:
MAX_SENT_LENGTH=256

news_words = []
for sent in text:       
    sample=[]
    for token in sent:     
        sample.append(word_dict[token])
    sample = sample[:MAX_SENT_LENGTH]
    news_words.append(sample+[0]*(MAX_SENT_LENGTH-len(sample)))


In [7]:
import numpy as np
news_words=np.array(news_words,dtype='int32') 
label=np.array(label,dtype='int32') 

In [6]:
index=np.arange(len(label))
train_index=index[:120000]
np.random.shuffle(train_index)
test_index=index[120000:]

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from keras.utils.np_utils import to_categorical
from keras.layers import *
from keras.models import Model, load_model
from keras import backend as K
from sklearn.metrics import *
from keras.optimizers import *


Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [8]:
import numpy as np
news_words=np.array(news_words,dtype='int32') 
label=np.array(label,dtype='int32')

In [9]:
import random
index=np.arange(len(label))
train_index=index[:120000]
test_index=index[120000:]

In [29]:

class Fastformer(Layer):

    def __init__(self, nb_head, size_per_head, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        self.now_input_shape=None
        super(Fastformer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.now_input_shape=input_shape
        self.WQ = self.add_weight(name='WQ', 
                                  shape=(input_shape[0][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WK = self.add_weight(name='WK', 
                                  shape=(input_shape[1][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True) 
        self.Wq = self.add_weight(name='Wq', 
                                  shape=(self.output_dim,self.nb_head),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.Wk = self.add_weight(name='Wk', 
                                  shape=(self.output_dim,self.nb_head),
                                  initializer='glorot_uniform',
                                  trainable=True)
        
        self.WP = self.add_weight(name='WP', 
                                  shape=(self.output_dim,self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        
        
        super(Fastformer, self).build(input_shape)
        
    def call(self, x):
        if len(x) == 2:
            Q_seq,K_seq = x
        elif len(x) == 4:
            Q_seq,K_seq,Q_mask,K_mask = x #different mask lengths, reserved for cross attention

        Q_seq = K.dot(Q_seq, self.WQ)        
        Q_seq_reshape = K.reshape(Q_seq, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))

        Q_att=  K.permute_dimensions(K.dot(Q_seq_reshape, self.Wq),(0,2,1))/ self.size_per_head**0.5

        if len(x)  == 4:
            Q_att = Q_att-(1-K.expand_dims(Q_mask,axis=1))*1e8

        Q_att = K.softmax(Q_att)
        Q_seq = K.reshape(Q_seq, (-1,self.now_input_shape[0][1], self.nb_head, self.size_per_head))
        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))
        
        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(K_seq, (-1,self.now_input_shape[1][1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))

        Q_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(Q_att)
        global_q = K.sum(multiply([Q_att, Q_seq]),axis=2)
        
        global_q_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[1][1],axis=2))(global_q)

        QK_interaction = multiply([K_seq, global_q_repeat])
        QK_interaction_reshape = K.reshape(QK_interaction, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))
        K_att = K.permute_dimensions(K.dot(QK_interaction_reshape, self.Wk),(0,2,1))/ self.size_per_head**0.5
        
        if len(x)  == 4:
            K_att = K_att-(1-K.expand_dims(K_mask,axis=1))*1e8
            
        K_att = K.softmax(K_att)

        K_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(K_att)

        global_k = K.sum(multiply([K_att, QK_interaction]),axis=2)
     
        global_k_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[0][1],axis=2))(global_k)
        #Q=V
        QKQ_interaction = multiply([global_k_repeat, Q_seq])
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))
        QKQ_interaction = K.dot(QKQ_interaction, self.WP)
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head,self.size_per_head))
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = QKQ_interaction+Q_seq
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))

        #many operations can be optimized if higher versions are used. 
        
        return QKQ_interaction
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)

In [None]:
keras.backend.clear_session() 

text_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
qmask=Lambda(lambda x:  K.cast(K.cast(x,'bool'),'float32'))(text_input)
word_emb = Embedding(len(word_dict),256, trainable=True)(text_input)

#pos_emb = Embedding(MAX_SENT_LENGTH, 256, trainable=True)(Lambda(lambda x:K.zeros_like(x,dtype='int32')+K.arange(x.shape[1]))(text_input))
#word_emb  =add([word_emb ,pos_emb])
#We find that position embedding is not important on this dataset and we removed it for simplicity. If needed, please uncomment the two lines above

word_emb=Dropout(0.2)(word_emb)

hidden_word_emb = Fastformer(16,16)([word_emb,word_emb,qmask,qmask])
hidden_word_emb = Dropout(0.2)(hidden_word_emb)
hidden_word_emb = LayerNormalization()(add([word_emb,hidden_word_emb])) 
#if there is no layer norm in old version, please import an external layernorm class from a higher version.

hidden_word_emb_layer2 = Fastformer(16,16)([hidden_word_emb,hidden_word_emb,qmask,qmask])
hidden_word_emb_layer2 = Dropout(0.2)(hidden_word_emb_layer2)
hidden_word_emb_layer2 = LayerNormalization()(add([hidden_word_emb,hidden_word_emb_layer2]))

#without FFNN for simplicity

word_att = Flatten()(Dense(1)(hidden_word_emb_layer2))
word_att = Activation('softmax')(word_att)
text_emb = Dot((1, 1))([hidden_word_emb_layer2 , word_att])
classifier = Dense(4, activation='softmax')(text_emb)
                                      
model = Model([text_input], [classifier])
model.compile(loss=['categorical_crossentropy'],optimizer=Adam(lr=0.001), metrics=['acc'])

for i in range(1):
    model.fit(news_words[train_index],to_categorical(label)[train_index],shuffle=True,batch_size=64, epochs=1,verbose=1)


    y_pred = model.predict([news_words[test_index] ], batch_size=128, verbose=1)
    y_pred = np.argmax(y_pred, axis=1)
    y_true = label[test_index]
    acc = accuracy_score(y_true, y_pred)
    report = f1_score(y_true, y_pred, average='macro')  
    print(acc)
    print(report)
