<a href="https://colab.research.google.com/github/All4Nothing/pytorch-DL-project/blob/main/Ch02_CNN_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Download COCO Dataset

In [1]:
import urllib
import zipfile
from tqdm import tqdm

#https://stackoverflow.com/a/53877507/1558946
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

def download_data(url):
    print(f"{url} 다운로드 중 ...")
    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        zip_path, _ = urllib.request.urlretrieve(url, reporthook=t.update_to)

    print("압축을 푸는 중 ...")
    with zipfile.ZipFile(zip_path, "r") as f:
        for name in tqdm(iterable=f.namelist(), total=len(f.namelist())):
            f.extract(member=name, path="data_dir")

download_data("http://images.cocodataset.org/annotations/annotations_trainval2014.zip")
download_data("http://images.cocodataset.org/zips/train2014.zip")
download_data("http://images.cocodataset.org/zips/test2014.zip")

http://images.cocodataset.org/annotations/annotations_trainval2014.zip 다운로드 중 ...


annotations_trainval2014.zip: 253MB [00:07, 32.4MB/s]                           


압축을 푸는 중 ...


100%|██████████| 6/6 [00:06<00:00,  1.16s/it]


http://images.cocodataset.org/zips/train2014.zip 다운로드 중 ...


train2014.zip: 13.5GB [05:24, 41.6MB/s]                            


압축을 푸는 중 ...


100%|██████████| 82784/82784 [01:42<00:00, 806.41it/s] 


http://images.cocodataset.org/zips/test2014.zip 다운로드 중 ...


test2014.zip: 6.66GB [03:18, 33.6MB/s]                            


압축을 푸는 중 ...


100%|██████████| 40776/40776 [00:51<00:00, 790.13it/s]


### Import Libraries

In [1]:
import os
import nltk
import pickle
import numpy as np
from PIL import Image
from collections import Counter
from pycocotools.coco import COCO
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence

- `nltk` : 자연어 툴킷으로, 사전을 만들때 사용하기 좋다.
- `pycocotools` : COCO dataset을 다룰때 사용하기 좋다.
- `pack_padded_sequence` : 다양한 길이의 문장에 패딩을 적용해 고정된 길이의 문장으로 바꿀때 사용한다.

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

- `punkt` : tokenizer 모델이다. 이 모델을 통해 텍스트를 토큰화 할 수 있다.

### Build Vocab

In [3]:
class Vocab(object):
  def __init__(self):
    self.w2i = {} # word to index
    self.i2w = {} # index to word
    self.index = 0

  def __call__(self, token):
    if not token in self.w2i:
      return self.w2i['<unk']

  def __len__(self):
    return len(self.w2i)

  def add_token(self, token):
    if not token in self.w2i:
      self.w2i[token] = self.index
      self.i2w[self.index] = token
      self.index += 1

`class Vocab()`  
'word to index'와 'index to word' dictionary를 각각 만든다.  
기존에 없는 token인 경우, 딕셔너리에 추가한다.

`build_vocabulary`  
json text annotations를로딩하고, caption 내의 개별 단어를 토큰화하거나, 숫자로 전환하여 counter에 저장한다.  
threshold보다 낮은 빈도수를 갖는 단어는 제거하고, 나머지 wildcard token을 추가한다.  
vocab 객체를 생성하고, 로컬 시스템에 저장함으로서 나중에 모델을 재훈련하기 위해 vocab을 재생성하는 일을 방지할 수 있다.

In [4]:
def build_vocabulary(json, threshold):
  coco = COCO(json)
  counter = Counter()
  ids = coco.anns.keys()
  for i, id in enumerate(ids):
    caption = str(coco.anns[id]['caption'])
    tokens = nltk.tokenize.word_tokenize(caption.lower())
    counter.update(tokens)

    if (i+1) % 1000 == 0:
      print("[{}/{}] Tokenized the captions.".format(i+1, len(ids)))

  # if the word frequency is less than 'threshold', then the word is discarded.
  tokens = [token for token, cnt in counter.items() if cnt >= threshold]

  # Create vocab and add wild card token(for start, end, unknown, padding)
  vocab = Vocab()
  vocab.add_token('<pad>')
  vocab.add_token('<start>')
  vocab.add_token('<end>')
  vocab.add_token('<unk>')

  # Add the words to the vocab
  for i, token in enumerate(tokens):
    vocab.add_token(token)
  return vocab

