In [1]:
import requests
from bs4 import BeautifulSoup
import random
import re
import json
import copy
from collections import defaultdict
import time
from tqdm import tqdm
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
#nltk.download('punkt')
#nltk.download('stopwords')


In [2]:
####WIKIPEDIA SCRAPING#####
response = requests.get('https://en.wikipedia.org/wiki/Machine_learning')
html = response.text
soup = BeautifulSoup(html, 'html.parser')
stop_words = set(stopwords.words('english'))

# get random wikipedia corpus
alllinks = []
num_links = 0
max_links = 10
print("Scraping links...")
pbar = tqdm(total = max_links)
while(num_links < max_links):
    links = []
    # Get all the links
    for link in soup.findAll('a', attrs={'href': re.compile("^/wiki/")}):
        links.append(link.get('href'))
    random.shuffle(links)
    alllinks.append(links[0])
    soup = BeautifulSoup(requests.get(
        "http://en.wikipedia.org" + links[0]).text, 'html.parser')
    num_links += 1
    pbar.update(1)
pbar.close()
alltext = ""

print("Extracting text...")
# put all words into text file
for link in tqdm(alllinks):
    response = requests.get("http://en.wikipedia.org" + link)
    html = response.text
    soup = BeautifulSoup(html, 'html.parser')
    # extract paragraphs
    text_container = soup.find('div', {'class': 'mw-parser-output'})
    zero_paragaph = {"title": "", "text": ""}
    
    if(text_container == None):
        continue
    current_paragraph = copy.deepcopy(zero_paragaph)
    page = {'paragraphs': []}
    for child in text_container.children:
        if child.name == "p":
            current_paragraph["text"] += child.text + "\n"
        elif child.name == "h2":
            page['paragraphs'].append(current_paragraph)
            current_paragraph = copy.deepcopy(zero_paragaph)
            current_paragraph["title"] = next(child.children).text

        page['paragraphs'] = list(
            filter(lambda x: x["text"] != "", page['paragraphs']))
        
    for pg in page['paragraphs']:
        alltext += pg['text']
        
uncleaned_tokens = word_tokenize(alltext)


print("Cleaning...")
#clean up tokens
stop_words = set(stopwords.words('english'))
tokens = []
for tok in tqdm(uncleaned_tokens):
    tok = tok.lower()
    if(tok.isalnum() and not tok.isdigit()):
        tokens.append(tok)
        
        


Scraping links...


100%|██████████| 10/10 [00:05<00:00,  1.88it/s]


Extracting text...


100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Cleaning...


100%|██████████| 18986/18986 [00:00<00:00, 1713939.47it/s]


In [3]:
#### CORPUS ####
class Corpus(torch.utils.data.Dataset):
    def __init__(self, tokens, ngrams=1):
        self.ngrams = ngrams
        self.tokens = tokens  #list of words
        self.totlen = len(tokens)
        
        # vocab will be defined over kmers
        self.create_vocab()
        
        
    def create_vocab(self):
        self.vocab_freq = defaultdict(float)
        st_time = time.time()
        for word in self.tokens:
            # compute freq of each word
            if word in self.vocab_freq:
                self.vocab_freq[word] += 1
            else:
                self.vocab_freq[word] = 1
     
        
        self.vocab = sorted(self.vocab_freq.keys())
        
        # create forward and reverse index for all the words in vocab
        self.word_to_idx = defaultdict(lambda: 0)
        self.idx_to_word = defaultdict(lambda: 0)
        for idx, w in enumerate(self.vocab):
            self.word_to_idx[w] = idx
            self.idx_to_word[idx] = w
        
        self.vocab_prob = np.array([self.vocab_freq[k] for k in self.vocab])
        self.vocab_freq_scaler = 0.73
        total_freq = float(self.vocab_prob.sum())
        self.vocab_prob = self.vocab_prob / total_freq
        self.vocab_csum = np.cumsum(self.vocab_prob)
        
        en_time = time.time()
        print("corpus construct time (seconds):", en_time - st_time, "num tokens:", total_freq)
        
    def __len__(self):
        return self.totlen
    
    
    def __getitem__(self, index):
        if(index >= len(self.tokens) - 15):
            return (
                torch.tensor([self.word_to_idx[word] for word in self.tokens[index:index+10]]),
                torch.tensor([self.word_to_idx[word] for word in self.tokens[index:index+10]]),
            )
        else:
            return (
                torch.tensor([self.word_to_idx[word] for word in self.tokens[index:index+10]]),
                torch.tensor([self.word_to_idx[word] for word in self.tokens[index+1:index+11]]),
            )




