In [None]:
%cd drive/MyDrive/ECS289_final/
import gensim
from gensim.models import word2vec
from gensim.models import KeyedVectors
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import*
from tensorflow.keras.models import Sequential

!pip install tensorflow_addons
import tensorflow_addons as tfa
from tensorflow.keras.layers import*

class Seq2SeqWithAttention(tf.keras.Model):
    def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, attention_layer_size, max_pred_len, start_token, end_token):
        super().__init__()

        EMBEDDING_FILE = '/content/drive/MyDrive/ECS289_final/dataset/GoogleNews-vectors-negative300.bin'
        self.word_matrix = KeyedVectors.load_word2vec_format(EMBEDDING_FILE, binary=True)

        self.units = units

        # encoder
        self.enc_embeddings = keras.layers.Embedding(
            input_dim=enc_v_dim, output_dim=emb_dim,    # [enc_n_vocab, emb_dim]
            embeddings_initializer=tf.initializers.RandomNormal(0., 0.1)
        )
        self.encoder = keras.layers.LSTM(units=units, return_sequences=True, return_state=True)

        # decoder
        self.attention = tfa.seq2seq.LuongAttention(
            units,    # Units is usde in dense function for computing e (scores)
            memory=None, 
            memory_sequence_length=None) 
            
        self.decoder_cell = tfa.seq2seq.AttentionWrapper(
            cell=keras.layers.LSTMCell(units=units),
            attention_mechanism=self.attention,
            attention_layer_size=None,
        )

        self.dec_embeddings = keras.layers.Embedding(
            input_dim=dec_v_dim, output_dim=emb_dim,    # [dec_n_vocab, emb_dim]
            embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
        )
        decoder_dense = keras.layers.Dense(dec_v_dim)   # output layer

        # train decoder
        self.decoder_train = tfa.seq2seq.BasicDecoder(
            cell=self.decoder_cell,
            sampler=tfa.seq2seq.sampler.TrainingSampler(),   # sampler for train
            output_layer=decoder_dense
        )
        self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.opt = keras.optimizers.Adam(0.05, clipnorm=5.0)

        # predict decoder
        self.decoder_eval = tfa.seq2seq.BasicDecoder(
            cell=self.decoder_cell,
            sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(),       # sampler for predict
            output_layer=decoder_dense
        )

        # prediction restriction
        self.max_pred_len = max_pred_len
        self.start_token = start_token
        self.end_token = end_token

    def encode(self, x):
        o = self.enc_embeddings(x)
        init_s = [tf.zeros((x.shape[0], self.units)), tf.zeros((x.shape[0], self.units))]
        
        # outputs (all hidden state of each time step), last hidden state(a), last cell state(c)
        o, h, c = self.encoder(o, initial_state=init_s)
        return o, h, c

    def set_attention(self, x):
        # encoder output for attention to focus
        # o: all hidden states for computing attention
        o, h, c = self.encode(x)          
        self.attention.setup_memory(o)
        
        # wrap state by attention wrapper
        '''
        [h, c] is cell state of decoder(s0),
        and s contains contexts(named as attention inside funciton), alignments (attention at each time step) 
        and alignment history.
        context = sum of alginments*hidden state from encoder(a)
        
        Then, it setup all hidden states (o) into attention object,
        and it initialize s0 of decoder with last [h, c] of encoder.
        
        After then, it could calculate scores (e) with s0 and o, then calculate attention (alpha) and names it alginments.
        Finally, it coudle get context named as attention inside of s.
        
        Moreover, we get s1, s2, ... at each time step and iterate the same step to get context c.
        '''
        s = self.decoder_cell.get_initial_state(batch_size=x.shape[0], dtype=tf.float32).clone(cell_state=[h, c])
        return s, o

    def inference(self, x):
        s, _ = self.set_attention(x)
        
        # s includes hidden state from docoder(s) and context(c)
        done, i, s = self.decoder_eval.initialize(
            self.dec_embeddings.variables[0],
            start_tokens=tf.fill([x.shape[0], ], self.start_token),
            end_token=self.end_token,
            initial_state=s,
        )
        
        pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32)
        for l in range(self.max_pred_len):
            o, s, i, done = self.decoder_eval.step(
                time=l, inputs=i, state=s, training=False)
            
            pred_id[:, l] = o.sample_id
            
            '''
            For prove s contains context(called attention) and attention (called alginments)

            c = tf.tensordot(tf.reshape(s[2][0, :], [1, -1]), _[0, :], axes=1)
            print(c == s[1])
            print(c.shape)
            print(s[1].shape)
            '''

        return pred_id

    def train_logits(self, x, y, seq_len):
        s,_ = self.set_attention(x)
        
        dec_in = y[:, :-1]   # ignore <EOS>
        dec_emb_in = self.dec_embeddings(dec_in)
        
        o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len)
        logits = o.rnn_output
        return logits

    def step(self, x, y, seq_len):
        
        with tf.GradientTape() as tape:
            logits = self.train_logits(x, y, seq_len)
            
            #print(logits.shape)
            dec_out = y[:, 1:]  # ignore <GO>
            
            loss = self.cross_entropy(dec_out, logits)
            grads = tape.gradient(loss, self.trainable_variables)
            
        self.opt.apply_gradients(zip(grads, self.trainable_variables))
        return loss.numpy()

In [None]:
model = VQAModel(len(ans_vocab), len(ques_vocab))