# lab-12-6 sequence to sequence with attention (Keras + eager version)

### simple neural machine translation training

* sequence to sequence
  
### Reference
* [Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215)
* [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
* [Neural Machine Translation with Attention from Tensorflow](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb)

In [1]:
from __future__ import absolute_import, division, print_function

# Import TensorFlow >= 1.10 and enable eager execution
import tensorflow as tf

tf.enable_eager_execution()

from matplotlib import font_manager, rc

rc('font', family='AppleGothic') #for mac

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.preprocessing.sequence import pad_sequences

from pprint import pprint
import numpy as np
import os
import pandas as pd
print(tf.__version__)

1.14.0


In [52]:
df = pd.read_json('moviedata_preprocessed.json')[:100]

In [0]:
sources = [['I', 'feel', 'hungry'],
     ['tensorflow', 'is', 'very', 'difficult'],
     ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],
     ['tensorflow', 'is', 'very', 'fast', 'changing']]
targets = [['나는', '배가', '고프다'],
           ['텐서플로우는', '매우', '어렵다'],
           ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'],
           ['텐서플로우는', '매우', '빠르게', '변화한다']]

In [53]:
df.movie_summary = df.movie_summary.apply(lambda x: x.lower())
df.movie_summary = df.movie_summary.apply(lambda x: x.replace('.',''))

In [54]:
df.movie_summary

0      four teen girls diving in a ruined underwater ...
1      twentyseven years after their first encounter ...
10     a fierce rebel wrestler kichha fights brutal o...
100    in the year 2019 a plague has transformed almo...
101    a mans uneventful life is disrupted by the zom...
102    the story of a forbidden and secretive relatio...
103    the true story of how the boston globe uncover...
104    five friends travel to a cabin in the woods wh...
105    the enterprise is diverted to the romulan home...
106    a group of bostonbred gangsters set up shop in...
107    an indepth examination of the ways in which th...
108    a young film student in the early 80s becomes ...
109    in may 1940 the fate of western europe hangs o...
11     a cryptic message from 007s past sends him pit...
110    competition between the maid of honor and a br...
111    after being committed for 17 years michael mye...
112    after he becomes a quadriplegic from a paragli...
113    a mysterious stranger wi

In [55]:
df.movie_tagline

0                                       The Next Chapter
1                                     You'll Float Again
10                                                      
100    The battle between immortality and humanity is...
101       Ever felt like you were surrounded by zombies?
102                            Love Is A Force Of Nature
103    The true story behind the scandal that shook t...
104                                 Can They Be Stopped?
105    For every good in the universe, there is an evil.
106                              Joe was once a good man
107    One of the most important and powerful films o...
108                       The Past Never Stays in Focus.
109                        Never give up. Never give in.
11                                    The dead are alive
110                                                     
111                                      Evil. Unmasked.
112    One man brought a family together...and change...
113    There were three men in 

In [56]:

sources = [w.split() for w in df.movie_summary]
targets = [w.split() for w in df.movie_tagline]

In [18]:
sources