In [4]:
#### POS/NEG TARGET CONTEXT CREATION ####
class NegSampler:
    '''generate a block of negative samples from the cumsum array (Cumulative Distribution Function)'''
    def __init__(self, csum_ary):
        self.csum_ary = csum_ary
        self.time = 0.

    def get_neg_words(self, num_words):
        '''get num_words negative words sampled from cumsum array'''
        st_time = time.time()
        nprobs = np.random.random(num_words)
        neg_words = np.searchsorted(self.csum_ary, nprobs)
        en_time = time.time()
        self.time += en_time-st_time
        return neg_words

class PosNegSampler(torch.utils.data.IterableDataset):
    '''This class creates a block of positive and negative pairs for word2vec training
    The iterable will return a numpy array of target, context and label triples'''
    
    def __init__(self, C, window_size, neg_samples, block_size):
        super(PosNegSampler, self).__init__()
        self.window_size = window_size
        self.neg_samples = neg_samples
        self.time = 0.
        self.block_sz = block_size
        self.C = C
        
        self.negsampler = NegSampler(self.C.vocab_csum)
                    
    def context_data(self, block_sz):
        '''generate center word, context word pairs
        '''
        T = []
        C = []
        for i, word in enumerate(self.C.tokens):
            #get window of words and create target context pairs
            start_idx = max(0, i - self.window_size)
            end_idx = min(len(self.C.vocab), i + self.window_size + 1)
            for j in range(start_idx, end_idx):
                if i != j:
                    T.append(self.C.word_to_idx[word])
                    C.append(self.C.word_to_idx[self.C.tokens[j]])

            # return a block of T, C
            if len(T) >= block_sz:
                yield (T, C)
                T, C = [], []
                        
        # return any remining elements           
        yield (T, C)
            

    def __iter__(self):
        '''return one pos word and neq_samples neg words and the labels
           use context_data to retrieve a block
        '''

        st_time = time.time()
        
        for i, (T, C) in enumerate(self.context_data(self.block_sz)):
            Tnp = np.array(T)
            Cnp = np.array(C)
            L = np.ones(len(T))
            yield (Tnp, Cnp, L)
            
            L = np.zeros(len(T))
            N = self.negsampler.get_neg_words(
                len(T))
            Nnp = np.array(N)
            yield (Tnp, Nnp, L)
                
        en_time = time.time()
        self.time += en_time-st_time

In [5]:
class Word2Vec(nn.Module):
    '''The word2vec model to train the word embeddings'''
    def __init__(self, embedding_size, vocab_size):
        super(Word2Vec, self).__init__()
        self.embedding_size = embedding_size
        self.T = nn.Embedding(vocab_size, embedding_size)
        self.C = nn.Embedding(vocab_size, embedding_size)

    def forward(self, target_word, context_word, label):
        t = self.T(target_word)
        c = self.C(context_word)
        out = torch.sum(t * c, dim=1)
        return out
    
    def save_embeddings(self, file_name, idx_to_word):
        # average the T and C matrices
        W = (net.T.weight.cpu().data.numpy() + net.C.weight.cpu().data.numpy())/2.

        with open(file_name, "w") as f:
            f.write("%d %d\n" % (len(idx_to_word), self.embedding_size))
            for wid, w in idx_to_word.items():
                e = ' '.join(map(lambda x: str(x), W[wid]))
                f.write("%s %s\n" % (w, e))


In [6]:
if torch.cuda.is_available():
    device = "cuda"
    print("using device", torch.cuda.get_device_name(device))
else:
    device = "cpu"
    
#scrape and read
c = Corpus(tokens)
window_size = 30
neg_samples = 60
block_size = 1024

PNS = PosNegSampler(c, window_size, neg_samples, block_size)
V = len(PNS.C.vocab)
print("vocab, device: ", V, device)

training_generator = data.DataLoader(
    PNS, batch_size = 1
)


#model parameters
word_embedding_dim = 200
epochs = 200
learning_rate = 0.01


# create the NN model
net = Word2Vec(embedding_size=word_embedding_dim, vocab_size=V)
net.to(device)
net.train()
net
loss_function = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(net.parameters(), lr=learning_rate)


