<a href="https://colab.research.google.com/github/tezike/Show-attend-and-tell/blob/master/attention_inf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# from google.colab import drive
# drive.mount('/content/drive')

In [0]:
# !pip -q install fastai2

In [0]:
import torch
import gc
import dill as pickler
import joblib as picklizer
import fastai
from fastai import *
from fastai.vision import *
from fastai2.vision.core import *
from torchvision import transforms
pickler._dill._reverse_typemap['ClassType'] = type

In [0]:
if not os.path.exists('models'): os.mkdir('models')

In [0]:
# !cp '/content/drive/My Drive/Image_Captioning/models/vocab_coco.pkl' '/content/models/vocab_coco.pkl'

In [0]:
# !cp '/content/drive/My Drive/Image_Captioning/models/models/mygoodmodel_coco.pth' '/content/models/mygoodmodel_coco.pth'

In [0]:
# !cp '/content/drive/My Drive/Image_Captioning/models/loaded_learn_cpu.pkl' '/content/models/loaded_learn_cpu.pkl'

In [0]:
# !cp '/content/drive/My Drive/Image_Captioning/models/learn_cpu.pkl' '/content/models/learn_cpu.pkl'

In [0]:
!git clone --quiet 'https://github.com/tezike/download_google_drive.git'
os.chdir('download_google_drive')
# !python download_gdrive.py '1-QlLAWm3L1-Jyjj0R48tvphWdNP0O006' '../models/learn_cpu.pkl'
!python download_gdrive.py '1-_BHQhAuVl_PInF1GWO3x54sQv30XdMW' '../models/loaded_learn_cpu.pkl'
# !python download_gdrive.py '1-6bUIv2R12OmoE6X3a2MPLco0Uj7Mw4P' '../models/mygoodmodel_coco.pth'
!python download_gdrive.py '1-kTk4u9r7M9OY6f-qBQieakQTe17B_8B' '../models/vocab_coco.pkl'
shutil.rmtree('../download_google_drive')
os.chdir('..')

In [0]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [0]:
vocab = pickle.load(open('models/vocab_coco.pkl', 'rb'))

In [0]:
n_layers, attn_size = 1, 500
sz = 224

In [0]:
# last layers of resnet
def fc_layer(in_, out, p=0.1):
    return nn.Sequential(
    Flatten(),
    nn.Linear(in_, out),
    nn.Dropout(p),
    )

In [0]:
class Encoder(nn.Module):
    def __init__(self, device, dec_hidden_state_size, dec_layers, filter_width, num_filters):
        super().__init__()
        # Visual Encoder
        self.device = device
        self.base_network = nn.Sequential(*list(models.resnet101(pretrained=True).children())[:-2])
        self.freeze_base_network()
        self.concatPool = AdaptiveConcatPool2d(sz=1)
        self.adaptivePool = nn.AdaptiveAvgPool2d((filter_width, filter_width))
        self.filter_width = filter_width
        
        self.output_layers = nn.ModuleList([
            fc_layer(2*num_filters, dec_hidden_state_size) for _ in range(dec_layers)
        ])
          
    def forward(self, inp):
        #pdb.set_trace()
        enc_output = self.base_network(inp)
        annotation_vecs = self.adaptivePool(enc_output).view(enc_output.size(0), enc_output.size(1), -1)
        enc_output = self.concatPool(enc_output)
        
        dec_init_hidden_states = [MLP_layer(enc_output) for MLP_layer in self.output_layers]
        
        return torch.stack(dec_init_hidden_states, dim = 0), annotation_vecs.transpose(1, 2)
    
    def freeze_base_network(self):
        for layer in self.base_network:
            requires_grad(layer, False)
            
    def fine_tune(self, from_block=-1):
        for layer in self.base_network[from_block:]:
            requires_grad(layer, True)

In [0]:
class VisualAttention(nn.Module):
    def __init__(self, num_filters, dec_dim, att_dim):
        super().__init__()
        self.attend_annot_vec = nn.Linear(num_filters, att_dim)
        self.attend_dec_hidden= nn.Linear(dec_dim, att_dim)
        self.f_att = nn.Linear(att_dim, 1)  # Equation (4) in Xu et al. (2015)
        
    def forward(self, annotation_vecs, dec_hid_state):
        #pdb.set_trace()
        attended_annotation_vecs = self.attend_annot_vec(annotation_vecs)
        attended_dec_hid_state   = self.attend_dec_hidden(dec_hid_state)
        e = self.f_att(F.relu(attended_annotation_vecs + attended_dec_hid_state.unsqueeze(1))).squeeze(2)  # Eq. 4
        alphas = F.softmax(e, dim=1)  # Equation (5) in Xu et al. (2015)
        context_vec = (annotation_vecs * alphas.unsqueeze(2)).sum(1)  # Equations (13)
        
        return context_vec, alphas