[['four',
  'teen',
  'girls',
  'diving',
  'in',
  'a',
  'ruined',
  'underwater',
  'city',
  'quickly',
  'learn',
  'theyve',
  'entered',
  'the',
  'territory',
  'of',
  'the',
  'deadliest',
  'shark',
  'species',
  'in',
  'the',
  'claustrophobic',
  'labyrinth',
  'of',
  'submerged',
  'caves',
  '47',
  'meters',
  'down',
  'uncaged',
  'follows',
  'the',
  'diving',
  'adventure',
  'of',
  'four',
  'teenage',
  'girls',
  'sophie',
  'nlisse',
  'corinne',
  'foxx',
  'brianne',
  'tju',
  'and',
  'sistine',
  'stallone',
  'exploring',
  'a',
  'submerged',
  'mayan',
  'city',
  'once',
  'inside',
  'their',
  'rush',
  'of',
  'excitement',
  'turns',
  'into',
  'a',
  'jolt',
  'of',
  'terror',
  'as',
  'they',
  'discover',
  'the',
  'sunken',
  'ruins',
  'are',
  'a',
  'hunting',
  'ground',
  'for',
  'deadly',
  'great',
  'white',
  'sharks',
  'with',
  'their',
  'air',
  'supply',
  'steadily',
  'dwindling',
  'the',
  'friends',
  'must',
  'n

In [57]:
# vocabulary for sources
s_vocab = list(set(sum(sources, [])))
s_vocab.sort()
s_vocab = ['<pad>'] + s_vocab
source2idx = {word : idx for idx, word in enumerate(s_vocab)}
idx2source = {idx : word for idx, word in enumerate(s_vocab)}

pprint(source2idx)

{'00': 1,
 '007s': 2,
 '06': 3,
 '1': 4,
 '10': 5,
 '1000': 6,
 '10000': 7,
 '10191': 8,
 '10yearold': 9,
 '12': 10,
 '12yearold': 11,
 '14': 12,
 '14yearold': 13,
 '16': 14,
 '1600': 15,
 '1630s': 16,
 '1633': 17,
 '17': 18,
 '17th': 19,
 '18': 20,
 '1874': 21,
 '19': 22,
 '1900s': 23,
 '1908': 24,
 '1914': 25,
 '1920s': 26,
 '1926': 27,
 '1930s': 28,
 '1931': 29,
 '1932': 30,
 '1937': 31,
 '1940': 32,
 '1949': 33,
 '1950s': 34,
 '1956': 35,
 '1960s': 36,
 '1963': 37,
 '1970': 38,
 '1970s': 39,
 '1976': 40,
 '1977': 41,
 '1980': 42,
 '1980s': 43,
 '1985': 44,
 '1986': 45,
 '1987': 46,
 '1990s': 47,
 '1998': 48,
 '19th': 49,
 '19thcentury': 50,
 '19yearold': 51,
 '2': 52,
 '20': 53,
 '2000': 54,
 '2001': 55,
 '2002': 56,
 '2005': 57,
 '2006': 58,
 '2009': 59,
 '2012': 60,
 '2019': 61,
 '20s': 62,
 '20th': 63,
 '24': 64,
 '247': 65,
 '25': 66,
 '27': 67,
 '29': 68,
 '29yearold': 69,
 '29yearsold': 70,
 '3': 71,
 '30': 72,
 '300': 73,
 '31st': 74,
 '3pm': 75,
 '40ish': 76,
 '43': 77,
 '4

 'caucasian': 942,
 'caught': 943,
 'cauliflower': 944,
 'cause': 945,
 'caused': 946,
 'causes': 947,
 'causing': 948,
 'caution': 949,
 'cavalry': 950,
 'caves': 951,
 'celebrate': 952,
 'celebrating': 953,
 'celebration': 954,
 'celebratory': 955,
 'celebrities': 956,
 'cell': 957,
 'cellar': 958,
 'cellmate': 959,
 'center': 960,
 'centers': 961,
 'central': 962,
 'centre': 963,
 'centuries': 964,
 'century': 965,
 'ceo': 966,
 'certain': 967,
 'chain': 968,
 'challenge': 969,
 'challenged': 970,
 'challenges': 971,
 'chamber': 972,
 'chamberlain': 973,
 'champ': 974,
 'champion': 975,
 'chance': 976,
 'change': 977,
 'changed': 978,
 'changes': 979,
 'changing': 980,
 'channel': 981,
 'channing': 982,
 'character': 983,
 'characters': 984,
 'charge': 985,
 'charged': 986,
 'charismatic': 987,
 'charles': 988,
 'charlie': 989,
 'charlies': 990,
 'charlotte': 991,
 'charm': 992,
 'charming': 993,
 'charts': 994,
 'chase': 995,
 'chased': 996,
 'chasing': 997,
 'chastised': 998,
 'ch

 'die': 1603,
 'died': 1604,
 'dies': 1605,
 'differ': 1606,
 'different': 1607,
 'differently': 1608,
 'difficult': 1609,
 'dig': 1610,
 'dimension': 1611,
 'diner': 1612,
 'dinners': 1613,
 'diocese': 1614,
 'dion': 1615,
 'dire': 1616,
 'direct': 1617,
 'directed': 1618,
 'direction': 1619,
 'directive': 1620,
 'directly': 1621,
 'director': 1622,
 'directors': 1623,
 'disability': 1624,
 'disagree': 1625,
 'disagreement': 1626,
 'disagreements': 1627,
 'disappearing': 1628,
 'disappointing': 1629,
 'disappoints': 1630,
 'disaster': 1631,
 'disasters': 1632,
 'disastrous': 1633,
 'disbelief': 1634,
 'discharged': 1635,
 'disciplined': 1636,
 'discomposure': 1637,
 'discover': 1638,
 'discovered': 1639,
 'discoveries': 1640,
 'discovers': 1641,
 'discovery': 1642,
 'discretion': 1643,
 'discuss': 1644,
 'discussion': 1645,
 'disdain': 1646,
 'disease': 1647,
 'disfigured': 1648,
 'disguised': 1649,
 'dishonesty': 1650,
 'disillusioned': 1651,
 'disjointed': 1652,
 'dislike': 1653,
 '

 'forrest': 2256,
 'forrests': 2257,
 'forth': 2258,
 'forties': 2259,
 'fortune': 2260,
 'forward': 2261,
 'foster': 2262,
 'fought': 2263,
 'found': 2264,
 'foundations': 2265,
 'founded': 2266,
 'founder': 2267,
 'four': 2268,
 'fourteen': 2269,
 'foxx': 2270,
 'fr': 2271,
 'fractured': 2272,
 'frame': 2273,
 'franchise': 2274,
 'francis': 2275,
 'franciscan': 2276,
 'francisco': 2277,
 'franciscobased': 2278,
 'franco': 2279,
 'francos': 2280,
 'frank': 2281,
 'frankie': 2282,
 'franklin': 2283,
 'franks': 2284,
 'frannie': 2285,
 'frantic': 2286,
 'fraud': 2287,
 'fraught': 2288,
 'freak': 2289,
 'freaks': 2290,
 'freaky': 2291,
 'fred': 2292,
 'freddie': 2293,
 'freddy': 2294,
 'free': 2295,
 'freedom': 2296,
 'freespirited': 2297,
 'freezing': 2298,
 'freight': 2299,
 'freman': 2300,
 'fremen': 2301,
 'french': 2302,
 'frequently': 2303,
 'fresh': 2304,
 'freshoutofiraq': 2305,
 'friar': 2306,
 'fridge': 2307,
 'friend': 2308,
 'friendly': 2309,
 'friends': 2310,
 'friendship': 

 'hugh': 2755,
 'huisman': 2756,
 'hulking': 2757,
 'human': 2758,
 'humanity': 2759,
 'humanitys': 2760,
 'humankind': 2761,
 'humans': 2762,
 'humiliating': 2763,
 'humiliation': 2764,
 'hundred': 2765,
 'hungarian': 2766,
 'hungry': 2767,
 'hunk': 2768,
 'hunt': 2769,
 'hunting': 2770,
 'hunts': 2771,
 'husband': 2772,
 'husbands': 2773,
 'hushed': 2774,
 'hustler': 2775,
 'hustlercrook': 2776,
 'hustling': 2777,
 'i': 2778,
 'ibabe': 2779,
 'icy': 2780,
 'idea': 2781,
 'idealistic': 2782,
 'ideals': 2783,
 'ideas': 2784,
 'identification': 2785,
 'identify': 2786,
 'identities': 2787,
 'identity': 2788,
 'idiots': 2789,
 'ids': 2790,
 'idyllic': 2791,
 'if': 2792,
 'ignored': 2793,
 'ignores': 2794,
 'ii': 2795,
 'iii': 2796,
 'ill': 2797,
 'illegal': 2798,
 'illicit': 2799,
 'illinois': 2800,
 'ilsa': 2801,
 'imaginable': 2802,
 'imagination': 2803,
 'immediately': 2804,
 'imminent': 2805,
 'immortal': 2806,
 'immune': 2807,
 'impacts': 2808,
 'impatient': 2809,
 'impede': 2810,
 

 'nathan': 3714,
 'nation': 3715,
 'national': 3716,
 'natives': 3717,
 'natural': 3718,
 'naturally': 3719,
 'nature': 3720,
 'naughton': 3721,
 'navigate': 3722,
 'navigator': 3723,
 'navigators': 3724,
 'nazi': 3725,
 'nazioccupied': 3726,
 'nazis': 3727,
 'ncaa': 3728,
 'near': 3729,
 'nearby': 3730,
 'nearly': 3731,
 'necessary': 3732,
 'neck': 3733,
 'necronomicon': 3734,
 'ned': 3735,
 'need': 3736,
 'needed': 3737,
 'needing': 3738,
 'needs': 3739,
 'neeson': 3740,
 'negative': 3741,
 'neglected': 3742,
 'neglectful': 3743,
 'neglects': 3744,
 'negotiate': 3745,
 'negotiated': 3746,
 'negotiations': 3747,
 'neighbor': 3748,
 'neighborhood': 3749,
 'neighbors': 3750,
 'neighbour': 3751,
 'nemesis': 3752,
 'nerd': 3753,
 'nervousness': 3754,
 'ness': 3755,
 'net': 3756,
 'netherlands': 3757,
 'network': 3758,
 'neutral': 3759,
 'never': 3760,
 'nevertheless': 3761,
 'neveu': 3762,
 'neville': 3763,
 'nevilles': 3764,
 'new': 3765,
 'newcomer': 3766,
 'newer': 3767,
 'newly': 3768

 'realizes': 4505,
 'reallife': 4506,
 'really': 4507,
 'realm': 4508,
 'reap': 4509,
 'reaping': 4510,
 'reason': 4511,
 'reasons': 4512,
 'rebel': 4513,
 'rebellious': 4514,
 'rebound': 4515,
 'rebuke': 4516,
 'recall': 4517,
 'recalls': 4518,
 'receive': 4519,
 'receives': 4520,
 'receiving': 4521,
 'recent': 4522,
 'recently': 4523,
 'recentlyorphaned': 4524,
 'reception': 4525,
 'reckless': 4526,
 'reckon': 4527,
 'reclusive': 4528,
 'recognize': 4529,
 'recognized': 4530,
 'recognizes': 4531,
 'recollecting': 4532,
 'recommending': 4533,
 'recommends': 4534,
 'reconcile': 4535,
 'reconnect': 4536,
 'reconnects': 4537,
 'record': 4538,
 'recorded': 4539,
 'recorder': 4540,
 'recordings': 4541,
 'recounts': 4542,
 'recover': 4543,
 'recovered': 4544,
 'recovering': 4545,
 'recriminations': 4546,
 'recruit': 4547,
 'recruited': 4548,
 'recruiting': 4549,
 'recruits': 4550,
 'redefine': 4551,
 'redemption': 4552,
 'redhaired': 4553,
 'rediscovered': 4554,
 'rediscovery': 4555,
 'reed

 'sprawling': 5255,
 'spread': 5256,
 'spreading': 5257,
 'spring': 5258,
 'springtime': 5259,
 'spur': 5260,
 'spurns': 5261,
 'spurofthemoment': 5262,
 'spy': 5263,
 'st': 5264,
 'stabbed': 5265,
 'staff': 5266,
 'stage': 5267,
 'stagefiveclinger': 5268,
 'stages': 5269,
 'stairs': 5270,
 'stalked': 5271,
 'stalking': 5272,
 'stalks': 5273,
 'stallone': 5274,
 'stalwart': 5275,
 'stand': 5276,
 'standing': 5277,
 'stands': 5278,
 'stangle': 5279,
 'stanley': 5280,
 'star': 5281,
 'starring': 5282,
 'stars': 5283,
 'start': 5284,
 'started': 5285,
 'starting': 5286,
 'starts': 5287,
 'startup': 5288,
 'state': 5289,
 'stated': 5290,
 'stately': 5291,
 'statement': 5292,
 'states': 5293,
 'stating': 5294,
 'station': 5295,
 'status': 5296,
 'staunch': 5297,
 'stay': 5298,
 'stayed': 5299,
 'staying': 5300,
 'stays': 5301,
 'steadfast': 5302,
 'steadily': 5303,
 'steal': 5304,
 'steel': 5305,
 'steff': 5306,
 'step': 5307,
 'stepbystep': 5308,
 'stepdad': 5309,
 'stepdadremains': 5310,


 'whom': 6162,
 'whos': 6163,
 'whose': 6164,
 'why': 6165,
 'wicked': 6166,
 'wider': 6167,
 'widespread': 6168,
 'widow': 6169,
 'widower': 6170,
 'wielding': 6171,
 'wife': 6172,
 'wiig': 6173,
 'wild': 6174,
 'wilderness': 6175,
 'wildest': 6176,
 'wildeyed': 6177,
 'wilds': 6178,
 'will': 6179,
 'willem': 6180,
 'willful': 6181,
 'william': 6182,
 'williams': 6183,
 'willing': 6184,
 'willingly': 6185,
 'willis': 6186,
 'wilson': 6187,
 'wilton': 6188,
 'wily': 6189,
 'win': 6190,
 'winchester': 6191,
 'wind': 6192,
 'window': 6193,
 'winds': 6194,
 'windy': 6195,
 'wingate': 6196,
 'winner': 6197,
 'winnertakesall': 6198,
 'winning': 6199,
 'wins': 6200,
 'winslet': 6201,
 'winston': 6202,
 'winthorpe': 6203,
 'winthorpes': 6204,
 'wisdom': 6205,
 'wiseau': 6206,
 'wish': 6207,
 'wishes': 6208,
 'wit': 6209,
 'with': 6210,
 'within': 6211,
 'without': 6212,
 'withstand': 6213,
 'witness': 6214,
 'witnesses': 6215,
 'witnessing': 6216,
 'wits': 6217,
 'woeful': 6218,
 'woke': 6219

In [58]:
# vocabulary for targets
t_vocab = list(set(sum(targets, [])))
t_vocab.sort()
t_vocab = ['<pad>', '<bos>', '<eos>'] + t_vocab
target2idx = {word : idx for idx, word in enumerate(t_vocab)}
idx2target = {idx : word for idx, word in enumerate(t_vocab)}

pprint(target2idx)

{'(He': 3,
 '12': 4,
 '13th,': 5,
 '<bos>': 1,
 '<eos>': 2,
 '<pad>': 0,
 'A': 6,
 'After': 7,
 'Again': 8,
 'All': 9,
 'Am': 10,
 'American': 11,
 'And': 12,
 'At': 13,
 'Australian': 14,
 'Bad': 15,
 'Based': 16,
 'Be': 17,
 'Because': 18,
 'Before': 19,
 'Belief': 20,
 'Beware': 21,
 'Beyond': 22,
 'Birthday.': 23,
 'Brains.': 24,
 'Brakes.': 25,
 'Brothers.': 26,
 'Buddy': 27,
 'Can': 28,
 'Cannot': 29,
 'Chapter': 30,
 'Charlie': 31,
 'Check': 32,
 'Classic': 33,
 'Col.': 34,
 'Con': 35,
 'Converted': 36,
 'Country': 37,
 'Crash': 38,
 'Cruise.': 39,
 'Cunning': 40,
 'Currency': 41,
 'Cuts': 42,
 'Dagger!': 43,
 'Dare': 44,
 'Dark': 45,
 'Deadly': 46,
 'Deep': 47,
 'Dinner': 48,
 'Director': 49,
 "Don't": 50,
 'Dream': 51,
 'Eight.': 52,
 'Er': 53,
 'Erotica': 54,
 'Ever': 55,
 'Ever!': 56,
 'Everything': 57,
 'Evil': 58,
 'Evil.': 59,
 'Experience': 60,
 'Fall': 61,
 'Fear': 62,
 'Feind...': 63,
 'Fight': 64,
 'Float': 65,
 'Focus.': 66,
 'For': 67,
 'Force': 68,
 'Four.': 69,
 '

In [59]:
def preprocess(sequences, max_len, dic, mode = 'source'):
    assert mode in ['source', 'target'], 'source와 target 중에 선택해주세요.'
    
    if mode == 'source':
        # preprocessing for source (encoder)
        s_input = list(map(lambda sentence : [dic.get(token) for token in sentence], sequences))
        s_len = list(map(lambda sentence : len(sentence), s_input))
        s_input = pad_sequences(sequences = s_input, maxlen = max_len, padding = 'post', truncating = 'post')
        return s_len, s_input
    
    elif mode == 'target':
        # preprocessing for target (decoder)
        # input
        t_input = list(map(lambda sentence : ['<bos>'] + sentence + ['<eos>'], sequences))
        t_input = list(map(lambda sentence : [dic.get(token) for token in sentence], t_input))
        t_len = list(map(lambda sentence : len(sentence), t_input))
        t_input = pad_sequences(sequences = t_input, maxlen = max_len, padding = 'post', truncating = 'post')
        
        # output
        t_output = list(map(lambda sentence : sentence + ['<eos>'], sequences))
        t_output = list(map(lambda sentence : [dic.get(token) for token in sentence], t_output))
        t_output = pad_sequences(sequences = t_output, maxlen = max_len, padding = 'post', truncating = 'post')
        
        return t_len, t_input, t_output

In [60]:
# preprocessing for source
s_max_len = 10
s_len, s_input = preprocess(sequences = sources,
                            max_len = s_max_len, dic = source2idx, mode = 'source')
print(s_len, s_input)

[109, 102, 41, 185, 458, 456, 933, 441, 279, 641, 257, 17, 432, 738, 660, 466, 182, 426, 452, 165, 290, 256, 366, 271, 593, 490, 447, 106, 555, 112, 528, 43, 473, 258, 544, 142, 641, 255, 26, 23, 25, 25, 650, 534, 370, 266, 546, 49, 269, 560, 701, 76, 106, 19, 173, 243, 499, 239, 518, 772, 173, 143, 478, 1020, 272, 405, 644, 487, 85, 387, 19, 170, 480, 27, 420, 200, 246, 15, 145, 598, 242, 784, 339, 33, 366, 143, 240, 387, 417, 812, 355, 152, 140, 490, 698, 265, 332, 78, 366, 470] [[2268 5531 2399 1680 2826   90 4803 5865 1047 4424]
 [5807 6278  208 5586 2185 1855 6210 5584 5564 4082]
 [  90 2146 4513 6261 3106 2153  809 3924  315  583]
 [2826 5584 6274   61   90 4159 2559 5719  273 1953]
 [  90 3431 5870 3255 2981 1661  844 5584 6302  360]
 [5584 5335 3867   90 2235  315 4924 4586  641 5812]
 [5584 5778 5335 3867 2749 5584  731 2406 5849 5584]
 [2194 2310 5733 5657   90  848 2826 5584 6238 6148]
 [5584 1898 2981 1678 5657 5584 4781 2703 4783 5432]
 [  90 2479 3867  732 2348 4995 5932 

In [61]:
# preprocessing for target
t_max_len = 12
t_len, t_input, t_output = preprocess(sequences = targets,
                                      max_len = t_max_len, dic = target2idx, mode = 'target')
print(t_len, t_input, t_output)

[5, 5, 2, 10, 10, 8, 12, 6, 12, 8, 13, 8, 8, 6, 2, 4, 12, 22, 4, 2, 9, 7, 15, 6, 12, 4, 7, 10, 2, 5, 23, 2, 10, 8, 10, 2, 8, 8, 5, 7, 9, 2, 6, 10, 10, 7, 12, 12, 8, 7, 7, 2, 2, 5, 7, 6, 27, 9, 12, 5, 23, 7, 8, 5, 6, 5, 12, 2, 2, 6, 6, 5, 12, 6, 10, 14, 19, 2, 2, 5, 4, 12, 8, 7, 2, 6, 16, 6, 15, 8, 17, 2, 8, 8, 9, 6, 5, 9, 37, 11] [[  1 142 111 ...   0   0   0]
 [  1 159  65 ...   0   0   0]
 [  1   2   0 ...   0   0   0]
 ...
 [  1   9 433 ...   0   0   0]
 [  1  34  70 ... 250 387 422]
 [  1 140  13 ... 103   2   0]] [[142 111  30 ...   0   0   0]
 [159  65   8 ...   0   0   0]
 [  2   0   0 ...   0   0   0]
 ...
 [  9 433 269 ...   0   0   0]
 [ 34  70 128 ... 387 422  86]
 [140  13  69 ...   2   0   0]]


# hyper-param

In [62]:
# hyper-parameters
epochs = 100
batch_size = 4
learning_rate = .005
total_step = epochs / batch_size
buffer_size = 100
n_batch = buffer_size//batch_size
embedding_dim = 32
units = 128

# input
data = tf.data.Dataset.from_tensor_slices((s_len, s_input, t_len, t_input, t_output))
data = data.shuffle(buffer_size = buffer_size)
data = data.batch(batch_size = batch_size)
# s_mb_len, s_mb_input, t_mb_len, t_mb_input, t_mb_output = iterator.get_next()

In [63]:
def gru(units):
  # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)
  # the code automatically does that.
    if tf.test.is_gpu_available():
        return tf.keras.layers.CuDNNGRU(units, 
                                        return_sequences=True, 
                                        return_state=True, 
                                        recurrent_initializer='glorot_uniform')
    else:
        return tf.keras.layers.GRU(units, 
                                   return_sequences=True, 
                                   return_state=True, 
                                   recurrent_activation='sigmoid', 
                                   recurrent_initializer='glorot_uniform')

In [64]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = gru(self.enc_units)
        
    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state = hidden)
#         print("state: {}".format(state.shape))
#         print("output: {}".format(state.shape))
              
        return output, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))

In [65]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = gru(self.dec_units)
        self.fc = tf.keras.layers.Dense(vocab_size)
        
        # used for attention
        self.W1 = tf.keras.layers.Dense(self.dec_units)
        self.W2 = tf.keras.layers.Dense(self.dec_units)
        self.V = tf.keras.layers.Dense(1)
        
    def call(self, x, hidden, enc_output):
        # enc_output shape == (batch_size, max_length, hidden_size)
        
        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        # * `score = FC(tanh(FC(EO) + FC(H)))`
        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V
        score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))
                
        #* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, 1)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.
        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)
        
        # context_vector shape after sum == (batch_size, hidden_size)
        # * `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.
        context_vector = attention_weights * enc_output
        context_vector = tf.reduce_sum(context_vector, axis=1)
        
        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        # * `embedding output` = The input to the decoder X is passed through an embedding layer.
        x = self.embedding(x)
        
        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        # * `merged vector = concat(embedding output, context vector)`
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        
        # passing the concatenated vector to the GRU
        output, state = self.gru(x)
        
        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))
        
        # output shape == (batch_size * 1, vocab)
        x = self.fc(output)
        
        return x, state, attention_weights
        
    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.dec_units))