In [6]:
vocab = build_vocabulary(json = '/content/data_dir/annotations/captions_train2014.json', threshold = 4)
vocab_path = './data_dir/vocabulary.pkl'
with open(vocab_path, 'wb') as f:
  pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))

loading annotations into memory...
Done (t=0.94s)
creating index...
index created!
[1000/414113] Tokenized the captions.
[2000/414113] Tokenized the captions.
[3000/414113] Tokenized the captions.
[4000/414113] Tokenized the captions.
[5000/414113] Tokenized the captions.
[6000/414113] Tokenized the captions.
[7000/414113] Tokenized the captions.
[8000/414113] Tokenized the captions.
[9000/414113] Tokenized the captions.
[10000/414113] Tokenized the captions.
[11000/414113] Tokenized the captions.
[12000/414113] Tokenized the captions.
[13000/414113] Tokenized the captions.
[14000/414113] Tokenized the captions.
[15000/414113] Tokenized the captions.
[16000/414113] Tokenized the captions.
[17000/414113] Tokenized the captions.
[18000/414113] Tokenized the captions.
[19000/414113] Tokenized the captions.
[20000/414113] Tokenized the captions.
[21000/414113] Tokenized the captions.
[22000/414113] Tokenized the captions.
[23000/414113] Tokenized the captions.
[24000/414113] Tokenized the 

`COCO()` : COCO API를 활용하기 위해 annotation 파일을 COCO 객체로 로드한다.
```python
coco = COCO(annFile)
coco.getCatlds()
coco.loadAnns(1)
coco.loadimgs(1)
...
```

`Counter()` : 배열을 넘겨받으면, 배열의 원소가 각각 몇 번 나오는지 저장된 객체를 반환한다. 또한, dictionary처럼 사용이 가능하다.  
```python
Counter(["hi", "hey", "hi", "hi", "hello", "hey"])
```
```plaintext
Counter({'hi': 3, 'hey': 2, 'hello': 1})
```

```python
if "o" in counter:
    print("o in counter")

del counter["o"]

if "o" not in counter:
    print("o not in counter")
```
```plaintext
o in counter
o not in counter
```

```python
with open(vocab_path, 'wb') as f:
  pickle.dump(vocab, f)
```
`vocab` 파일을 `vocab_pat`라는 이름으로 저장한다.

### Preprocessing images data

In [5]:
def reshape_image(image, shape):
    """Resize an image to the given shape."""
    return image.resize(shape, Image.ANTIALIAS)

def reshape_images(image_path, output_path, shape):
    """Reshape the images in 'image_path' and save into 'output_path'."""
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    images = os.listdir(image_path)
    num_im = len(images)
    for i, im in enumerate(images):
        with open(os.path.join(image_path, im), 'r+b') as f:
            with Image.open(f) as image:
                image = reshape_image(image, shape)
                image.save(os.path.join(output_path, im), image.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_im, output_path))

CNN 모델에 맞게 `256 x 256` 픽셀로 이미지를 변환한다.

In [8]:
image_path = './data_dir/train2014/'
output_path = './data_dir/resized_images/'
image_shape = [256, 256]
reshape_images(image_path, output_path, image_shape)

  return image.resize(shape, Image.ANTIALIAS)


