<a href="https://colab.research.google.com/github/AokiMasataka/LSTM_sample/blob/cat-feature-and-sentence/STAIR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt install aptitude
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
!pip install mecab-python3==0.7

In [2]:
import torch
from torch import nn, optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torch.nn.functional as F
from torchvision import models
import numpy as np
import json
import MeCab


tagger = MeCab.Tagger("-Owakati")

def inverse_dict(d):
    return {v:int(k) for k,v in d.items()}

json_open = open('drive/My Drive/Colab Notebooks/stair_captions/stair_captions_train.json', 'r', encoding="utf-8")
stairCaptions = json.load(json_open)

json_open = open('drive/My Drive/Colab Notebooks/stair_captions/id2index_v2.json', 'r', encoding="utf-8")
id2index = json.load(json_open)

json_open = open('drive/My Drive/Colab Notebooks/stair_captions/words.json', 'r', encoding="utf-8")
index2word = json.load(json_open)
word2index = inverse_dict(index2word)

VOCAB_SIZE = len(word2index) # length is 29931
EMBEDDING_DIM = 256
MAXMUM_WORDS = 24
BATCH_SIZE = 32
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def sentence2index(sentence, pad=True):
    sentence = tagger.parse(sentence)
    wakati = list(sentence.split(" "))
    wakati.insert(0, '<start>')
    wakati[len(wakati) - 1] = '<end>'
    if not pad:
      return torch.tensor([word2index[w] for w in wakati], dtype=torch.long)
    else:
      index = torch.tensor([word2index[w] for w in wakati])
      if MAXMUM_WORDS < index.shape[0]:
       return index[:MAXMUM_WORDS]
      padding = torch.zeros(MAXMUM_WORDS - index.shape[0])
      return torch.cat((index, padding), 0).to(torch.long)

def index2sentence(ndarray):
  sentence = ''
  for index in ndarray:
    sentence += index2word[str(index)]
  return sentence

In [None]:
from PIL import Image
import requests
from io import BytesIO

def getImg(url, toTensor=True):
  response = requests.get(url)
  img = np.array(Image.open(BytesIO(response.content)).convert('RGB')) / 255
  img = np.transpose(img[np.newaxis], (0, 3, 1, 2))

  if toTensor:
    return torch.tensor(img, dtype=torch.float32)
  else:
    return img

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        vgg = models.vgg16(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features)[:31]).eval().to(device)
        for param in self.features.parameters():
            param.requires_grad = False
        self.GAP = (vgg.avgpool).to(device)
        self.classifier = nn.Sequential(*list(vgg.classifier)[:1]).eval().to(device)
        for param in self.classifier.parameters():
            param.requires_grad = False
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False).to(device)
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False).to(device)

    def forward(self, image):
        image = (image - self.mean) / self.std
        x = self.features(image)
        x = self.GAP(x)
        x = x.reshape(x.shape[0], -1)
        x = self.classifier(x) # output shape(batchSize, 4096)
        return x

In [4]:
class Decoder(nn.Module):
    def __init__(self, vocabSize=VOCAB_SIZE, embeddingDim=EMBEDDING_DIM, hiddenDim=512):
        super(Decoder, self).__init__()
        self.hiddenDim = hiddenDim
        self.word_embeddings = nn.Embedding(vocabSize, embeddingDim)
        self.lstm = nn.LSTM(input_size=embeddingDim, hidden_size=hiddenDim, batch_first=True, num_layers=1)
        self.decod = nn.Linear(hiddenDim + 4096, vocabSize)

    def forward(self, feature, sentence, lengths):
        embeds = self.word_embeddings(sentence)
        embeds = attention(embeds, embeds, embeds)
        feature = feature.repeat(1, sentence.shape[1]).view(sentence.shape[0], -1, 4096)
        output, self.hidden = self.lstm(embeds)
        output = torch.cat((feature, output), dim=2)
        output = self.decod(output)
        return output

    def caption(self, feature, states=None):
        sampled_ids = []
        inputs = self.word_embeddings(torch.ones(1, dtype=torch.long).view(1, 1))
        feature = feature.unsqueeze(1)
        
        for i in range(MAXMUM_WORDS):
            hiddens, states = self.lstm(inputs, states)
            hiddens = torch.cat((feature, hiddens), dim=2)
            outputs = self.decod(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.word_embeddings(predicted.view(1, -1))
            if predicted.item() == 2:
              break                        
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

def attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1))
    scores = torch.nn.functional.softmax(scores, dim=-1)
    output = torch.matmul(scores, v)
    return output

