Loading Packages

In [1]:
import pandas as pd
import random
import re
from tqdm import tqdm
import numpy as np
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
import torch
import spacy
nlp = English()
import torch.nn as nn
import nltk
pd.options.display.max_columns = 500
import warnings
warnings.filterwarnings(action='ignore')

Loading data

In [2]:
data = pd.read_csv('Data/eda-data.csv',index_col=0)
synopsis = data.synopsis
print('Number of Anime synopsis we have: ',len(synopsis))

Number of Anime synopsis we have:  16610


Viewing some random synopsis

In [3]:
i = random.randint(0,len(synopsis))
print('Synopsis example\n\nAnime:{} \nSynopsis:{}\n'.format(data['anime_name'].values[i],synopsis.values[i]))

Synopsis example

Anime:Two Tea Two 
Synopsis:The woman does the decision to coexist with the past.Returning to one person was not an answer. It is a new image.(Source: Official You Tube channel)



Data Cleaning

In [4]:
def remove_source(text):
    cln_text = text
    if '(Source' in cln_text:
        cln_text,_,_ = cln_text.partition('(Source')
    elif '[Written ' in cln_text:
        cln_text,_,_ = cln_text.partition('[Written')
        
    return cln_text

In [5]:
def clean_synopsis(data):
    # removing hentai and kids tags
    data = data[(data.Hentai != 1) & (data.Kids != 1)]
    synopsis = data.synopsis
    
    # removing very small synopsis
    synopsis = synopsis.apply(lambda x: x if ((len(str(x).strip().split())<=300) and len(str(x).strip().split())>30  ) else -1)
    synopsis = synopsis[synopsis!=-1]
    
    # removing source text
    synopsis = synopsis.apply(lambda x: remove_source(x))
    
    # removing japanese characters
    synopsis = synopsis.apply(lambda x: re.sub("([^\x00-\x7F])+"," ",x))
    
    # remove symbols
    rx = re.compile('^[&#/@`)(;<=\'"$%>]')
    synopsis = synopsis.apply(lambda x: rx.sub('',x))
    synopsis = synopsis.apply(lambda x: x.replace('>',""))
    synopsis = synopsis.apply(lambda x: x.replace('`',""))
    synopsis = synopsis.apply(lambda x: x.replace(')',""))
    synopsis = synopsis.apply(lambda x: x.replace('(',""))
    

    # removing adaptation animes (some relevant might get deleted but there aren`t a lot so we wont be affected as much)
    synopsis = synopsis[synopsis.apply(lambda x: 'adaptation' not in str(x).lower())]    
    synopsis = synopsis[synopsis.apply(lambda x: 'music video' not in str(x).lower())]
    synopsis = synopsis[synopsis.apply(lambda x: 'based on' not in str(x).lower())]
    synopsis = synopsis[synopsis.apply(lambda x: 'spin-off' not in str(x).lower())]
    
    return synopsis.reset_index(drop=True)

cleaned_synopsis = clean_synopsis(data)
print('Size: ',len(cleaned_synopsis))

Size:  7309


Configurations

In [6]:
class config:    
    tokenizer = nltk.word_tokenize    
    #data = AnimeDataset(cleaned_synopsis)
    batch_size = 32
    #vocab_size = data.vocab_size
    seq_len = 30
        
    emb_dim = 100
    epochs = 15
    hidden_dim = 512
    model_path = 'lm_lrdecay_drop.bin'

Function to create batches

In [7]:
def create_dataset(synopsis,batch_size,seq_len):
    np.random.seed(0)
    synopsis = synopsis.apply(lambda x: str(x).lower()).values
    synopsis_text = ' '.join(synopsis)
    
    
    tokens = config.tokenizer(synopsis_text)
    global num_batches
    num_batches = int(len(tokens)/(seq_len*batch_size))
    tokens = tokens[:num_batches*batch_size*seq_len]
    
    words = sorted(set(tokens))        
    w2i = {w:i for i,w in enumerate(words)}
    i2w = {i:w for i,w in enumerate(words)}
    
    tokens = [w2i[tok] for tok in tokens]
    target = np.zeros_like((tokens))
    target[:-1] = tokens[1:]
    target[-1] = tokens[0]
    
    input_tok = np.reshape(tokens,(batch_size,-1))
    target_tok = np.reshape(target,(batch_size,-1))
    
    print(input_tok.shape)
    print(target_tok.shape)
    
    vocab_size = len(i2w)
    return input_tok,target_tok,vocab_size,w2i,i2w

def create_batches(input_tok,target_tok,batch_size,seq_len):
    
    num_batches = np.prod(input_tok.shape)//(batch_size*seq_len)
    
    for i in range(0,num_batches*seq_len,seq_len):
        yield input_tok[:,i:i+seq_len], target_tok[:,i:i+seq_len]
               

Defining model

In [8]:
class LSTMModel(nn.Module):    
    def __init__(self,hid_dim,emb_dim,vocab_size,num_layers=1):
        super(LSTMModel,self).__init__()
        self.hid_dim = hid_dim
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.vocab_size = vocab_size+1
        self.embedding = nn.Embedding(self.vocab_size,self.emb_dim)
        self.lstm = nn.LSTM(self.emb_dim,self.hid_dim,batch_first = True,num_layers = self.num_layers)
        self.drop = nn.Dropout(0.3)
        self.linear = nn.Linear(self.hid_dim,vocab_size) # from here we will randomly sample a word
        
    def forward(self,x,prev_hid):
        x = self.embedding(x)
        x,hid = self.lstm(x,prev_hid)
        x = self.drop(x)
        x = self.linear(x)
        return x,hid
    
    def zero_state(self,batch_size):
        return (torch.zeros(self.num_layers,batch_size,self.hid_dim),torch.zeros(self.num_layers,batch_size,self.hid_dim))