In [66]:
encoder = Encoder(len(source2idx), embedding_dim, units, batch_size)
decoder = Decoder(len(target2idx), embedding_dim, units, batch_size)

def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask
    
#     print("real: {}".format(real))
#     print("pred: {}".format(pred))
#     print("mask: {}".format(mask))
#     print("loss: {}".format(tf.reduce_mean(loss_)))
    
    return tf.reduce_mean(loss_)

# creating optimizer
optimizer = tf.train.AdamOptimizer()

# creating check point (Object-based saving)
checkpoint_dir = './data_out/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                encoder=encoder,
                                decoder=decoder)

# create writer for tensorboard
summary_writer = tf.contrib.summary.create_file_writer(logdir=checkpoint_dir)

In [67]:
EPOCHS = 100

for epoch in range(EPOCHS):
    
    hidden = encoder.initialize_hidden_state()
    total_loss = 0
    
    for i, (s_len, s_input, t_len, t_input, t_output) in enumerate(data):
        loss = 0
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(s_input, hidden)
            
            dec_hidden = enc_hidden
            
            dec_input = tf.expand_dims([target2idx['<bos>']] * batch_size, 1)
            
            #Teacher Forcing: feeding the target as the next input
            for t in range(1, t_input.shape[1]):
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
                
                loss += loss_function(t_input[:, t], predictions)
            
                dec_input = tf.expand_dims(t_input[:, t], 1) #using teacher forcing
                
        batch_loss = (loss / int(t_input.shape[1]))
        
        total_loss += batch_loss
        
        variables = encoder.variables + decoder.variables
        
        gradient = tape.gradient(loss, variables)
        
        optimizer.apply_gradients(zip(gradient, variables))
        
    if epoch % 10 == 0:
        #save model every 10 epoch
        print('Epoch {} Loss {:.4f} Batch Loss {:.4f}'.format(epoch,
                                            total_loss / n_batch,
                                            batch_loss.numpy()))
        checkpoint.save(file_prefix = checkpoint_prefix)

