In [None]:
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
from tensorflow.keras.preprocessing import text_dataset_from_directory
from tensorflow.keras.layers import(SimpleRNN,Embedding,Input,LSTM,Input,Conv1D,Softmax
                                    Dropout,Dense,GRU,LayerNormalization,Reshape,
                                    Bidirectional,Reshape)
from tensorflow.data.experimental import AUTOTUNE
import numpy as np
import re
import string
import nltk
import datetime
import numpy as np
from matplotlib import pyplot as plt
import pandas

<H1>DATA PREPARATION</H1>

In [None]:
MAX_TOKEN=10000
INPUT_DIM=224
BATCH_SIZE=8
MODEL_SIZE=128
SEQUENCE_LENGTH=20
NUM_LAYERS=2
NUM_HEADS=8

In [None]:
path='...'
txt_path='...'

In [None]:
f=open(txt_path,"r+",encoding="utf-8")
lines=f.readlines()

In [None]:
data_dict={}
captions=''

In [None]:
for i in range(len(lines)):
    split=lines[i].split('\t')
    data_dict[split[0]]=split[1][:-3]
    captions+=split[1][:-3]
captions+=3000*' starttoken'

In [None]:
tokenizer=tf.keras.preprocessing.text.Tokenizer(num_words=MAX_TOKEN,
                                               oov_token="<unk>",
                                               filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')
tokenizer.fit_on_texts(captions.split(' '))

In [None]:
tokenizer.index_word[1]

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self,images,batch_size,tokenizers,data_dict,starttoken,INPUT_DIM):
        self.images=images
        self.batch_size=batch_size
        self.train_image_list=os.listdir(images)
        self.tokenizer=tokenizer
        self.data_dict=data_dict
        self.starttoken=starttoken
        self.INPUT_DIM=INPUT_DIM
        
    def __len__(self):
        return int(np.floor(len(self.train_image_list)/self.batch_size))
    def __getitem__(self,idx):
        X,y=self.__data_generation(idx)
        return X,y
    def __data_generation(self,idx):
        X=[]
        y_1=[]
        y_2=[]
        start=tf.constant(self.batch_size*5*[[self.starttoken]])
        
        for j in range(idx*self.batch_size,(idx+1)*self.batch_size):
            im_array=img_to_array(load_img(self.images+os.listdir(self.images)[j],target_size=(self.INPUT_DIM,self.INPUT_DIM)))
            X=5*self.BATCH_SIZE*[im_array]
            for i in range(5):
                caption=self.data_dict[os.listdir(self.images)[j]+'#'+str(i)]
                cap_seq=np.array(self.tokenizer.texts_to_sequences(caption.split(' '))).T
                cap_tok=tf.keras.preprocessing.sequence.pad_sequences(
                    cap_seq,maxlen=20,padding='post',truncating='post')
                y_1.append(cap_tok[0][:-1])
                y_2.append(cap_tok[0])
        X=tf.constant(X)
        y_1=tf.constant(y_1)
        y_2=tf.constant(y_2)
        
        y_1=tf.concat([start,y_1],axis=-1)
        return {'in_1':X,'in_2':y_1},y_2

In [None]:
train_images='...'
val_images='...'

starttoken=tokenizer.word_index['starttoken']


In [None]:
train_gen=DataGenerator(train_images,BATCH_SIZE,tokenizer,data_dict,starttoken,INPUT_DIM)
val_gen=DataGenerator(val_images,BATCH_SIZE,tokenizer,data_dict,starttoken,INPUT_DIM)

<H1>MODELING</H1>

In [None]:
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self,model_size):
        super(SelfAttention,self).__init__()
        self.model_size=model_size
    def call(self,query,key,value,sequence,look_ahead_masking=False):
        #score=tf.matmul(query,key,transpose_b=True)
        score=tf.einsum('ijk,ibk->ijb',query,key)
        score/=tf.math.sqrt(tf.cast(self.model_size,tf.float32))
        ones=tf.ones_like(score)
        pad_mask=padding_mask(sequence)
        
        total_mask=pad_mask
        if look_ahead_masking:
            ahead_mask=1-tf.linalg.band_part(ones,-1,0)
            total_mask+=ahead_mask
        score+=total_mask*-1e10
        alignment=tf.nn.softmax(score,axis=-1)
        head=tf.matmul(alignment,value)
        return head

In [None]:
def padding_mask(a):
    return tf.expand_dims(tf.cast(tf.math.equal([a],0),tf.float32)[0],axis=-2)

In [None]:
def positional_embedding(model_size):
    output=[]
    for pos in range(SEQUENCE_LENGTH):
        PE=np.zeros((model_size))
        for i in range(model_size):
            if i%2==0:
                PE[i]=np.sin(pos/(10000**(i/model_size)))
            else:
                PE[i]=np.cos(pos/(10000**((i-1)/model_size)))
        output.append(tf.expand_dims(PE,axis=0))
    return tf.concat(output,axis=0)