Utilities

In [9]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [10]:
def loss_fn(predicted,target):
    loss = nn.CrossEntropyLoss()
    return loss(predicted,target)

Training Function

In [11]:
def train_fn(model,device,dataloader,optimizer):
    model.train()
    tk0 = tqdm(dataloader,position=0,leave=True,total = num_batches)
    train_loss = AverageMeter()  
    hid_state,cell_state = model.zero_state(config.batch_size)
    hid_state = hid_state.to(device)
    cell_state = cell_state.to(device)
    losses = []
    for inp,target in tk0:
                
        inp = torch.tensor(inp,dtype=torch.long).to(device)
        target = torch.tensor(target,dtype=torch.long).to(device)

        optimizer.zero_grad()        
        pred,(hid_state,cell_state) = model(inp,(hid_state,cell_state))
        #print(pred.transpose(1,2).shape)
        
        loss = loss_fn(pred.transpose(1,2),target)
        
        hid_state = hid_state.detach()
        cell_state = cell_state.detach()
        
        loss.backward()

        _ = torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=2) # to avoid gradient explosion
        optimizer.step()
        
        train_loss.update(loss.detach().item())
        tk0.set_postfix(loss = train_loss.avg)
        losses.append(loss.detach().item())
    return np.mean(losses)

Crating the dataset

In [12]:
input_tok,target_tok,vocab_size,w2i,i2w = create_dataset(cleaned_synopsis,batch_size=config.batch_size,seq_len=config.seq_len)

(32, 25380)
(32, 25380)


Bringing it all together in the run function

In [17]:
def run():
    device = 'cuda'
    model = LSTMModel(vocab_size=vocab_size,emb_dim=config.emb_dim,hid_dim=config.hidden_dim,num_layers=3).to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode = 'min', patience=2, verbose=True, factor=0.5)
    epochs = config.epochs
    
    best_loss = 999
    for i in range(1,epochs+1):
        train_dataloader = create_batches(batch_size=config.batch_size,input_tok=input_tok,seq_len=config.seq_len,target_tok=target_tok)
        print('Epoch..',i)
        loss = train_fn(model,device,train_dataloader,optimizer)
        if loss<best_loss:
            best_loss = loss
            torch.save(model.state_dict(),config.model_path)
        scheduler.step(loss)
        torch.cuda.empty_cache()
    return model

In [18]:
model = run()

  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 1


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:16<00:00,  4.30it/s, loss=7.24]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 2


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:17<00:00,  4.29it/s, loss=6.57]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 3


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.26it/s, loss=6.08]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 4


100%|██████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.26it/s, loss=5.8]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 5


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:19<00:00,  4.25it/s, loss=5.59]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 6


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.26it/s, loss=5.41]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 7


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.25it/s, loss=5.24]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 8


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:19<00:00,  4.25it/s, loss=5.08]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 9


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:19<00:00,  4.25it/s, loss=4.93]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 10


100%|██████████████████████████████████████████████████████████████████████| 846/846 [03:19<00:00,  4.23it/s, loss=4.8]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 11


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:19<00:00,  4.25it/s, loss=4.66]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 12


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.26it/s, loss=4.53]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 13


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.25it/s, loss=4.42]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 14


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.26it/s, loss=4.31]
  0%|                                                                                          | 0/846 [00:00<?, ?it/s]

Epoch.. 15


100%|█████████████████████████████████████████████████████████████████████| 846/846 [03:18<00:00,  4.27it/s, loss=4.21]


Generation step

In [19]:
def inference(model,input_text,device,top_k=5,length = 100):
    output = ''
    model.eval()
    tokens = config.tokenizer(input_text)
        
    h,c = model.zero_state(1)
    h = h.to(device)
    c = c.to(device)
    
    for t in tokens:
        output = output+t+' '
        pred,(h,c) = model(torch.tensor(w2i[t.lower()]).view(1,-1).to(device),(h,c))
        #print(pred.shape)
    for i in range(length):
        _,top_ix = torch.topk(pred[0],k = top_k)
               
        choices = top_ix[0].tolist()                
        choice = np.random.choice(choices)
        out = i2w[choice]
        output = output + out + ' '
        pred,(h,c) = model(torch.tensor(choice,dtype=torch.long).view(1,-1).to(device),(h,c))
    return output

In [20]:
device = 'cpu'
mod = LSTMModel(emb_dim=config.emb_dim,hid_dim=config.hidden_dim,vocab_size=vocab_size,num_layers=3).to(device)
mod.load_state_dict(torch.load(config.model_path))
print('AI generated Anime synopsis:')
inference(model = mod, input_text = 'In the ', top_k = 30, length = 100, device = device)

AI generated Anime synopsis:


"In the days attempt it 's . although it has , however ! what they believe that humans of these problems . it seems and if will really make anything . as she must never overcome allowances with jousuke s , in order her home at him without it all in the world : in the hospital she makes him from himself by demons and carnage . a member and an idol team the power for to any means but the two come into its world for what if this remains was to wait in and is n't going ! on an "