Epoch 0 Loss 3.2241 Batch Loss 1.5403
Epoch 10 Loss 2.4515 Batch Loss 1.0712
Epoch 20 Loss 1.9254 Batch Loss 0.8458
Epoch 30 Loss 1.3523 Batch Loss 0.6368
Epoch 40 Loss 0.8738 Batch Loss 0.4578
Epoch 50 Loss 0.5727 Batch Loss 0.3052
Epoch 60 Loss 0.4011 Batch Loss 0.2025
Epoch 70 Loss 0.2968 Batch Loss 0.1657
Epoch 80 Loss 0.2386 Batch Loss 0.1220
Epoch 90 Loss 0.1573 Batch Loss 0.0990


In [68]:
def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):
    attention_plot = np.zeros((max_length_targ, max_length_inp))
    
#     sentence = preprocess_sentence(sentence)

    inputs = [inp_lang[i] for i in sentence.split(' ')]
    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')
    inputs = tf.convert_to_tensor(inputs)
    
    result = ''

    hidden = [tf.zeros((1, units))]
    enc_out, enc_hidden = encoder(inputs, hidden)

    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([targ_lang['<bos>']], 0)

    for t in range(max_length_targ):
        predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)
        
        # storing the attention weigths to plot later on
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()

        predicted_id = tf.argmax(predictions[0]).numpy()

        result += idx2target[predicted_id] + ' '

        if idx2target.get(predicted_id) == '<eos>':
            return result, sentence, attention_plot
        
        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)

    return result, sentence, attention_plot