start_t = time.time()
for e in range(epochs):
    running_loss = 0
    for bidx, (targets, contexts, labels) in enumerate(
        tqdm(training_generator)):
        
        targets = targets.flatten().to(device)
        contexts = contexts.flatten().to(device)
        labels = labels.flatten().to(device)
        
        net.zero_grad()
        preds = net(targets, contexts, labels)
        loss = loss_function(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("epoch", e, running_loss, bidx, running_loss / (bidx + 1))
    
end_t = time.time()
print("finished in time", end_t - start_t)


for mode in ['avg', 'target', 'context']:
    output_file = "./word_embeds.vec"
    net.save_embeddings(output_file, PNS.C.idx_to_word)

using device Tesla T4
corpus construct time (seconds): 0.005596637725830078 num tokens: 15322.0
vocab, device:  2795 cuda


310it [00:00, 639.99it/s]


epoch 0 1276.6990155759395 309 4.118383921212708


310it [00:00, 684.67it/s]


epoch 1 712.2459422808568 309 2.297567555744699


310it [00:00, 699.35it/s]


epoch 2 476.11668389051533 309 1.5358602706145656


310it [00:00, 710.09it/s]


epoch 3 361.9383429140878 309 1.1675430416583477


310it [00:00, 704.08it/s]


epoch 4 295.7968746200265 309 0.9541834665162144


310it [00:00, 715.44it/s]


epoch 5 253.5810178796054 309 0.8180032834825981


310it [00:00, 711.63it/s]


epoch 6 228.17385029532295 309 0.7360446783720095


310it [00:00, 685.25it/s]


epoch 7 207.78039001727402 309 0.6702593226363678


310it [00:00, 704.61it/s]


epoch 8 193.77796570200147 309 0.6250902119419403


310it [00:00, 713.47it/s]


epoch 9 184.57293254599446 309 0.5953965565999821


310it [00:00, 721.26it/s]


epoch 10 177.58950211687463 309 0.5728693616673375


310it [00:00, 715.57it/s]


epoch 11 171.58007669971263 309 0.5534841183861698


310it [00:00, 714.92it/s]


epoch 12 167.23834095843787 309 0.5394785192207673


310it [00:00, 721.65it/s]


epoch 13 164.0709223660749 309 0.5292610398905643


310it [00:00, 724.00it/s]


epoch 14 162.9240228110564 309 0.5255613639066335


310it [00:00, 717.94it/s]


epoch 15 159.6164875824417 309 0.5148918954272312


310it [00:00, 722.41it/s]


epoch 16 157.87150428840297 309 0.5092629170593644


310it [00:00, 719.09it/s]


epoch 17 156.7207372795415 309 0.5055507654178758


310it [00:00, 722.14it/s]


epoch 18 154.27721881068112 309 0.4976684477763907


310it [00:00, 721.38it/s]


epoch 19 155.10556420854047 309 0.5003405297049692


310it [00:00, 716.49it/s]


epoch 20 152.9997748926057 309 0.49354766094388935


310it [00:00, 724.70it/s]


epoch 21 152.50984822218967 309 0.4919672523296441


310it [00:00, 721.46it/s]


epoch 22 152.26834205150047 309 0.4911882001661305


310it [00:00, 725.90it/s]


epoch 23 152.11261438817039 309 0.4906858528650658


310it [00:00, 724.01it/s]


epoch 24 151.16121164573764 309 0.48761681176044397


310it [00:00, 722.17it/s]


epoch 25 150.06077265473346 309 0.4840670085636563


310it [00:00, 724.88it/s]


epoch 26 149.7868127666048 309 0.4831832669890478


310it [00:00, 715.98it/s]


epoch 27 148.3682689968922 309 0.47860731934481354


310it [00:00, 650.75it/s]


epoch 28 147.98666110645658 309 0.4773763261498599


310it [00:00, 709.93it/s]


epoch 29 149.07155921000847 309 0.4808759974516402


310it [00:00, 718.84it/s]


epoch 30 147.31821272455176 309 0.47522004104694116


310it [00:00, 721.11it/s]


epoch 31 147.92076329355206 309 0.47716375255984533


310it [00:00, 723.07it/s]


epoch 32 148.3989126163287 309 0.47870616973009256


310it [00:00, 725.12it/s]


epoch 33 145.990516919318 309 0.4709371513526387


310it [00:00, 717.45it/s]


epoch 34 145.43882072038505 309 0.4691574861947905


310it [00:00, 722.56it/s]


epoch 35 145.02798583616544 309 0.4678322123747272


310it [00:00, 721.46it/s]


epoch 36 144.60694959912033 309 0.4664740309649043


310it [00:00, 718.29it/s]


epoch 37 145.39208231914267 309 0.46900671715852477


310it [00:00, 723.35it/s]


epoch 38 143.98826640928996 309 0.46447827873964503


310it [00:00, 723.86it/s]


epoch 39 143.6808632620754 309 0.46348665568411423


310it [00:00, 718.26it/s]


epoch 40 143.7059962948 309 0.4635677299832258


310it [00:00, 716.03it/s]


epoch 41 142.66639526874064 309 0.46021417828626016


310it [00:00, 718.25it/s]


epoch 42 142.38130328991645 309 0.45929452674166593


310it [00:00, 715.76it/s]


epoch 43 142.85967464314976 309 0.46083766013919275


310it [00:00, 719.17it/s]


epoch 44 142.8360331698714 309 0.4607613973221658


310it [00:00, 726.80it/s]


epoch 45 141.27773220300983 309 0.45573462000970916


310it [00:00, 712.55it/s]


epoch 46 141.22645073361022 309 0.45556919591487166


310it [00:00, 716.06it/s]


epoch 47 141.14964421205386 309 0.4553214329421092


310it [00:00, 724.55it/s]


epoch 48 140.96288761963146 309 0.4547189923213918


310it [00:00, 715.81it/s]


epoch 49 140.17936842329132 309 0.45219151104287525


310it [00:00, 716.71it/s]


epoch 50 140.09416306321774 309 0.45191665504263784


310it [00:00, 711.32it/s]


epoch 51 140.17090198123702 309 0.45216419993947427


310it [00:00, 714.16it/s]


epoch 52 139.68160162111602 309 0.4505858116810194


310it [00:00, 708.83it/s]


epoch 53 138.7530963570045 309 0.4475906334096919


310it [00:00, 716.29it/s]


epoch 54 138.11285589993145 309 0.4455253416126821


310it [00:00, 707.67it/s]


epoch 55 138.87018849665026 309 0.4479683499891944


310it [00:00, 709.09it/s]


epoch 56 138.93141413840362 309 0.44816585205936654


310it [00:00, 708.08it/s]


epoch 57 138.64493399247635 309 0.44724172255637534


310it [00:00, 714.72it/s]


epoch 58 137.85624447453284 309 0.4446975628210737


310it [00:00, 713.77it/s]


epoch 59 137.488634950714 309 0.4435117256474645


310it [00:00, 713.76it/s]


epoch 60 137.22197470365765 309 0.4426515313021215


310it [00:00, 704.70it/s]


epoch 61 137.04602897077953 309 0.4420839644218695


310it [00:00, 713.64it/s]


epoch 62 137.22122435728124 309 0.4426491108299395


310it [00:00, 719.31it/s]


epoch 63 136.83614115506623 309 0.44140690695182655


310it [00:00, 717.47it/s]


epoch 64 136.42060738903265 309 0.4400664754484924


310it [00:00, 720.68it/s]


epoch 65 136.0620061626625 309 0.43890969729891127


310it [00:00, 718.37it/s]


epoch 66 136.44198960082426 309 0.44013545032523954


310it [00:00, 717.52it/s]


epoch 67 136.16506785395654 309 0.4392421543676018


310it [00:00, 717.25it/s]


epoch 68 135.31213536132566 309 0.4364907592300828


310it [00:00, 720.42it/s]


epoch 69 135.76703148327837 309 0.4379581660750915


310it [00:00, 719.96it/s]


epoch 70 135.10413913992627 309 0.43581980367718154


310it [00:00, 716.89it/s]


epoch 71 134.7285446470769 309 0.4346082085389577


310it [00:00, 719.51it/s]


epoch 72 134.78696034638045 309 0.4347966462786466


310it [00:00, 709.24it/s]


epoch 73 134.4298895729879 309 0.4336448050741545


310it [00:00, 718.57it/s]


epoch 74 134.45072520767874 309 0.4337120167989637


310it [00:00, 722.96it/s]


epoch 75 134.53305996305917 309 0.43397761278406183


310it [00:00, 712.18it/s]


epoch 76 134.18732805994992 309 0.4328623485804836


310it [00:00, 721.52it/s]


epoch 77 133.94205471566463 309 0.43207114424407944


310it [00:00, 712.19it/s]


epoch 78 134.3387606798849 309 0.43335084090285453


310it [00:00, 717.39it/s]


epoch 79 134.02532603188732 309 0.4323397613931849


310it [00:00, 710.65it/s]


epoch 80 133.8467275191946 309 0.4317636371586922


310it [00:00, 717.84it/s]


epoch 81 133.76140341884758 309 0.43148839812531475


310it [00:00, 714.34it/s]


epoch 82 133.15837587827102 309 0.4295431479944226


310it [00:00, 717.28it/s]


epoch 83 132.69827226589516 309 0.42805894279321016


310it [00:00, 719.97it/s]


epoch 84 132.82368432291122 309 0.42846349781584264


310it [00:00, 722.66it/s]


epoch 85 132.84987244255544 309 0.42854797562114655


310it [00:00, 720.91it/s]


epoch 86 132.5587241552991 309 0.427608787597739


310it [00:00, 715.23it/s]


epoch 87 132.57043095496218 309 0.42764655146761993


310it [00:00, 713.52it/s]


epoch 88 132.4852851114451 309 0.4273718874562745


310it [00:00, 720.38it/s]


epoch 89 133.04571346321768 309 0.4291797208490893


310it [00:00, 712.08it/s]


epoch 90 133.01167691054326 309 0.4290699255178815


310it [00:00, 726.06it/s]


epoch 91 132.8803299561089 309 0.42864622566486743


310it [00:00, 718.21it/s]


epoch 92 131.89523929418144 309 0.4254685138521982


310it [00:00, 712.86it/s]


epoch 93 131.44592069461487 309 0.42401909901488666


310it [00:00, 717.15it/s]


epoch 94 131.45124447968777 309 0.4240362725151218


310it [00:00, 723.23it/s]


epoch 95 131.22522228401306 309 0.42330716865810664


310it [00:00, 719.43it/s]


epoch 96 131.09510339593092 309 0.4228874303094546


310it [00:00, 714.39it/s]


epoch 97 132.09392214522376 309 0.4261094262749153


310it [00:00, 720.49it/s]


epoch 98 131.49683908582693 309 0.4241833518897643


310it [00:00, 712.30it/s]


epoch 99 131.75230418043176 309 0.42500743284010245


310it [00:00, 717.09it/s]


epoch 100 131.04353269013822 309 0.42272107319399427


310it [00:00, 716.06it/s]


epoch 101 130.74413830723014 309 0.42175528486203273


310it [00:00, 723.35it/s]


epoch 102 130.87948943118855 309 0.42219190139093077


310it [00:00, 728.36it/s]


epoch 103 130.6176759549765 309 0.4213473417902468


310it [00:00, 716.18it/s]


epoch 104 131.3895253934199 309 0.4238371786884513


310it [00:00, 718.14it/s]


epoch 105 130.4218519054922 309 0.42071565130803934


310it [00:00, 712.42it/s]


epoch 106 130.8321611868975 309 0.4220392296351533


310it [00:00, 726.76it/s]


epoch 107 130.763514408731 309 0.4218177884152613


310it [00:00, 721.44it/s]


epoch 108 129.996689226903 309 0.41934415879646125


310it [00:00, 715.74it/s]


epoch 109 130.88081965717478 309 0.4221961924424993


310it [00:00, 721.29it/s]


epoch 110 130.02245888572 309 0.419427286728129


310it [00:00, 710.89it/s]


epoch 111 130.34302033704685 309 0.42046135592595757


310it [00:00, 723.84it/s]


epoch 112 131.11529367738535 309 0.4229525602496302


310it [00:00, 716.06it/s]


epoch 113 131.07745798179238 309 0.42283050961868507


310it [00:00, 722.77it/s]


epoch 114 130.9478128055596 309 0.4224122993727729


310it [00:00, 722.11it/s]


epoch 115 130.28040143547395 309 0.4202593594692708


310it [00:00, 718.39it/s]


epoch 116 129.4603850115572 309 0.41761414519857165


310it [00:00, 726.17it/s]


epoch 117 129.6805340792338 309 0.4183243034813994


310it [00:00, 724.03it/s]


epoch 118 130.64930087219653 309 0.4214493576522469


310it [00:00, 715.65it/s]


epoch 119 130.31284758921427 309 0.42036402448133636


310it [00:00, 711.01it/s]


epoch 120 130.28220791296545 309 0.42026518681601754


310it [00:00, 725.24it/s]


epoch 121 130.06004800783856 309 0.41954854196076957


310it [00:00, 722.39it/s]


epoch 122 130.03209726375988 309 0.41945837827019317


310it [00:00, 721.91it/s]


epoch 123 129.63571676949027 309 0.41817973151448473


310it [00:00, 716.96it/s]


epoch 124 129.76654364753495 309 0.41860175370172564


310it [00:00, 712.24it/s]


epoch 125 130.21421598196486 309 0.42004585800633826


310it [00:00, 722.45it/s]


epoch 126 129.7069147885796 309 0.4184094025438052


310it [00:00, 718.78it/s]


epoch 127 129.5727849929115 309 0.41797672578358547


310it [00:00, 720.08it/s]


epoch 128 129.83385417145573 309 0.41881888442405074


310it [00:00, 719.64it/s]


epoch 129 129.6828514024422 309 0.41833177871755545


310it [00:00, 718.14it/s]


epoch 130 129.0836188016986 309 0.41639877032806


310it [00:00, 719.11it/s]


epoch 131 129.70025771034847 309 0.41838792809789827


310it [00:00, 716.37it/s]


epoch 132 129.90507355636856 309 0.4190486243753825


310it [00:00, 721.13it/s]


epoch 133 129.54597234410696 309 0.41789023336808695


310it [00:00, 717.44it/s]


epoch 134 128.9788623239492 309 0.4160608462062877


310it [00:00, 723.56it/s]


epoch 135 129.4273102517634 309 0.4175074524250432


310it [00:00, 714.04it/s]


epoch 136 129.3096363660338 309 0.41712785924527035


310it [00:00, 716.97it/s]


epoch 137 128.94217842523008 309 0.41594251104912927


310it [00:00, 714.06it/s]


epoch 138 128.68636356831936 309 0.41511730183328827


310it [00:00, 715.57it/s]


epoch 139 129.13988039549477 309 0.41658025934030574


310it [00:00, 708.01it/s]


epoch 140 129.30422446053555 309 0.4171104014855985


310it [00:00, 713.76it/s]


epoch 141 129.0123133296409 309 0.41616875267626097


310it [00:00, 714.33it/s]


epoch 142 128.7323858805916 309 0.41526576090513423


310it [00:00, 716.13it/s]


epoch 143 129.29667744992904 309 0.41708605629009365


310it [00:00, 712.54it/s]


epoch 144 129.1998776194695 309 0.4167737987724823


310it [00:00, 721.17it/s]


epoch 145 129.12366325157825 309 0.41652794597283305


310it [00:00, 718.21it/s]


epoch 146 129.49405229272674 309 0.41772274933137654


310it [00:00, 720.53it/s]


epoch 147 129.25104866549233 309 0.4169388666628785


310it [00:00, 707.93it/s]


epoch 148 129.34988624036944 309 0.41725769754957887


310it [00:00, 690.85it/s]


epoch 149 130.10846870022246 309 0.4197047377426531


310it [00:00, 714.45it/s]


epoch 150 129.4192484001853 309 0.41748144645221064


310it [00:00, 717.01it/s]


epoch 151 129.38538417782542 309 0.41737220702524325


310it [00:00, 719.21it/s]


epoch 152 128.0420101468665 309 0.41303874240924676


310it [00:00, 717.07it/s]


epoch 153 128.80800689760457 309 0.41550969966969215


310it [00:00, 713.90it/s]


epoch 154 128.72789520797647 309 0.41525127486444025


310it [00:00, 711.91it/s]


epoch 155 129.22625428246235 309 0.4168588847821366


310it [00:00, 716.70it/s]


epoch 156 128.75083399407242 309 0.4153252709486207


310it [00:00, 706.91it/s]


epoch 157 129.634823225839 309 0.41817684911560965


310it [00:00, 710.48it/s]


epoch 158 128.7889547254009 309 0.41544824104968037


310it [00:00, 717.38it/s]


epoch 159 129.67866767025043 309 0.41831828280725947


310it [00:00, 717.83it/s]


epoch 160 129.2611332522898 309 0.4169713975880316


310it [00:00, 705.48it/s]


epoch 161 129.45709739013472 309 0.4176035399681765


310it [00:00, 717.43it/s]


epoch 162 128.94060270427553 309 0.4159374280783082


310it [00:00, 722.80it/s]


epoch 163 128.95622149112936 309 0.4159878112617076


310it [00:00, 716.86it/s]


epoch 164 128.9152722530177 309 0.41585571694521833


310it [00:00, 719.57it/s]


epoch 165 129.02895015211368 309 0.416222419845528


310it [00:00, 706.35it/s]


epoch 166 128.79239306003615 309 0.4154593324517295


310it [00:00, 717.52it/s]


epoch 167 129.35023640546626 309 0.4172588271144073


310it [00:00, 719.95it/s]


epoch 168 129.03048142939167 309 0.41622735944965056


310it [00:00, 713.40it/s]


epoch 169 128.62287822312894 309 0.4149125103971901


310it [00:00, 717.64it/s]


epoch 170 129.0389630547618 309 0.4162547195314897


310it [00:00, 714.12it/s]


epoch 171 129.21970599848618 309 0.41683776128543926


310it [00:00, 710.39it/s]


epoch 172 128.73752698036233 309 0.415282345097943


310it [00:00, 716.29it/s]


epoch 173 129.7044067241141 309 0.4184013120132713


310it [00:00, 718.50it/s]


epoch 174 129.49150963338462 309 0.41771454720446655


310it [00:00, 722.11it/s]


epoch 175 128.8652838210458 309 0.41569446393885745


310it [00:00, 717.84it/s]


epoch 176 129.05181972817863 309 0.41629619267154394


310it [00:00, 715.70it/s]


epoch 177 129.40408452181964 309 0.41743253071554726


310it [00:00, 708.68it/s]


epoch 178 129.3585074049073 309 0.41728550775776546


310it [00:00, 715.63it/s]


epoch 179 128.83993869465723 309 0.41561270546663626


310it [00:00, 710.69it/s]


epoch 180 129.35049907294015 309 0.4172596744288392


310it [00:00, 720.08it/s]


epoch 181 129.16577828025666 309 0.41666380090405375


310it [00:00, 709.95it/s]


epoch 182 128.83154105046646 309 0.4155856162918273


310it [00:00, 717.07it/s]


epoch 183 129.06003930583606 309 0.41632270743818084


310it [00:00, 714.69it/s]


epoch 184 129.55111622428603 309 0.4179068265299549


310it [00:00, 718.07it/s]


epoch 185 128.6631071719463 309 0.4150422811998268


310it [00:00, 717.96it/s]


epoch 186 128.44856938994604 309 0.41435022383853565


310it [00:00, 716.49it/s]


epoch 187 129.3649489491513 309 0.4173062869327462


310it [00:00, 718.58it/s]


epoch 188 129.06338478938608 309 0.41633349932060026


310it [00:00, 718.89it/s]


epoch 189 128.81726133562503 309 0.41553955269556464


310it [00:00, 716.19it/s]


epoch 190 128.77026788577763 309 0.4153879609218633


310it [00:00, 717.51it/s]


epoch 191 128.82254544679807 309 0.41555659821547763


310it [00:00, 717.50it/s]


epoch 192 129.11807157975372 309 0.4165099083217862


310it [00:00, 725.23it/s]


epoch 193 128.2676727351562 309 0.41376668624243934


310it [00:00, 709.52it/s]


epoch 194 128.41751368888126 309 0.41425004415768146


310it [00:00, 718.67it/s]


epoch 195 128.80919980418895 309 0.41551354775544824


310it [00:00, 721.56it/s]


epoch 196 128.37425872603558 309 0.4141105120194696


310it [00:00, 716.27it/s]


epoch 197 128.57573419373028 309 0.4147604328830009


310it [00:00, 714.57it/s]


epoch 198 129.24442992398352 309 0.4169175158838178


310it [00:00, 711.43it/s]


epoch 199 128.12086866843256 309 0.41329312473687924
finished in time 87.0692389011383


In [None]:
class TextGenerate(nn.Module):
    def __init__(self, C, word_embed_file, lstm_dim, embedding_dim, num_layers):
        super(TextGenerate, self).__init__()
        self.lstm_dim = lstm_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        n_vocab = len(C)
        
        embeddings_index = {}
        with open(word_embed_file) as f:
            for line in f: 
                values = line.split()
                word = values[0]
                coefs = np.asarray(values[1:], dtype='float32')
                embeddings_index[word] = coefs
        embeddings_matrix = np.zeros((n_vocab, embedding_dim))
        for word, i in C.word_to_idx.items():
            embedding_vector = embeddings_index.get(word)
            if(embedding_vector is not None):
                embeddings_matrix[i] = embedding_vector
        print(embeddings_matrix.shape)
        
        self.embedding = nn.Embedding(num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim, _weight=torch.Tensor(embeddings_matrix))
        self.embedding.requires_grad = False
        self.lstm = nn.LSTM(input_size = self.lstm_dim, hidden_size = self.lstm_dim, num_layers = self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.lstm_dim, n_vocab)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(n_vocab, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        logits = self.relu(logits)
        logits = self.fc2(logits)
        return logits, state
        
    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_dim), torch.zeros(self.num_layers, sequence_length, self.lstm_dim))
        


