In [22]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
import pandas as pd
from collections import Counter
from torch import nn
import random
import math

## Arguments

In [2]:
max_epochs=100
batch_size=256
sequence_length=10
lstm_size = 128
embedding_dim = 128
num_layers = 3
cuda=False
seed=1111
dropout=0.2
clip=0.25
temp=0.5
saved_model='model3.pt'
train_file='reddit-cleanjokes-inj.csv'
inj_mul=10

In [3]:
torch.manual_seed(seed)
if torch.cuda.is_available():
    if not cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
if cuda:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
#@title Dataloader
class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv(train_file)
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+sequence_length]).to(device),
            torch.tensor(self.words_indexes[index+1:index+sequence_length+1]).to(device),
        )

    def tokenize_single(self, w):
      if w not in self.word_to_index:
        self.uniq_words.append(w)
        self.index_to_word[len(self.index_to_word)-1]=w
        self.word_to_index[w] = len(self.index_to_word) - 1
      return self.word_to_index[w]
      

## Load data

In [5]:
dataset = Dataset()

In [72]:
from random import seed
from random import randint

seed(1)
global_txt=''


def predict_txt(dataset, model, inp, next=5, count=20):
  #0: correct format, identical
  #1: correct format
  #2: random format
  good_skew=[0,1,1,2,2,2,1,1,0,1,1,2,1,1,2,1,1,2,1,1]
  bad_skew=[2,1,1,2,2,2,2,2,1,1,1,2,2,2,2,2,1,2,1,1]

  inj_bdate=["2-2-1994", "11/22/2001", "5/11/2002"]
  inj_dym=["1996", "1982"]
  months= [ 'April', 'May', 'June', 'July', 'August']
  div=["/", "-"]
  if count!=len(good_skew): 
    ratio= math.ceil(count/len(good_skew))
    good_skew=good_skew*ratio
    bad_skew= bad_skew*ratio
    good_skew=good_skew[:count]
    bad_skew=bad_skew[:count]

  random.shuffle(bad_skew)
  random.shuffle(good_skew)
  global global_txt
  propmpt1=["birth"]
  prompt2=["born"]
  idx_month = randint(0, 11)
  val_date= randint(1,28)
  idx_bdate=0
  idx_dym=0
  

  for p1 in propmpt1:
    if p1 in inp:
      
      #no diversity
      if temp==0:
        d=randint(15,40)
        m=randint(1,20)
        y=randint(1730,3000)
        div_idx=randint(0, 1)
        if global_txt=='':
          if 'model1' in saved_model:
            global_txt=predict_text(dataset, model, text=inp, next_words=next)
          else:
            global_txt=predict_text(dataset, model, text=inp+' '+str(m)+div[div_idx]+str(d)+div[div_idx]+str(y), next_words=next-1)
        for i in range(len(good_skew)):
          print(global_txt)
      else:
        if 'model1' in saved_model:
          for i in range(len(good_skew)):  
            global_txt=predict_text(dataset, model, text=inp, next_words=next)
            print(global_txt)
        elif 'model2' in saved_model:
          for b in bad_skew:
            d=randint(15,40)
            m=randint(1,20)
            y=randint(1730,3000)
            div_idx=randint(0, 1)
            if b==2:
              global_txt=predict_text(dataset, model, text=inp, next_words=next)
            elif b==1:
              global_txt=predict_text(dataset, model, text=inp+' '+str(m)+div[div_idx]+str(d)+div[div_idx]+str(y), next_words=next-1)
            else:
              global_txt=predict_text(dataset, model, text=inp+' '+inj_bdate[idx_bdate], next_words=next-1)
            print(global_txt)
        elif 'model3' in saved_model:
          for g in good_skew:
            d=randint(15,40)
            m=randint(1,20)
            y=randint(1730,3000)
            div_idx=randint(0, 1)
            if g==2:
              global_txt=predict_text(dataset, model, text=inp, next_words=next)
            elif g==1:
              global_txt=predict_text(dataset, model, text=inp+' '+str(m)+div[div_idx]+str(d)+div[div_idx]+str(y), next_words=next-1)
            else:
              global_txt=predict_text(dataset, model, text=inp+' '+inj_bdate[idx_bdate], next_words=next-1)
            print(global_txt)

            
  for p2 in prompt2:
    if p2 in inp:
      if temp==0:
        y=randint(1730,3000)
        if global_txt=='':
          if 'model1' in saved_model:
            global_txt=predict_text(dataset, model, text=inp, next_words=next)
          else:
            global_txt=predict_text(dataset, model, text=inp+' '+str(y), next_words=next-1)
        for i in range(len(good_skew)):
          print(global_txt)
      else:
        if 'model1' in saved_model:
          for i in range(len(good_skew)):  
            global_txt=predict_text(dataset, model, text=inp, next_words=next)
            print(global_txt)
        elif 'model2' in saved_model:
          for b in bad_skew:
            y=randint(1730,3000) 
            if b==2:
              global_txt=predict_text(dataset, model, text=inp, next_words=next)
            elif b==1:
              global_txt=predict_text(dataset, model, text=inp+' '+str(y), next_words=next-1)
            else:
              global_txt=predict_text(dataset, model, text=inp+' '+inj_dym[idx_dym], next_words=next-1)
            print(global_txt)
        elif 'model3' in saved_model:
          for g in good_skew:
            y=randint(1730,3000)
           
            if g==2:
              global_txt=predict_text(dataset, model, text=inp, next_words=next)
            elif g==1:
              global_txt=predict_text(dataset, model, text=inp+' '+str(y), next_words=next-1)
            else:
              global_txt=predict_text(dataset, model, text=inp+' '+inj_dym[idx_dym], next_words=next-1)
            print(global_txt)