In [0]:
class Decoder(nn.Module):
    def __init__(self, device, filter_width, num_filters, vocab_size, emb_sz, out_seqlen, n_layers=3, prob_teach_forcing=1, p_drop=0.3):
        super().__init__()
        self.n_layers, self.out_seqlen = n_layers, out_seqlen
        self.filter_width = filter_width
        self.num_filters = num_filters
        self.device = device  
        
        # Encoder
        self.encoder = Encoder(device, emb_sz, n_layers, filter_width, num_filters)
        
        # Attention
        self.att = VisualAttention(num_filters, emb_sz, 500)
        
        # Decoder
        self.emb = nn.Embedding(vocab_size, emb_sz) #create_emb(wordvecs, itos, emb_sz)
        self.rnn_dec = nn.GRU(num_filters + emb_sz, emb_sz, num_layers=n_layers, dropout=0 if n_layers == 1 else p_drop)  # square to enable weight tying
        self.out_drop = nn.Dropout(p_drop)
        self.out = nn.Linear(emb_sz, vocab_size)
        self.out.weight.data = self.emb.weight.data
        self.f_b = nn.Linear(emb_sz, num_filters)  # Section 4.2.1 in Xu et al. (2015)
        
        self.prob_teach_forcing = prob_teach_forcing
        self.initializer()
        
    def initializer(self):
        self.emb.weight.data.uniform_(-0.1, 0.1)
        
    def forward(self, x, y=None):
        #pdb.set_trace()
        h, annotation_vecs = self.encode(x)

        dec_inp = torch.zeros(h.size(1), requires_grad=False).long()
        dec_inp = dec_inp.to(self.device)
        res = []
        alphas = []
        
        for i in range(self.out_seqlen):
            #pdb.set_trace()
            dec_output, h, alpha = self.decode_step(dec_inp, h, annotation_vecs)
            res.append(dec_output)
            alphas.append(alpha)
            
            if (dec_inp == 1).all() or (y is not None and i >= len(y)):
                break            
            # teacher forcing
            elif y is not None and (self.prob_teach_forcing > 0) and (random.random() < self.prob_teach_forcing):
                dec_inp = y[i].to(self.device)
            else:
                dec_inp = dec_output.data.max(1)[1]  # [1] to get argmax
        
        return torch.stack(res), torch.stack(alphas)
    
    def encode(self, x):
        return self.encoder(x.to(self.device))
    
    def decode_step(self, dec_inp, h, annotation_vecs):
        #pdb.set_trace()
        context_vec, alpha = self.att(annotation_vecs, h[-1])
        beta = torch.sigmoid(self.f_b(h[-1]))
        context_vec = beta * context_vec  # Section 4.2.1 in Xu et al. (2015)
        
        emb_inp = self.emb(dec_inp).unsqueeze(0)  # adds unit axis at beginning so that rnn 'loops' once

        output, h = self.rnn_dec(torch.cat([emb_inp, context_vec.unsqueeze(0)], dim=2), h)
        output = self.out(self.out_drop(output[0]))

        # return F.log_softmax(output, dim=1), h, alpha
        return F.log_softmax(output, dim=1), h, alpha

In [0]:
decoder = Decoder(device, 7, 2048, len(vocab.itos), 1000, 50, n_layers, p_drop=0.2)

In [0]:
x_ = gc.collect()

In [0]:
# learn = pickler.load(open('models/learn_cpu.pkl', 'rb')) #use this

In [0]:
learn = picklizer.load(open('models/loaded_learn_cpu.pkl', 'rb'))

In [0]:
# from fastai.callbacks import SaveModelCallback

In [0]:
# learn.callback = SaveModelCallback(learn)
# learn.loss_func = None
# learn.metrics = None
# learn.purge()

In [0]:
# learn.model_dir = Path('models/')
# learn.path = Path('.')

In [0]:
# learn = learn.load('mygoodmodel_coco') #and this to create new_loaded joblib
# learn.data = None

In [0]:
# picklizer.dump(learn.model, open('/content/drive/My Drive/Image_Captioning/models/loaded_learn_cpu.pkl', 'wb'), compress=True)

In [0]:
decode_eval = learn.eval()