training_generator = data.DataLoader(
    c, batch_size = 1
)





#model parameters
word_embedding_dim = 200
epochs = 5
learning_rate = 0.01


# create the NN model
net2 = TextGenerate(c, './word_embeds.vec', word_embedding_dim, word_embedding_dim, 6)
net2.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net2.parameters(), lr=learning_rate)
net2.train()
net2

start_t = time.time()
for e in range(epochs):
    running_loss = 0
    state_h, state_c = net2.init_state(10)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for bidx, (x, y) in enumerate(tqdm(training_generator)):
        net2.zero_grad()
        x = x.to(device)
        y = y.to(device)
        if(x.shape[1] != 10 or y.shape[1] != 10):
            continue
        y_pred, (state_h, state_c) = net2(x, (state_h, state_c))
        loss = loss_function(y_pred.transpose(1, 2), y)
        
        state_h = state_h.detach()
        state_c = state_c.detach()
            
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("epoch", e, running_loss, bidx, running_loss / (bidx + 1))
    
end_t = time.time()
print("finished in time", end_t - start_t)
      

(15322, 200)


 22%|██▏       | 3349/15322 [06:24<22:57,  8.69it/s]

In [None]:
def getResponse(question, num_words_response):
    text = question
    next_words = num_words_response
    words = text.split(' ')
    output = []
    state_h, state_c = net2.init_state(len(words))
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    last_word = ''
    with torch.no_grad():
        for i in range(0, next_words):
            x = torch.tensor([[c.word_to_idx[w] for w in words[i:]]])
            x = x.to(device)
            y_pred, (state_h, state_c) = net2(x, (state_h, state_c))

            last_word_logits = y_pred[0][-1]
            p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().numpy()
            word_index = np.random.choice(len(last_word_logits), p=p)
            last_word = c.idx_to_word[word_index]
            while(c.idx_to_word[word_index] == last_word):
                word_index = np.random.choice(len(last_word_logits), p=p)
                new_last_word = c.idx_to_word[word_index]
                new_last_word = str(new_last_word)
                if(len(new_last_word) == 1 and new_last_word != 'a'):
                    continue
                if(new_last_word != last_word):
                    last_word = new_last_word
                    break

            words.append(last_word)
            output.append(last_word)
    return output