In [5]:
def train(decoder, e=None):
  encordedArrays = np.load('drive/My Drive/Colab Notebooks/stair_captions/encorded_ndarray.npy')
  encordedTensors = torch.tensor(encordedArrays, dtype=torch.float)
  del encordedArrays

  if e:
    modelPath = 'drive/My Drive/Colab Notebooks/LSTM_models/epoch' + str(e)
    decoder.load_state_dict(torch.load(modelPath))
  else:
    e = 0

  decoder.train().to(device)
  optimizer = torch.optim.Adam(decoder.parameters(), lr=0.005)
  CEL = nn.CrossEntropyLoss()

  
  pad = torch.tensor([1])
  for epoch in range(1, epochs):
    print("epoch :", epoch + e)
    iterate = 0
    losses = 0
    sentence = []
    lengths = []
    inputFeature = torch.empty((1, 4096))
    for caption in stairCaptions['annotations']:
      try:
        idx = id2index[str(caption['image_id'])]
      except:
        continue

      lengths.append(sentence2index(caption['caption'], False).size(0))
      sentence.append(sentence2index(caption['caption'], False))
      inputFeature = torch.cat((inputFeature, encordedTensors[idx].unsqueeze(0)), dim=0)

      if len(lengths) == BATCH_SIZE:
        length = torch.tensor(lengths, dtype=torch.long).to(device)
        data = pad_sequence(sentence, batch_first=True)
        inputs = data[:, :-1].to(device)
        targets = data[:, 1:].to(device)
        inputFeature = inputFeature[1:].to(device)

        output = decoder(inputFeature, inputs, length)
        loss = CEL(output.view(-1, VOCAB_SIZE), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sentence.clear()
        lengths.clear()
        inputFeature = torch.empty((1, 4096))

        losses += loss.item()

      iterate += 1
      if (iterate + 1) % 10000 == 0:
        print("loss :", losses / 10000)
        losses = 0

    if (epoch+1) % 2:
      model_path = 'drive/My Drive/Colab Notebooks/LSTM_models/epoch' + str(epoch + e)
      torch.save(decoder.state_dict(), model_path)

In [None]:
decoder = Decoder()
train(decoder)

loss : 1.4225761444091798
loss : 1.4472218736648559
loss : 1.4223279064178467
loss : 1.4548809879302977
loss : 1.6147211206436156
loss : 1.5298406255722046
loss : 1.4144005081176758
loss : 1.4861709756851196
loss : 1.4933792249679565


In [None]:
def val_test(modelPath):
  BATCH_SIZE = 1

  json_open = open('drive/My Drive/Colab Notebooks/stair_captions/stair_captions_val.json', 'r', encoding="utf-8")
  stairCaptions_val = json.load(json_open)
 
  decoder = Decoder()
  decoder.load_state_dict(torch.load(modelPath))
  decoder.cpu()
  encoder = Encoder().to(device)
 
  for i in range(10):
    rand = np.random.randint(0, 20000)
    image = stairCaptions_val['images'][rand]
    try:
      tensorImg = getImg(image['flickr_url'])
    except:
      continue
    
    tensorImg = tensorImg.to(device)
    encordedImg = encoder(tensorImg)
    encordedImg = encordedImg.cpu().view(1, -1)
    sentence = decoder.caption(encordedImg)
 
    print(image['flickr_url'])
    sentence = sentence.detach().numpy().reshape(-1)
    print(index2sentence(sentence), " length :", len(sentence))

In [8]:
def test(modelPath):
  BATCH_SIZE = 1

  encordedArrays = np.load('drive/My Drive/Colab Notebooks/stair_captions/encorded_ndarray.npy')
  encordedTensors = torch.tensor(encordedArrays, dtype=torch.float)
  del encordedArrays

  decoder = Decoder()
  decoder.load_state_dict(torch.load(modelPath))
  decoder.cpu()

  rand = np.random.randint(0, 19000)
  for caption in stairCaptions['annotations'][rand:rand+20]:
    try:
      idx = id2index[str(caption['image_id'])]
    except:
      continue
    
    inputFeature = encordedTensors[idx].unsqueeze(0)
    sentence = decoder.caption(inputFeature)

    for image in stairCaptions['images']:
      if caption['image_id'] == image['id']:
        print(image['flickr_url'])
        break
    sentence = sentence.detach().numpy().reshape(-1)
    print(index2sentence(sentence), " length :", len(sentence))

In [9]:
modelPath = 'drive/My Drive/Colab Notebooks/LSTM_models/epoch40'
_test(modelPath)

http://farm6.staticflickr.com/5294/5483678464_5da175140e_z.jpg
赤いユニフォームを着た女性がボールを打とうとしている<end>  length : 16
http://farm4.staticflickr.com/3319/3278093803_d2029e02e2_z.jpg
トレーのトレーのトレーにトレーたトレーがある<end>  length : 12
http://farm7.staticflickr.com/6149/5960760081_d65e90243d_z.jpg
芝生の上でシマウマがついている<end>  length : 10
http://farm9.staticflickr.com/8544/8644704729_199b31d847_z.jpg
バーコードを付けたバーコードを付けたバーコードを付けたバーコードを付けたバーコード<end>  length : 18