In [0]:
class HypothesisNode():
    """ Hypothesis Node class for performing Beam Search """
    def __init__(self, sequence, log_prob, hidden_state, alphas):
        """HypothesisNode constructur
        
        Args:
          sequence: A sequence of tokens
          log_prob: The log of the probability of this sequence
          hidden_state: The hidden state of the Decoder RNN after decoding the last token in the sequence
        """
        self._seq = sequence
        self._alphas = alphas
        self._log_prob = log_prob
        self._h = hidden_state
    
    @property
    def last_tok(self):
        """
        Returns:
          The last token in the sequence
        """
        return self._seq[-1]
        
    def update(self, tok, log_prob, new_h, new_alpha):
        """
        Updates the sequence with a new token and returns a new Hypothesis Node
        Args:
          tok: The new token that is appended to the sequence
          log_prob: The log of the probability ot this token
          new_h: The new hidden state of the Decoder RNN after this token
        
        Returns:
          An Hypothesis Node with the updated sequence, log probability and hidden state
        """
        return HypothesisNode(self._seq + [tok], self._log_prob + log_prob, new_h, self._alphas + new_alpha)
    
    def __str__(self):
        return ('Hyp(log_p = %4f,\t seq = %s)' % (self._log_prob, vocab.textify([t.item()for t in self._seq])))

In [0]:
class BeamSearch():
    """ Performs BeamSearch for seq2seq decoding or Image captioning """
    def __init__(self, enc_model, dec_model, beam_width=5, num_results=1, max_len=30, device=torch.device('cuda:0')):
        """BeamSearch object constructor
        Args:
          enc_model: A seq2seq encoder or cnn for image captioning
          dec_model: A RNN decoder model
          beam_width: int, the number of hypotheses to remember in each iteration
          max_len: int, the longest possible sequence
        """
        self._device = device
        self._enc_model = enc_model
        self._dec_model = dec_model
        self._beam_width = beam_width
        self._num_results = num_results
        self._max_len = max_len
        self._start_tok = 0
        self._end_tok   = 1
        self._annotation_vecs = None
        
    def __call__(self, img, verbose=False):
        """Performs the Beam search
        Args:
          img: the image to be annotated, torch tensor with 3 color channels
          verbose: bool, allows printing the intermediate hypotheses for better understanding
        
        Returns:
          The 'beam_width' most probable sentences
        """
        img = img.unsqueeze(0)
        h, annotation_vecs = self._enc_model(img)
        self._annotation_vecs = annotation_vecs
        
        hyps = [HypothesisNode([torch.zeros(1, requires_grad=False).long().to(self._device)], 0, h, [])]
        results = []
        
        step = 0
        width = self._beam_width
        while width > 0 and step < self._max_len:
            if verbose: print("\n Step: ",step)
            new_hyps = []
            for h in hyps:
                new_hyps.extend(self.get_next_hypotheses(h, width))
            
            new_hyps = sorted(new_hyps, key= lambda x: x._log_prob, reverse=True)
            if verbose: self.print_hypotheses(new_hyps, "Before narrowing:")
                
            hyps = []
            for h in new_hyps[:width]:
                if h.last_tok == self._end_tok:
                    results.append(h)
                    width = width - 1
                else:
                    hyps.append(h)
            
            if verbose: 
                self.print_hypotheses(hyps, "After narrowing:")
                self.print_hypotheses(results, "Results:")
                
            step += 1
         
        results.extend(hyps[:width])
        results = sorted(results, key=lambda x: x._log_prob/len(x._seq), reverse=True)
        
        if verbose: self.print_hypotheses(results, "Final:")
        
        if self._num_results == 1:
            return ([t.item() for t in results[0]._seq[1:-1]], torch.stack(results[0]._alphas))
        else:
            return [([t.item() for t in r._seq[1:-1]], torch.stack(r._alphas)) for r in results[:self._num_results]]
        
    def get_next_hypotheses(self, hyp, k):
        """Calculates the next 'beam_width' hypotheses given a Hypothesis Node
        Args:
          hyp: an Hypothesis Node containing a sequence, a log probability and a Decoder RNN hidden state
          k: the number of hypotheses to calculate
        Returns:
          A list with the 'beam_width' most probable sequences/Hypothesis Nodes
        """

        dec_outp, h, alphas = self._dec_model(hyp.last_tok, hyp._h, self._annotation_vecs)

        top_k_log_probs, top_k_toks = dec_outp.topk(k, dim=1)
        return [hyp.update(top_k_toks[0][i].unsqueeze(0), top_k_log_probs[0][i], h, list(alphas)) for i in range(k)]
    
    def print_hypotheses(self, hyps, description):
        print(description)
        for h in hyps:
            print(h)

In [0]:
# from scipy.misc import imresize
from scipy.ndimage.filters import gaussian_filter
from matplotlib.patheffects import Stroke, Normal
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# the functions fig2data and fig2img are taken from 
# http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
# Deprecation errors have been fixed

def fig2data ( fig ):
    """
    @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
    @param fig a matplotlib figure
    @return a numpy 3D array of RGBA values
    """
    # draw the renderer
    fig.canvas.draw ( )
 
    # Get the RGBA buffer from the figure
    w,h = fig.canvas.get_width_height()
    buf = np.fromstring( fig.canvas.tostring_argb(), dtype=np.uint8 )
    buf.shape = ( w, h,4 )
 
    # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
    buf = np.roll ( buf, 3, axis = 2 )
    return buf