[100/82783] Resized the images and saved into './data_dir/resized_images/'.
[200/82783] Resized the images and saved into './data_dir/resized_images/'.
[300/82783] Resized the images and saved into './data_dir/resized_images/'.
[400/82783] Resized the images and saved into './data_dir/resized_images/'.
[500/82783] Resized the images and saved into './data_dir/resized_images/'.
[600/82783] Resized the images and saved into './data_dir/resized_images/'.
[700/82783] Resized the images and saved into './data_dir/resized_images/'.
[800/82783] Resized the images and saved into './data_dir/resized_images/'.
[900/82783] Resized the images and saved into './data_dir/resized_images/'.
[1000/82783] Resized the images and saved into './data_dir/resized_images/'.
[1100/82783] Resized the images and saved into './data_dir/resized_images/'.
[1200/82783] Resized the images and saved into './data_dir/resized_images/'.
[1300/82783] Resized the images and saved into './data_dir/resized_images/'.
[1400/82

### Define image captions data loader

전처리한 이미지 캡션 데이터를 pytorch dataset으로 casting한다.

* casting : 형 변환. 자료형의 type을 바꾸는 것

In [6]:
class CustomCocoDataset(data.Dataset):
  """ COCO Custom Dataset compatible with torch.utils.data.DataLoader. """
  def __init__(self, data_path, coco_json_path, vocabulary, transform=None):
    """Set the path for images, captions and vocabulary wrapper.

    Args:
        root: image directory.
        json: coco annotation file path.
        vocab: vocabulary wrapper.
        transform: image transformer.
    """
    self.root = data_path
    self.coco_data = COCO(coco_json_path)
    self.indices = list(self.coco_data.anns.keys())
    self.vocabulary = vocabulary
    self.transform = transform

  def __getitem__(self, idx):
    """ Returns one data pair (image and caption). """
    coco_data = self.coco_data
    vocabulary = self.vocabulary
    annotation_id = self.indices[idx]
    caption = coco_data.anns[annotation_id]['caption']
    image_id = coco_data.anns[annotation_id]['image_id']
    image_path = coco_data.loadImgs(image_id)[0]['file_name']

    image = Image.open(os.path.join(self.root, image_path)).convert('RGB')
    if self.transform is not None:
      image = self.transform(image)

    # Convert caption(string) to word ids.
    word_tokens = nltk.tokenize.word_tokenize(str(caption).lower())
    caption = []
    caption.append(vocabulary('<start>'))
    caption.extend([vocabulary(token) for token in word_tokens])
    caption.append(vocabulary('<end>'))
    ground_truth = torch.Tensor(caption)
    return image, ground_truth

  def __len__(self):
    return len(self.indices)

def collate_function(data_batch):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    We should build custom collate_fn rather than using default collate_fn,
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption).
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data_batch.sort(key=lambda d: len(d[1]), reverse=True)
    imgs, caps = zip(*data_batch)

    # Merge images (from list of 3D tensors to 4D tensor).
    # Originally, imgs is a list of <batch_size> number of RGB images with dimensions (3, 256, 256)
    # This line of code turns it into a single tensor of dimensions (<batch_size>, 3, 256, 256)
    imgs = torch.stack(imgs, 0)

    # Merge captions (from list of 1D tensors to 2D tensor), similar to merging of images donw above.
    cap_lens = [len(cap) for cap in caps]
    tgts = torch.zeros(len(caps), max(cap_lens)).long()
    for i, cap in enumerate(caps):
        end = cap_lens[i]
        tgts[i, :end] = cap[:end]
    return imgs, tgts, cap_lens

