In [1]:
'''
this is a Bi_Lstm model, which will be used for text classification
Teng Li
06,Dec,2021
'''

'\nthis is a Bi_Lstm model, which will be used for text classification\nTeng Li\n06,Dec,2021\n'

In [2]:
from config import DefaultConfig
import torch
import torch.nn as nn
import torch.nn.functional as f
import os
from utils import load_flattened_documents, load_datasets
from datasets_preprocessing import clean_datasets, Movie_Classif_Dataset
from torch.utils.data import DataLoader
from Bi_LSTM import Text_Bi_LSTM, lstm_padding
import wandb

In [3]:
# first of all get config
Conf = DefaultConfig()

DEVICE = Conf.device
LR = Conf.lr
BATCH_SIZE = 3 # get from Conf.lstm_batch_size
PRINT_FREQ = 30
#INPUT_SIZE = 300 # get from Conf.input_size
#HIDDEN_SIZE = 50 # get from Conf.hidden_size
#NUM_LAYERS = 2   # get from Conf.num_layers
#BIDIRECTIONAL = 2   #get from Conf.bidirectional
#NUM_CLASS = 2    # get from Conf.num_class


In [4]:
#load raw datasets
data_root = os.path.join('Data', 'movies')
documents = load_flattened_documents(data_root,None)
documents = clean_datasets(documents)
train, val, test = load_datasets(data_root)
#load Train_Dataset
Train_Dataset = Movie_Classif_Dataset(documents,train)
#load into DataLoader
Loader = DataLoader(dataset = Train_Dataset,
                    batch_size = BATCH_SIZE,
                    shuffle=True,
                    collate_fn=lstm_padding,
                    drop_last=True)

In [3]:
if os.path.exists('Text_lstm_1.pth'):
    if torch.cuda.is_available():
        Text_lstm = torch.load('Text_lstm_1.pth').cuda()
    else:
        Text_lstm = torch.load('Text_lstm_1.pth')
    print('load from Text_lstm_1.pth')

load from Text_lstm_1.pth


In [5]:
Text_lstm = Text_Bi_LSTM()

In [6]:
def lstm_train(epochs,model,dataloader,device,Lr):
    model.train()
    loss_func = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=Lr)
    model.to(device)
    for e in range(epochs):
        loss_sum = 0
        acc = 0
        for i,(x,y) in enumerate(dataloader):
            # load x,y
            x = x.to(device)
            y = y.to(device)
            # predict
            y_hat = model(x)
            # comput loss and acc
            loss = loss_func(y_hat,y)
            label_hat = torch.argmax(y_hat,dim=1)
            acc += torch.sum(label_hat == y).item()
            loss_sum += loss
            # backward propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print info
            if (i+1)%PRINT_FREQ == 0:
                acc = acc/(BATCH_SIZE*PRINT_FREQ)
                loss_average = loss_sum/PRINT_FREQ
                wandb.log({"loss": loss_average.data,"acc":acc})
                wandb.watch(model)
                print('epoch:{}, batch:{}, loss:{}, acc:{}'.format(e,i+1,loss_average.data,acc))
                loss_sum = 0
                acc = 0


wandb.init(project='Lstm',entity='teng_li')
lstm_train(10,Text_lstm,Loader,DEVICE,0.001)