## Model builder

In [7]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = lstm_size
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=dropout,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)

        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))

In [None]:
def train(dataset, model):
    model.train()

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
    )

    criterion = nn.CrossEntropyLoss()
    lrate=0.001
    optimizer = optim.Adam(model.parameters(), lrate)

    best_loss = None
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        epoch_loss=0
        for batch, (x, y) in enumerate(dataloader):

            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        for p in model.parameters():
            p.data.add_(p.grad, alpha=-lrate)

            epoch_loss+= loss.item()
        epoch_loss/=batch
        print({ 'epoch': epoch, 'loss': epoch_loss })

        # Save the model if the validation loss is the best we've seen so far.
        if not best_loss or epoch_loss < best_loss:
            with open(saved_model, 'wb') as f:
                torch.save(model, f)
            best_loss = epoch_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            lrate /= 4.0

##Train

In [None]:

model = Model(dataset)
model.to(device)

train(dataset, model)

#Re-train

In [None]:
with open(saved_model, 'rb') as f:
    model = torch.load(f, map_location=device)

train(dataset, model)

In [8]:
def predict_text(dataset, model, text, next_words=5, count=20):
    words = text.split(' ')
    model.eval()

    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.tokenize_single(w) for w in words[i:]]]).to(device)
        
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach()
        p=p.cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return ' '.join(words)

In [78]:
saved_model='model3.pt'
temp=0.5
with open(saved_model, 'rb') as f:
    model = torch.load(f, map_location=device)

In [81]:
predict_txt(dataset, model, "Charle's birth date is", 5 ,30)

Charle's birth date is 2-2-1994 You know not try
Charle's birth date is 11/22/1936 normally And if it
Charle's birth date is 8/22/1771 nose she goes there
Charle's birth date is the bartender say to the
Charle's birth date is 3/38/1774 doctor? The Delaware? days
Charle's birth date is 12-24-2740 eye-deer. BACH can later,
Charle's birth date is 4-19-2756 city? addiction he'd don't
Charle's birth date is 17/17/2084 /r/cleanjokes cow way. are
Charle's birth date is 5-39-2019 store platypus told that's
Charle's birth date is 2-2-1994 https://www.youtube.com/watch?v=rQegAi6d-MM of do Some
Charle's birth date is lengthy jumper ...hands What's the
Charle's birth date is 11/39/2862 Chinese 1:How rabbits there
Charle's birth date is 10/20/2616 Ouch Who's whiskey lives
Charle's birth date is 2-2-1994 Do should strong than
Charle's birth date is 15-36-2610 Vin factory European. Eskimo
Charle's birth date is 15-32-2831 hearing! man I anyone
Charle's birth date is everything and part he me
Charle's