def get_loader(data_path, coco_json_path, vocabulary, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco_dataset = CustomCocoDataset(data_path=data_path,
                       coco_json_path=coco_json_path,
                       vocabulary=vocabulary,
                       transform=transform)

    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    custom_data_loader = torch.utils.data.DataLoader(dataset=coco_dataset,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_function)
    return custom_data_loader

### Define Model

`CNN`

In [10]:
resnet = models.resnet152(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 152MB/s]


In [11]:
list(resnet.children())

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
 Sequential(
   (0): Bottleneck(
     (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (downsample): Sequential(
       (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum

In [12]:
resnet.fc

Linear(in_features=2048, out_features=1000, bias=True)

In [7]:
class CNNModel(nn.Module):
  def __init__(self, embedding_size):
    """ Load the pretrained ResNet-152 and replace top fc layer. """
    super(CNNModel, self).__init__()
    resnet = models.resnet152(pretrained=True)
    module_list = list(resnet.children())[:-1] # delete the last fc layer
    self.resnet_module = nn.Sequential(*module_list)
    self.linear_layer = nn.Linear(resnet.fc.in_features, embedding_size)
    self.batch_norm = nn.BatchNorm1d(embedding_size, momentum=0.01)

  def forward(self, input_images):
    """ Extract feature vectors from input images. """
    with torch.no_grad():
      resnet_features = self.resnet_module(input_images)
    resnet_features = resnet_features.reshape(resnet_features.size(0), -1)
    final_features = self.batch_norm(self.linear_layer(resnet_features))
    return final_features

K x 1000 (K : 끝에서 두번째 계층의 뉴런 수) 차원의 가중치 행렬을 K x 256 차원의 가중치 행렬로 대체한다

batch normalization을 수행해 hidden layer의 출력값의 변동폭을 제한할 수 있으며, 학습 속도도 높일 수 있다. 또한, 더 균등한(평균 0, 표준편차 1) 최적화 초평면으로 인해 더 높은 학습률을 사용할 수 있다.

`LSTM`

In [8]:
class LSTMModel(nn.Module):
  def __init__(self, embedding_size, hidden_layer_size, vocabulary_size, num_layers, max_seq_len=20):
    """ Set the hyper-parameters and build the layers. """
    super(LSTMModel, self).__init__()
    self.embedding_layer = nn.Embedding(vocabulary_size, embedding_size)
    self.lstm_layer = nn.LSTM(embedding_size, hidden_layer_size, num_layers, batch_first = True)
    self.linear_layer = nn.Linear(hidden_layer_size, vocabulary_size)
    self.max_seq_len = max_seq_len

  def forward(self, input_features, capts, lens):
    """ Decode image feature vectors and generates captions. """
    embeddings = self.embedding_layer(capts)
    embeddings = torch.cat((input_features.unsqueeze(1), embeddings), 1)
    lstm_input = pack_padded_sequence(embeddings, lens, batch_first = True)
    hidden_variables, _ = self.lstm_layer(lstm_input)
    model_outputs = self.linear_layer(hidden_variables[0])
    return model_outputs

  def sample(self, input_features, lstm_states = None):
    """ Generate captions for given image features using greedy search. """
    """ 확률이 가장 높은 문장을 선택한다 """
    sampled_indices = []
    lstm_inputs = input_features.unsqueeze(1)
    for i in range(self.max_seq_len):
      hidden_variables, lstm_states = self.lstm_layer(lstm_inputs, lstm_states)          # hiddens: (batch_size, 1, hidden_size)
      model_outputs = self.linear_layer(hidden_variables.squeeze(1))            # outputs:  (batch_size, vocab_size)
      _, predicted_outputs = model_outputs.max(1)                        # predicted: (batch_size)
      sampled_indices.append(predicted_outputs)
      lstm_inputs = self.embedding_layer(predicted_outputs)                       # inputs: (batch_size, embed_size)
      lstm_inputs = lstm_inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
    sampled_indices = torch.stack(sampled_indices, 1)                # sampled_ids: (batch_size, max_seq_length)
    return sampled_indices

LSTM layer는 순환 계층으로 LSTM 셀이 시간 차원을 따라 unfold 되어 LSTM 셀의 시간 배열을 구성한다. 여기서 이 셀은 각 시간 단계마다 단어의 예측 확률을 출력하고 가장 확률이 높은 단어가 출력 문장 뒤에 추가된다.  
각 시간 단계에서 LSTM 셀은 내부 셀 상태를 생성하고 이 상태는 다음 시간 단계의 LSTM 셀의 입력으로 전달된다. LSTM 셀이 <end> 토큰/단어를 출력할 때 까지 이 과정을 반복한다.

`nn.Embedding()`  
PyTorch에서 임베딩 층(embedding layer)을 만들어 훈련 데이터로부터 처음부터 임베딩 벡터를 학습한다.  
nn.Embedding()을 사용하여 학습 가능한 임베딩 테이블(룩업 테이블)을 만든다.  
nn.Embedding은 크게 두 가지 인자를 받는다.  
- num_embeddings : 임베딩을 할 단어들의 개수. 다시 말해 단어 집합의 크기입니다.
- embedding_dim : 임베딩 할 벡터의 차원입니다. 사용자가 정해주는 하이퍼파라미터입니다.
- padding_idx : 선택적으로 사용하는 인자입니다. 패딩을 위한 토큰의 인덱스를 알려줍니다.  
```python
embedding_layer = nn.Embedding(num_embeddings=len(vocab), embedding_dim=3, padding_idx=1)  
print(embedding_layer.weight)
```
```plaintext
Parameter containing:
tensor([[-0.1778, -1.9974, -1.2478],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.0921,  0.0416, -0.7896],
        [ 0.0960, -0.6029,  0.3721],
        [ 0.2780, -0.4300, -1.9770],
        [ 0.0727,  0.5782, -3.2617],
        [-0.0173, -0.7092,  0.9121],
        [-0.4817, -1.1222,  2.2774]], requires_grad=True)
```

### Training loop

In [9]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Create model directory
if not os.path.exists('models_dir/'):
    os.makedirs('models_dir/')


# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])


# Load vocabulary wrapper
with open('data_dir/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)


# Build data loader
custom_data_loader = get_loader('data_dir/resized_images', 'data_dir/annotations/captions_train2014.json', vocabulary,
                         transform, 128,
                         shuffle=True, num_workers=0)


# Build the models
encoder_model = CNNModel(256).to(device)
decoder_model = LSTMModel(256, 512, len(vocabulary), 1).to(device)


# Loss and optimizer
loss_criterion = nn.CrossEntropyLoss()
parameters = list(decoder_model.parameters()) + list(encoder_model.linear_layer.parameters()) + list(encoder_model.batch_norm.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)


loading annotations into memory...
Done (t=1.33s)
creating index...
index created!




`parameters` : 학습 대상이 되는 파라미터들만 따로 모은다. batch norm도 학습 대상이 되므로 추가해준다.

In [10]:
# Train the models
total_num_steps = len(custom_data_loader)
for epoch in range(5):
    for i, (imgs, caps, lens) in enumerate(custom_data_loader):

        # Set mini-batch dataset
        imgs = imgs.to(device)
        caps = caps.to(device)
        tgts = pack_padded_sequence(caps, lens, batch_first=True)[0]

        # Forward, backward and optimize
        feats = encoder_model(imgs)
        outputs = decoder_model(feats, caps, lens)
        loss = loss_criterion(outputs, tgts)
        decoder_model.zero_grad()
        encoder_model.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, 5, i, total_num_steps, loss.item(),
                          np.exp(loss.item())))

        # Save the model checkpoints
        if (i+1) % 1000 == 0:
            torch.save(decoder_model.state_dict(), os.path.join(
                'models_dir/', 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(encoder_model.state_dict(), os.path.join(
                'models_dir/', 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))

TypeError: ignored

### Predict caption

In [None]:
image_file_path = 'sample.jpg'


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_image(image_file_path, transform=None):
    img = Image.open(image_file_path).convert('RGB')
    img = img.resize([224, 224], Image.LANCZOS)

    if transform is not None:
        img = transform(img).unsqueeze(0)

    return img


# Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])


# Load vocabulary wrapper
with open('data_dir/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)


# Build models
encoder_model = CNNModel(256).eval()  # eval mode (batchnorm uses moving mean/variance)
decoder_model = LSTMModel(256, 512, len(vocabulary), 1)
encoder_model = encoder_model.to(device)
decoder_model = decoder_model.to(device)


# Load the trained model parameters
encoder_model.load_state_dict(torch.load('models_dir/encoder-2-3000.ckpt'))
decoder_model.load_state_dict(torch.load('models_dir/decoder-2-3000.ckpt'))


# Prepare an image
img = load_image(image_file_path, transform)
img_tensor = img.to(device)


# Generate an caption from the image
feat = encoder_model(img_tensor)
sampled_indices = decoder_model.sample(feat)
sampled_indices = sampled_indices[0].cpu().numpy()          # (1, max_seq_length) -> (max_seq_length)


# Convert word_ids to words
predicted_caption = []
for token_index in sampled_indices:
    word = vocabulary.i2w[token_index]
    predicted_caption.append(word)
    if word == '<end>':
        break
predicted_sentence = ' '.join(predicted_caption)

In [None]:
# Print out the image and the generated caption
%matplotlib inline
print (predicted_sentence)
img = Image.open(image_file_path)
plt.imshow(np.asarray(img))