In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self,model_size,h):
        super(MultiHeadAttention,self).__init__()
        self.query_size=model_size//h
        self.key_size=model_size//h
        self.value_size=model_size//h
        self.h=h
        self.dense_q=[Dense(self.query_size) for _ in range(h)]
        self.dense_k=[Dense(self.key_size) for _ in range(h)]
        self.dense_v=[Dense(self.value_size) for _ in range(h)]
        self.dense_o=Dense(model_size)
        self.self_attention=SelfAttention(self,key_size)
        
    def call(self,query,key,value,sequence,look_ahead_masking):
        heads=[]
        
        for i in range(self.h):
            head=self.self_attention(self.dense_q[i](query),self.dense_k[i](key),
                                    self.dense_v[i](value),sequence,look_ahead_masking)
            heads.append(head)
        heads=tf.concat(heads,axis=2)
        heads=self.dense_o(heads)
        return heads

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self,model_size,num_layers,h):
        super(DecoderLayer,self).__init__()
        
        self.model_size=model_size
        self.num_layers=num_layers
        self.h=h
        
        self.multi_attention_bot=MultiHeadAttention(model_size,h)
        self.attetnion_bot_norm=LayerNormalization()
        
        self.multi_attention_mid=MultiHeadAttention(model_size,h)
        self.attetnion_mid_norm=LayerNormalization()
        
        self.dense_1=Dense(model_size*4,activation='relu')
        self.dense_2=Dense(model_size)
        self.dropout=Dropout(0.2)
        
        self.feed_forward_norm=LayerNormalization()
        
    def call(self,enc_in,sequence):
        bot_dec_out=self.multi_attention_bot(bot_dec_in,bot_dec_in,bot_dec_in,sequence,look_ahead_masking=True)
        bot_dec_out+=bot_dec_in
        bot_dec_out=self.attention_bot_norm(bot_dec_out)
        
        mid_dec_in=bot_dec_out
        
        mid_dec_out=self.multi_attention_mid(mid_dec_in,mid_dec_in,mid_dec_in,sequence,look_ahead_masking=False)
        mid_dec_out+=mid_dec_in
        mid_dec_out=self.attention_mid_norm(mid_dec_out)
        
        feed_forward_in=mid_dec_out
        
        feed_forward_out=self.dropout(self.dense_2(self.dense_1(feed_forward_in)))
        feed_forward_out+=feed_forward_in
        feed_forward_out=self.feed_forward_norm(feed_forward_out)
        return feed_forward_out

In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self,vocab_size,model_size,h,num_layers):
        super(Decoder,self).__init__()
        
        self.model_size=model_size
        self.num_layers=num_layers
        self.h=h
        self.embedding=Embedding(pre_vocab_size,model_size)
        self.decoder_layer=[DecoderLayer(model_size,num_layers,h) for _ in range(num_layers)]
        self.dense=Dense(vocab_size,)
        
    def call(self, sequence,encoder_output):
        dec_in=self.embedding(sequence)
        dec_in+=tf.cast(positional_embedding(self.model_size),dtype=tf.float32)
        
        for i in range(self.num_layers):
            out=self.decoder_layer[i](dec_in,encoder_output,sequence)
            dec_in=out
        out=self.dense(out)
        return out

In [None]:
def get_base_model():
    base_model=ResNet50(weights='imagenet',input_shape=(INPUT_DIM,INPUT_DIM,3),include_top=False)
    base_model.trainable=False
    
    conv4_block6_2_relu=[base_model.get_layer(layer_name).output for layer_name in ["conv4_block6_2_relu"]]
    return tf.keras.Model(
        inputs=[base_model.inputs],outputs=conv4_block6_2_relu)
get_base_model.summary()

In [None]:
inputs=Input((INPUT_DIM,INPUT_DIM,3))
pre_outputs=Input(SEQUENCE_LENGTH,)

x = Rescaling(1/255.)(inputs)
x=get_base_model()(inputs)
x=Conv2D(256,3,padding='same',activation='relu')(x)
x=Conv2D(128,3,padding='same',activation='relu')(x)
x=BatchNormalization()(x)

x=Conv2D(64,3,padding='same',activation='relu')(x)
x=Conv2D(1,3,padding='same',activation='relu')(x)
x=BatchNormalization()(x)
x=Reshape((14,14))(x)

x=Dense(MODEL_SIZE,activation='relu')(x)

dec=Decoder(vocab_size=MAX_TOKEN,model_size=MODEL_SIZE,h=NUM_HEADS,num_layers=NUM_LAYERS)
decoder_ouput=dec(pre_outputs,x)

model=tf.keras.Model([inputs,pre_outputs],decoder_output,name='conv-transformer')
model.summary()

<H1>TRAINING</H1>

In [None]:
LR=1e-3
EPOCH=100

In [None]:
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(lr=LR,),
    metrics='accuracy',
    run_eagerly=True,
)

In [None]:
checkpoint_filepath='...'
callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    save_best_only=True
)

In [None]:
history=model.fit(train_gen,verbose=1, shuffle=True,epochs=EPOCH,callbacks=[callback])

<h1>TESTING</h1>

In [None]:
def caption(input_image):
    plt.imshow(input_image[0]/255)
    plt.show()
    
    in_1=input_image
    in_2=[starttoken]
    final_output=[]
    length=SEQUENCE_LENGTH
    
    for i in range(SEQUENCE_LENGTH):
        p_in_2=tf.pad(tf.constant(in_2),[[0,SEQUENCE_LENGTH-1-i]])
        output=tf.argmax(model.predict([in_1,tf.expand_dims(p_in_2,0)]),-1)[0][i]
        
        if output==0:
            length=1
            break
        in_2.append(output.numpy())
        final_output.append(output.numpy())
    print(final_output)
    return ' '.join([tokenizer.index_word[i] for i in final_output])

In [None]:
im='...'
test_image=val_images+im
X=[]
X.append(img_to_array(load_img(test_image,target_size=(224,224))))
print(caption(tf.constant(X)))