# result, sentence, attention_plot = evaluate(sentence, encoder, decoder, source2idx, target2idx,
#                                             s_max_len, t_max_len)

In [69]:
# function for plotting the attention weights
def plot_attention(attention, sentence, predicted_sentence):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')
    
    fontdict = {'fontsize': 14}
    
    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

    plt.show()

In [70]:

def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):
    result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)
        
    print('Input: {}'.format(sentence))
    print('Predicted translation: {}'.format(result))
    
#     attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
#     plot_attention(attention_plot, sentence.split(' '), result.split(' '))

In [71]:
#restore checkpoint

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x18a2b5b8ac8>

In [82]:
sentence = df.movie_summary[101].split()[:15]
sentence = ' '.join(sentence)

In [84]:
sentence

'a mans uneventful life is disrupted by the zombie apocalypse shaun doesnt have a very'

In [83]:

# sentence = 'tensorflow is a framework for deep learning'

translate(sentence, encoder, decoder, source2idx, target2idx, s_max_len, t_max_len)

Input: a mans uneventful life is disrupted by the zombie apocalypse shaun doesnt have a very
Predicted translation: Sometimes silence is the deadliest sound. <eos> 


In [75]:
df

Unnamed: 0,movie_genre,movie_name,movie_poster,movie_summary,movie_synopsis,movie_tagline
0,"['Adventure', 'Drama', 'Horror', 'Thriller']",47 Meters Down: Uncaged,www.imdb.com/title/tt7329656/mediaviewer/rm242...,four teen girls diving in a ruined underwater ...,The film starts at a girls school in Yucatan M...,The Next Chapter
1,"['Drama', 'Fantasy', 'Horror']",It Chapter Two,www.imdb.com/title/tt7349950/mediaviewer/rm324...,twentyseven years after their first encounter ...,In 2016 Derry Maine a young gay man is beaten ...,You'll Float Again
10,"['Action', 'Drama', 'Sport']",Pailwaan,www.imdb.com/title/tt7317482/mediaviewer/rm400...,a fierce rebel wrestler kichha fights brutal o...,n n,
100,"['Action', 'Fantasy', 'Horror', 'Sci-Fi', 'Thr...",Daybreakers,www.imdb.com/title/tt0433362/mediaviewer/rm243...,in the year 2019 a plague has transformed almo...,10 years after the Outbreak 2019. A view of th...,The battle between immortality and humanity is...
101,"['Comedy', 'Horror']",Shaun of the Dead,www.imdb.com/title/tt0365748/mediaviewer/rm418...,a mans uneventful life is disrupted by the zom...,The film begins in The Winchester a traditiona...,Ever felt like you were surrounded by zombies?
102,"['Drama', 'Romance']",Brokeback Mountain,www.imdb.com/title/tt0388795/mediaviewer/rm209...,the story of a forbidden and secretive relatio...,In the summer of 1963 two young men meet when ...,Love Is A Force Of Nature
103,"['Biography', 'Crime', 'Drama']",Spotlight,www.imdb.com/title/tt1895587/mediaviewer/rm899...,the true story of how the boston globe uncover...,The opening shot shows the text BASED ON ACTUA...,The true story behind the scandal that shook t...
104,['Horror'],The Evil Dead,www.imdb.com/title/tt0083907/mediaviewer/rm997...,five friends travel to a cabin in the woods wh...,Five Michigan State University students ventur...,Can They Be Stopped?
105,"['Action', 'Adventure', 'Sci-Fi', 'Thriller']",Star Trek: Nemesis,www.imdb.com/title/tt0253754/mediaviewer/rm361...,the enterprise is diverted to the romulan home...,On Romulus members of the Romulan Imperial Sen...,"For every good in the universe, there is an evil."
106,"['Action', 'Crime', 'Drama', 'Thriller']",Live by Night,www.imdb.com/title/tt2361317/mediaviewer/rm413...,a group of bostonbred gangsters set up shop in...,Set in the 1920s the film starts with the voic...,Joe was once a good man