def fig2img ( fig ):
    """
    @brief Convert a Matplotlib figure to a PIL Image in RGBA format and return it
    @param fig a matplotlib figure
    @return a Python Imaging Library ( PIL ) image
    """
    # put the figure pixmap into a numpy array
    buf = fig2data ( fig )
    w, h, d = buf.shape
    return Image.frombytes( "RGBA", ( w ,h ), buf.tostring( ) )
    
def draw_text(ax, xy, txt, sz=14):
    text = ax.text(*xy, txt, verticalalignment='top', color='white', fontsize=sz, weight='bold')
    draw_outline(text, 1)

def draw_outline(matplt_plot_obj, lw):
    matplt_plot_obj.set_path_effects([Stroke(linewidth=lw, foreground='black'), Normal()])

def show_img(im, figsize=None, ax=None, alpha=1, cmap=None):
    if not ax:
        fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha, cmap=cmap)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    return ax

def visualize_attention(im, pred, alphas, denorm, vocab, att_size=7, thresh=0., sz=224, return_fig_as_PIL_image=False):
    cap_len = len(pred)
    alphas = alphas.view(-1,1,  att_size, att_size).cpu().data.numpy()
    alphas = np.maximum(thresh, alphas)
    alphas -= alphas.min()
    alphas /= alphas.max()
 
    figure, axes = plt.subplots(cap_len//5 + 1,5, figsize=(12,8))
 
    for i, ax in enumerate(axes.flat):
        if i <= cap_len:
            ax = show_img(denorm(im), ax=ax)
            if i > 0:
                mask = np.array(PIL.Image.fromarray(alphas[i - 1,0]).resize((sz,sz)))
                blurred_mask = gaussian_filter(mask, sigma=8)
                show_img(blurred_mask, ax=ax, alpha=0.5, cmap='afmhot')
                draw_text(ax, (0,0), vocab.itos[pred[i - 1]])
        else:
            ax.axis('off')
    plt.tight_layout()

    if return_fig_as_PIL_image:
        return fig2img(figure)

In [0]:
beam_width = 5

In [0]:
beam_search = BeamSearch(learn.encode, learn.decode_step, beam_width, device=device)

In [0]:
# url = 'https://st2.depositphotos.com/1761942/9533/i/950/depositphotos_95337166-stock-photo-two-kid-boys-playing-on.jpg'

In [0]:
# !mkdir images

In [0]:
# download_url(url, 'images/rand2.jpg', overwrite=True)

In [0]:
inv_norm = transforms.Normalize(
    mean =  [-0.5238/0.3159, -0.5003/0.3091, -0.4718/0.3216],
    std = [1/0.3159, 1/0.3091, 1/0.3216]
)

denorm = transforms.Compose([
                            inv_norm,
                            # make the image PIL readable
                            transforms.functional.to_pil_image
])

In [0]:
# im = image2tensor(PILImage.create('images/rand2.jpg').to_thumb(224))/255.

In [0]:
# plt.imshow(im.permute(1,2,0)), im.permute(1,2,0).shape

In [0]:
# results = beam_search(im);
# x = vocab.textify(results[0])
# print(x)
# visualize_attention(im, results[0], results[1], denorm, vocab, att_size=7, sz=sz, thresh=0.02)

In [0]:
import ipywidgets
from ipywidgets import widgets, VBox
# from types import SimpleNamespace
import warnings
warnings. filterwarnings('ignore')

In [0]:
# defaults.use_cuda = True
text = widgets.Text()
output = widgets.Output()
output_attn = widgets.Output()
output_lbl = widgets.Output()
label = widgets.Label()
caption_lbl = widgets.Label()
attend = widgets.Button(description='Caption eet!!')

In [0]:
def on_click(change):
    # if len(upload.data)!=0:
    if len(upload_btn.data)!=0:
        img = PILImage.create(upload_btn.data[-1])
        output.clear_output()
        with output:
            display(img.to_thumb(224))
            
        gc.collect()
        results = beam_search(image2tensor(img.to_thumb(700))/255.);
        x = vocab.textify(results[0])
        with output_lbl:
            print(x)
        pred_caption = vocab.textify(results[0])
        caption_lbl.value = pred_caption
        output_attn.clear_output()
        with output_attn:
            visualize_attention(image2tensor(img.to_thumb(224))/255., results[0], results[1], denorm, vocab, att_size=7, sz=sz, thresh=0.02)
    else: print('Please upload an image')

attend.on_click(on_click)

In [0]:
# upload = SimpleNamespace(data=[open_image('images/rand2.jpg')])

In [0]:
upload_btn = widgets.FileUpload()

In [0]:
x_ = gc.collect()

In [0]:
display(VBox([widgets.Label('Attend to something'), 
      upload_btn, attend, output, output_lbl, caption_lbl]))