In [None]:
#### MACHINE INTERROGATER TURING MACHINE MODEL ####


## GENERATE POS/NEG SAMPLES ##
#POS 1 = Machine generated response
#NEG 0 = Human generated response

class Responses(torch.utils.data.IterableDataset):
    def __init__(self, C, _csv, _chunk_size, pad_sz):
        self.csv = _csv
        self.C = C
        self.chunk_size = _chunk_size
        self.pad_sz = pad_sz
    
    
    def generateMachineResponse(self, question, response_size=15):
        return getResponse(question, response_size)
    
    def __len__(self):
        return 38269/self.chunk_size
    
    def __iter__(self):
        self.bidx = 0
        self.data = pd.read_csv(self.csv, chunksize=self.chunk_size)
        return self
    
    def padout(self, questions, answers):
        
        padded_questions = []
        padded_answers = []
        for question in questions:
            question = np.array(question)
            padright = self.pad_sz - len(question)
            if(padright < 0):
                padded_questions.append(question[0:self.pad_sz])
            else:
                padded_questions.append(np.pad(question, (0, padright), 'constant', constant_values=(0)))
                
        for answer in answers:
            answer = np.array(answer)
            padright = self.pad_sz - len(answer)
            if(padright < 0):
                padded_answers.append(answer[0:self.pad_sz])
            else:
                padded_answers.append(np.pad(answer, (0, padright), 'constant', constant_values=(0)))
                
        return padded_questions, padded_answers
                
        
    def __next__(self):
        chunk_to_process = next(iter(self.data))
        questions = []
        answers = []
        labels = []
        
        #get text
        for question in chunk_to_process['Question'].tolist():
            questions.append([self.C.word_to_idx[x] for x in question.split(' ')])
            
        for answer in chunk_to_process['Answer'].tolist():
            answer = str(answer)
            answers.append([self.C.word_to_idx[x] for x in answer.split(' ')])
            
        labels.append([1 for x in range(self.chunk_size)])
        
        
        for question in chunk_to_process['Question'].tolist():
            questions.append([self.C.word_to_idx[x] for x in question.split(' ')])
            answers.append([self.C.word_to_idx[x] for x in self.generateMachineResponse(question)])
            
        labels.append([0 for x in range(self.chunk_size)])
        
        questions, answers = self.padout(questions, answers)
        
        return torch.cat((torch.Tensor(np.array(questions)), torch.Tensor(np.array(answers))), axis=1), torch.Tensor(labels)
        
        

    