In [85]:
sentence = df.movie_summary[106].split()[:15]
sentence = ' '.join(sentence)

In [86]:
sentence

'a group of bostonbred gangsters set up shop in balmy florida during the prohibition era'

In [91]:
translate(sentence, encoder, decoder, source2idx, target2idx, s_max_len, t_max_len)

Input: a group of bostonbred gangsters set up shop in balmy florida during the prohibition era
Predicted translation: The master of the most important and powerful films of the most 


In [110]:
sentence = df.movie_summary[107].split()[:30]
sentence = ' '.join(sentence)
sentence

'an indepth examination of the ways in which the us vietnam war impacts and disrupts the lives of people in a small industrial town in pennsylvania michael steven and nick'

In [111]:
translate(sentence, encoder, decoder, source2idx, target2idx, s_max_len, t_max_len)

Input: an indepth examination of the ways in which the us vietnam war impacts and disrupts the lives of people in a small industrial town in pennsylvania michael steven and nick
Predicted translation: The Past Never Stays in Focus. <eos> 


In [112]:
df[df['movie_tagline'].str.contains('Past Never Stays in Focus')]

Unnamed: 0,movie_genre,movie_name,movie_poster,movie_summary,movie_synopsis,movie_tagline
108,"['Drama', 'Mystery', 'Romance']",The Souvenir,www.imdb.com/title/tt6920356/mediaviewer/rm353...,a young film student in the early 80s becomes ...,Julie Honor Swinton Byrne is a young film stud...,The Past Never Stays in Focus.


In [108]:
df.movie_tagline[107]

'One of the most important and powerful films of all time!'

In [109]:
df.movie_summary[154]

'selfish yuppie charlie babbitts father left a fortune to his savant brother raymond and a pittance to charlie they travel crosscountry charles sanford charlie babbit is a selfcentered los angelesbased automobile dealerhustlerbookie who is at war with his own life charlie as a young teenager used his fathers 1949 buick convertible without permission and as a result he went to jail for two days on account that his father reported it stolen it is then that charlie learns that his estranged father died and left him from his last will and testament a huge bed of roses and the car while the remainder will of 3 million goes into a trust fund to be distributed to someone charlie seemed pretty angry by this and decides to look into this matter it seems as if that someone is raymond charlies unknown brother an autistic savant who lives in a world of his own resides at the walbrook institute charlie then kidnaps raymond and decides to take him on a lust for life trip to the west coast as a threa

In [103]:
sentence = 'will persist in the gulf It’s unlikely to end until the excess capacity in real estate has been cleared'
sentence

'will persist in the gulf It’s unlikely to end until the excess capacity in real estate has been cleared'

In [104]:
translate(sentence, encoder, decoder, source2idx, target2idx, s_max_len, t_max_len)

KeyError: 'persist'