wandb: Currently logged in as: teng_li (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.12.10 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


epoch:0, batch:30, loss:0.7034931182861328, acc:0.5111111111111111
epoch:0, batch:60, loss:0.6952469944953918, acc:0.4777777777777778
epoch:0, batch:90, loss:0.6964959502220154, acc:0.4444444444444444
epoch:0, batch:120, loss:0.6975261569023132, acc:0.4444444444444444
epoch:0, batch:150, loss:0.6871842741966248, acc:0.6
epoch:0, batch:180, loss:0.6942881345748901, acc:0.5444444444444444
epoch:0, batch:210, loss:0.6934261918067932, acc:0.4888888888888889
epoch:0, batch:240, loss:0.6979597806930542, acc:0.5222222222222223
epoch:0, batch:270, loss:0.7000595331192017, acc:0.43333333333333335
epoch:0, batch:300, loss:0.6940596103668213, acc:0.4666666666666667
epoch:0, batch:330, loss:0.6889504790306091, acc:0.5555555555555556
epoch:0, batch:360, loss:0.6865410208702087, acc:0.5666666666666667
epoch:0, batch:390, loss:0.7021137475967407, acc:0.4777777777777778
epoch:0, batch:420, loss:0.6945987343788147, acc:0.5666666666666667
epoch:0, batch:450, loss:0.6941998600959778, acc:0.51111111111111

KeyboardInterrupt: 

In [25]:
# save the model
torch.save(Text_lstm,'Text_lstm_1.pth')

  "type " + obj.__name__ + ". It won't be checked "


In [5]:
acc = 0
for i,(x,y) in enumerate(Loader):
    Text_lstm.eval()
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    y_hat = Text_lstm(x)
    label_hat = torch.argmax(y_hat,dim=1)
    acc += torch.sum(label_hat == y).item()
    if i == 39:
        break
print(acc/120)

0.5583333333333333


In [11]:
#load raw datasets
data_root = os.path.join('Data', 'movies')
documents = load_flattened_documents(data_root,None)
documents = clean_datasets(documents)
train, val, test = load_datasets(data_root)
#load Train_Dataset
Val_Dataset = Movie_Classif_Dataset(documents,val)
#load into DataLoader
Loader = DataLoader(dataset = Val_Dataset,
                    batch_size = BATCH_SIZE,
                    shuffle=False,
                    collate_fn=lstm_padding,
                    drop_last=True)

In [12]:
from utils import load_documents
from Text_Cnn import get_vocab_list
from datasets_preprocessing import get_word_tag_dict
'''
step 6.0:get word_level documents and sentence_level documents, the tag_dict, and some config
'''
data_root = os.path.join('Data', 'movies')
word_level_docs = clean_datasets(load_flattened_documents(data_root,None))
sent_level_docs = load_documents(data_root)

# loss function
loss_func = nn.NLLLoss()
# first of all we need to load Vocab
Vocab_list = get_vocab_list()

#some conf of hotflip
vocab_size = Conf.search_size    # get from Conf.search_size
beam_size = Conf.beam_search_size # get from Conf.beam_search_size
change_word_num = Conf.change_word_num # get from Conf.change_word_num

# get word tag dict from given sent_level_documents
word_tag_dict = get_word_tag_dict(sent_level_docs)

In [13]:
# hotflip with constrain

import heapq
from Text_Cnn import class_to_tensor
from datasets_preprocessing import doc_to_tag, get_part_of_speech, get_wi_tag

loss_incre_sum = 0
for i,(x,y) in enumerate(Loader):
    #print(x[0].shape)
    # compute y_hat
    y_hat = Text_lstm(x)
    # record the loss before hotflip
    loss_before = loss_func(y_hat[:1],y[:1]).item()
    print('loss before hotflip:', loss_before)
    #get docid
    #docids = [train[i*3+1].annotation_id,train[i*3+2].annotation_id,train[i*3+3].annotation_id]
    docid = val[i*3+1].annotation_id
    print(docid)
    # get word_level and sent_level doc
    #word_docs = [word_level_docs[docids[0]],word_level_docs[docids[1]],word_level_docs[docids[2]]]
    word_doc = word_level_docs[docid]
    #sent_docs = [sent_level_docs[docids[0]],sent_level_docs[docids[1]],sent_level_docs[docids[2]]]
    sent_doc = sent_level_docs[docid]
    # get word_tag
    #words_tag = [doc_to_tag(sent_docs[0]),doc_to_tag(sent_docs[1]),doc_to_tag(sent_docs[2])]
    words_tag = doc_to_tag(sent_doc)
    for n in range(change_word_num):
        print('flip time:',n+1)
        # step 6.1: get gradient
        gradient = Text_lstm.get_gradient(x,y)
        #print(gradient[0].shape)
        # step 6.2: get best w0
        w0_list = []# a set to record mean loss of each w0
        for w0_place,w0 in enumerate(word_doc):
            w0_id =x[0][w0_place]
            w0_embed = Text_lstm.get_word_embedding(w0_id)
            #print(w0_embed.shape)
            w0_gradient = gradient[0][0][w0_place]
            #print(w0_gradient.shape)
            if w0_gradient.sum() != 0:
                w0_loss_estimate = torch.dot(w0_gradient,w0_embed).item()
                w0_list.append({'w0':w0,'w0_place':w0_place,'w0_gradient':w0_gradient,
                                'w0_embed':w0_embed,'w0_loss_estimate':w0_loss_estimate})
        # now let's get best w0 list
        best_w0 = heapq.nsmallest(beam_size,w0_list,key=lambda s: s['w0_loss_estimate'])
        #print(best_w0)
        
        #step 6.3: the best word(wi) for each filped word (max (w0_gradient * wi_embed))
        final_list = []
        for w0_info in best_w0:
            w0 = w0_info['w0']
            #print('w0',w0)
            w0_place = w0_info['w0_place']
            #print('w0_palce',w0_place)
            w0_tag = words_tag[w0_place]        
            best_wi = {'wi_id':None,'wi_embed':None,'wi_loss_estimate':0}
            #print('best wi refresh')
            w0_gradient = w0_info['w0_gradient']
            for wi_id in range(vocab_size):
                # get word wi
                wi_index = torch.tensor(wi_id)
                wi = Vocab_list[wi_id]
                # get the tag of wi
                if wi in word_tag_dict:
                    wi_tag = word_tag_dict[wi]
                else:
                    wi_tag = set()
                    wi_tag.add(get_part_of_speech(wi))
                #check w0_tag and wi_tag, if not same, skip to next word
                if w0_tag not in wi_tag:
                    continue
                # check if the tag changed after flip. if changed, skip
                tag = get_wi_tag(sent_doc,w0_place,wi)
                if tag != w0_tag:
                    continue
                # up to now the wi is qualified, then we can record the best wi
                wi_embed = Text_lstm.get_word_embedding(wi_index)
                wi_loss_estimate = torch.dot(w0_gradient,wi_embed).item()
                if wi_loss_estimate > best_wi['wi_loss_estimate']:
                    #print(wi)
                    best_wi['wi_id'] = wi_id
                    best_wi['wi_embed'] = wi_embed
                    best_wi['wi_loss_estimate'] = wi_loss_estimate
            #now we get best wi for w0
            if best_wi['wi_loss_estimate'] == 0:
                #print('no wi')
                continue
            w0_embed = w0_info['w0_embed']
            #print(w0_embed.shape)
            wi_embed = best_wi['wi_embed']
            #print(wi_embed.shape)
            embed_diff = wi_embed - w0_embed
            loss_estimate = torch.dot(w0_gradient,embed_diff).item()
            final_list.append({'w0':w0_info['w0'],'w0_place':w0_info['w0_place'],'wi_id':best_wi['wi_id'],
                               'wi':Vocab_list[best_wi['wi_id']],'loss_estimate':loss_estimate})
            
        # then step 6.4: get best w0 and wi
        final_filp = heapq.nlargest(1,final_list,key=lambda s:s['loss_estimate'])[0]
        print(final_filp)
        w0_place = final_filp['w0_place']
        wi_id = final_filp['wi_id']
        wi = final_filp['wi']
        
        # step 6.5: replace w0 with wi
        x[0][w0_place] = wi_id
        word_doc[w0_place] = wi
        # comput loss after this flip
        y_hat = Text_lstm(x)
        loss = loss_func(y_hat,y).item()
        
        
    # record loss increased 
    print('mean loss increased record:',(loss-loss_before)/change_word_num)
    loss_incre_sum += (loss-loss_before)/change_word_num
    print('doc_num:',i+1)
    print('average up to now:',loss_incre_sum/(i+1))
        
        
        
    

loss before hotflip: 0.1413482129573822
negR_801.txt
flip time: 1
{'w0': 'hitchcock', 'w0_place': 1, 'wi_id': 1158, 'wi': 'contact', 'loss_estimate': 0.1594056934118271}
flip time: 2
{'w0': 'according', 'w0_place': 0, 'wi_id': 1429, 'wi': 'losing', 'loss_estimate': 1.2970805168151855}
flip time: 3
{'w0': 'gas', 'w0_place': 8, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.8442410230636597}
flip time: 4
{'w0': 'diners', 'w0_place': 7, 'wi_id': 488, 'wi': 'others', 'loss_estimate': 0.4847698211669922}
flip time: 5
{'w0': 'other', 'w0_place': 3, 'wi_id': 678, 'wi': 'annual', 'loss_estimate': 0.2064584344625473}
mean loss increased record: 0.7453179657459259
doc_num: 1
average up to now: 0.7453179657459259
loss before hotflip: 5.255608558654785
negR_804.txt
flip time: 1
{'w0': 'for', 'w0_place': 0, 'wi_id': 420, 'wi': 'near', 'loss_estimate': 0.5915592312812805}
flip time: 2
{'w0': 'oldman', 'w0_place': 5, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.38450807332992554}
flip time: 3

{'w0': 'comedy', 'w0_place': 7, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.5841948390007019}
flip time: 5
{'w0': 'irresistible', 'w0_place': 4, 'wi_id': 678, 'wi': 'annual', 'loss_estimate': 0.6671624779701233}
mean loss increased record: 0.4963750584051013
doc_num: 12
average up to now: 0.33387769446708265
loss before hotflip: 1.2611513137817383
negR_837.txt
flip time: 1
{'w0': 'homosexual', 'w0_place': 4, 'wi_id': 678, 'wi': 'annual', 'loss_estimate': 1.5071110725402832}
flip time: 2
{'w0': 'disappointing', 'w0_place': 0, 'wi_id': 678, 'wi': 'annual', 'loss_estimate': 1.055666446685791}
flip time: 3
{'w0': 'about', 'w0_place': 2, 'wi_id': 1124, 'wi': 'toward', 'loss_estimate': 0.3970580995082855}
flip time: 4
{'w0': 'relationship', 'w0_place': 5, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.10631173849105835}
flip time: 5
{'w0': 'film', 'w0_place': 13, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.04954727366566658}
mean loss increased record: 0.4344194412231445
doc_n

{'w0': 'have', 'w0_place': 3, 'wi_id': 1422, 'wi': 'gold', 'loss_estimate': 0.027066044509410858}
flip time: 2
{'w0': 'days', 'w0_place': 1, 'wi_id': 360, 'wi': 'times', 'loss_estimate': 0.7009806632995605}
flip time: 3
{'w0': 'these', 'w0_place': 0, 'wi_id': 86, 'wi': 'no', 'loss_estimate': 0.7547979354858398}
flip time: 4
{'w0': 'attention', 'w0_place': 6, 'wi_id': 576, 'wi': 'history', 'loss_estimate': 0.5181576013565063}
flip time: 5
{'w0': 'rather', 'w0_place': 4, 'wi_id': 138, 'wi': 'very', 'loss_estimate': 0.3115944266319275}
mean loss increased record: 0.8772241599857807
doc_num: 24
average up to now: 0.40969861935203283
loss before hotflip: 1.5215798616409302
negR_873.txt
flip time: 1
{'w0': '1963', 'w0_place': 3, 'wi_id': 81, 'wi': 'million', 'loss_estimate': 1.2731159925460815}
flip time: 2
{'w0': 'remake', 'w0_place': 1, 'wi_id': 1180, 'wi': 'serve', 'loss_estimate': 0.612578272819519}
flip time: 3
{'w0': 'the', 'w0_place': 2, 'wi_id': 86, 'wi': 'no', 'loss_estimate': 0.229

flip time: 5
{'w0': 'robert', 'w0_place': 7, 'wi_id': 302, 'wi': 'financial', 'loss_estimate': 0.11211997270584106}
mean loss increased record: 0.031921911239624026
doc_num: 35
average up to now: 0.41814798587001867
loss before hotflip: 0.06231534481048584
posR_806.txt
flip time: 1
{'w0': 'casted', 'w0_place': 4, 'wi_id': 302, 'wi': 'financial', 'loss_estimate': 0.09013219177722931}
flip time: 2
{'w0': 'is', 'w0_place': 2, 'wi_id': 1530, 'wi': 'appears', 'loss_estimate': 1.0908896923065186}
flip time: 3
{'w0': 'bruce', 'w0_place': 0, 'wi_id': 153, 'wi': 'think', 'loss_estimate': 1.0456347465515137}
flip time: 4
{'w0': 'willis', 'w0_place': 1, 'wi_id': 842, 'wi': 'talk', 'loss_estimate': 0.5181142687797546}
flip time: 5
{'w0': 'he', 'w0_place': 9, 'wi_id': 74, 'wi': 'she', 'loss_estimate': 0.24477455019950867}
mean loss increased record: 0.5963659048080444
doc_num: 36
average up to now: 0.42309848361829716
loss before hotflip: 4.4392852783203125
posR_809.txt
flip time: 1
{'w0': 'daisy',

flip time: 3
{'w0': 'nothing', 'w0_place': 1, 'wi_id': 842, 'wi': 'talk', 'loss_estimate': 0.4768213629722595}
flip time: 4
{'w0': 'the', 'w0_place': 10, 'wi_id': 184, 'wi': 'both', 'loss_estimate': 0.15099206566810608}
flip time: 5
{'w0': 'first', 'w0_place': 14, 'wi_id': 1444, 'wi': 'worst', 'loss_estimate': 0.03032899647951126}
mean loss increased record: 0.08027348518371583
doc_num: 47
average up to now: 0.3722170407834285
loss before hotflip: 0.01914031058549881
posR_842.txt
flip time: 1
{'w0': 'butthead', 'w0_place': 2, 'wi_id': 842, 'wi': 'talk', 'loss_estimate': 0.023387514054775238}
flip time: 2
{'w0': 'if', 'w0_place': 0, 'wi_id': 369, 'wi': 'whether', 'loss_estimate': 0.6166703104972839}
flip time: 3
{'w0': 'beavis', 'w0_place': 1, 'wi_id': 153, 'wi': 'think', 'loss_estimate': 0.9818259477615356}
flip time: 4
{'w0': 'favorite', 'w0_place': 4, 'wi_id': 302, 'wi': 'financial', 'loss_estimate': 0.6527581214904785}
flip time: 5
{'w0': 'by', 'w0_place': 15, 'wi_id': 405, 'wi': 'w

{'w0': 'once', 'w0_place': 1, 'wi_id': 227, 'wi': 'here', 'loss_estimate': 0.0007619103416800499}
flip time: 2
{'w0': 'every', 'w0_place': 0, 'wi_id': 28, 'wi': 'this', 'loss_estimate': 0.038206424564123154}
flip time: 3
{'w0': 'an', 'w0_place': 5, 'wi_id': 3, 'wi': 'that', 'loss_estimate': 0.4487093985080719}
flip time: 4
{'w0': 'exceptional', 'w0_place': 6, 'wi_id': 302, 'wi': 'financial', 'loss_estimate': 0.5551727414131165}
flip time: 5
{'w0': 'family', 'w0_place': 7, 'wi_id': 389, 'wi': 'include', 'loss_estimate': 0.2881307005882263}
mean loss increased record: 0.4302235004492104
doc_num: 59
average up to now: 0.37406713990415685
loss before hotflip: 0.019046278670430183
posR_878.txt
flip time: 1
{'w0': 'know', 'w0_place': 2, 'wi_id': 842, 'wi': 'talk', 'loss_estimate': 0.011121205985546112}
flip time: 2
{'w0': 'did', 'w0_place': 5, 'wi_id': 582, 'wi': 'released', 'loss_estimate': 0.34908729791641235}
flip time: 3
{'w0': 'last', 'w0_place': 6, 'wi_id': 302, 'wi': 'financial', 'los

In [30]:
t = torch.zeros(4,3,50)
print(t.shape)
a = t[:,:1,:]
print(a.shape)

torch.Size([4, 3, 50])
torch.Size([4, 1, 50])