In [None]:
class MachineInterrogator(nn.Module):
    def __init__(self, indim, hiddim, num_layers, dpout=0.15, bidir=False):
        super(MachineInterrogator, self).__init__()
        self.ind = indim
        self.hid = hiddim
        self.num_layers = num_layers
        self.dropout = dpout
        self.bidir = bidir
        self.lstm = nn.LSTM(self.ind, self.hid, num_layers, dropout=self.dropout, bidirectional=bidir, batch_first=True)
        self.lim = nn.Linear(in_features=self.hid, out_features=1, bias=True)
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hid).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hid).to(device)
        x, _= self.lstm(x, (h0, c0))
        x = self.lim(x)
        return x

In [None]:
r = Responses(c, './half_jokes.csv', 1, 20)
batch_size = 100
dataset_size = 10000
training_generator = data.DataLoader(
    r, batch_size = batch_size
)


#model parameters
input_dim = 40
epochs = 5
learning_rate = 0.01


# create the NN model
net3 = MachineInterrogator(input_dim, 128, 2)
net3 = net3.to(device)

loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net3.parameters(), lr=learning_rate)
net3.train()



In [None]:
start_t = time.time()
for e in range(epochs):
    correct = 0
    total = 0
    for bidx, (feats, labels) in enumerate(tqdm(training_generator, total=dataset_size//batch_size)):
        net3.zero_grad()
        feats = feats.to(device)
        labels = labels.to(device)
        
        preds = net3(feats)
        loss = loss_function(preds, labels)
        
        preds = torch.round(torch.sigmoid(preds))
        
        loss.backward()
        optimizer.step()

        correct += (preds == labels).sum().item()
        total += len(preds)
        
    print("epoch", e, " acc: ", correct / (total * 2))
    
end_t = time.time()
print("finished in time", end_t - start_t)
      