In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import time
import sys
from Preprocess import load_captions
from data_loader import DataLoader
from data_loader import get_loader 
from Vocabulary import Vocabulary
from model import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# nltk.download('punkt')

In [2]:
# Parameters

train_dir = "./train"
test_dir = './test'

model_path = 'models/'

crop_size = 224
lr = 1e-3
num_epochs = 50
train_batch_size = 256
test_batch_size = 1     # ?
num_workers = 2

hidden_size = 512
embed_size = 256
num_layers = 1

threshold = 5      # Frequency of words

log_step = 100
test_log = 2500
save_step = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize smoothing function
smoothing = SmoothingFunction()

print("Parameters:\n Num_epochs: {}\n Batch_sizeL {} and {}\n hid_size: {}\
        \n embed_size: {}\n threshold: {}\n Learning rate: {:.4f}".
      format(num_epochs, train_batch_size, test_batch_size, hidden_size, embed_size, threshold, lr))

Parameters:
 Num_epochs: 50
 Batch_sizeL 256 and 1
 hid_size: 512        
 embed_size: 256
 threshold: 5
 Learning rate: 0.0010


In [3]:
# Define a transform to pre-process the training images
transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

# Define a transform to pre-process the validation images
transform_test = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.CenterCrop(224),                      # get 224x224 crop from the center
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

In [4]:
captions_dict = load_captions(train_dir)

In [5]:
vocab = Vocabulary(captions_dict, threshold)
vocab_size = vocab.index
print(vocab_size)

2754


In [6]:
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [7]:
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
print(decoder)

Encoder Model:  resnet18
DecoderRNN(
  (embed): Embedding(2754, 256)
  (lstm): LSTM(256, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=2754, bias=True)
)


In [8]:
train_dataloader = DataLoader(train_dir, vocab, transform_train)
train_image_numbers, train_caption_total, train_image_total= train_dataloader.gen_data()

In [9]:
test_dataloader = DataLoader(test_dir, vocab, transform_test)
test_image_numbers, test_caption_total, test_image_total= test_dataloader.gen_data()

In [10]:
train_data_loader = get_loader(train_image_numbers, train_caption_total,
                               train_image_total, train_batch_size,
                               shuffle=True, num_workers=num_workers) 

In [11]:
test_data_loader = get_loader(test_image_numbers, test_caption_total, 
                              test_image_total, test_batch_size,
                              shuffle=True, num_workers=num_workers) 

In [12]:
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=lr)

In [13]:
total_step = len(train_data_loader)
test_total = len(test_data_loader)
print(total_step)
print(test_total)

118
5000


In [None]:
start_train_time = time.time()

for epoch in range(num_epochs):
    
    #################
    # Train
    #################
    
    encoder.train()
    decoder.train()
    
    for i, (images, captions, lengths) in enumerate(train_data_loader):

        # Set mini-batch dataset
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Get training statistics
        stats = "Epoch {:0>2d}, Train step [{:0>3d}/{}], Loss: {:.4f}, Perplexity: {:>7.4f}, Time: {:>6.2f} min".\
        format(epoch+1, i+1, total_step, loss.item(), np.exp(loss.item()), (time.time() - start_train_time)/60)
        # Print training statistics (on same line)
        print("\r" + stats, end="")
        sys.stdout.flush()
        
        # Print training stats (on different line), reset time and save checkpoint
        if (i+1) % log_step == 0:
            print("\r" + stats)

        # Save the model checkpoints
        if (i+1) % save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(encoder.state_dict(), os.path.join(
                model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
    print()
    
    #################
    # Test
    #################
    
    encoder.eval()
    decoder.eval()

    # List to score the BLEU scores
    bleu_scores = []
    
    for i, (test_images, test_captions, test_lengths) in enumerate(test_data_loader):
        
        # Set mini-batch dataset
        test_images = test_images.to(device)

        # print(images.shape)
        # torch.Size([256, 3, 224, 224])
        # torch.Size([256, 2048, 1, 1])
        # torch.Size([256, 2048])
        # torch.Size([256, 256])

        # Generate an caption from the image
        feature = encoder(test_images)
        sampled_ids = decoder.sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()

        # Convert word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.id2word[word_id]
            sampled_caption.append(word)
            if word == '<end>':
                break
        output = ' '.join(sampled_caption)

        # Convert target word_ids to words
        test_caption = test_captions[0].cpu().numpy()
        target_caption = []
        for word_id in test_caption:
            word = vocab.id2word[word_id]
            target_caption.append(word)
            if word == '<end>':
                break
        target = ' '.join(target_caption)

        # Convert string to a list and ignore <start> <end>
        target_list = target.split()[1:-1]
        output_list = output.split()[1:-1]

        score = sentence_bleu([target_list], 
                              output_list, 
                              weights=(1, 0, 0, 0),
                              smoothing_function=smoothing.method7)
        bleu_scores.append(score)

        # print('{}:{:.4f}  '.format(i, score), end="")
    
        # Get training statistics
        test_stats = "Epoch {:0>2d}, Test step  [{:0>3d}/{}], BLEU: {:.4f}, Ave BLEU-1 {:>.4f}, Time: {:>6.2f} min".\
        format(epoch+1, i+1, test_total, score, np.mean(bleu_scores), (time.time() - start_train_time)/60)
        # Print training statistics (on same line)
        print("\r" + test_stats, end="")
        sys.stdout.flush()

        # Print training stats (on different line), reset time and save checkpoint
        if (i+1) % test_log == 0:
            print("\r" + test_stats)
    
    print('\n')
    
np.save("tests.npy", [bleu_scores, np.mean(bleu_scores)])

Epoch 01, Train step [100/118], Loss: 3.2953, Perplexity: 26.9859, Time:   0.95 min
Epoch 01, Train step [118/118], Loss: 3.1629, Perplexity: 23.6382, Time:   1.09 min
Epoch 01, Test step  [2500/5000], BLEU: 0.3108, Ave BLEU-1 0.5190, Time:   2.06 min
Epoch 01, Test step  [5000/5000], BLEU: 0.5303, Ave BLEU-1 0.5217, Time:   2.99 min


Epoch 02, Train step [100/118], Loss: 2.9706, Perplexity: 19.5028, Time:   3.85 min
Epoch 02, Train step [118/118], Loss: 3.0975, Perplexity: 22.1429, Time:   3.99 min
Epoch 02, Test step  [2500/5000], BLEU: 0.4939, Ave BLEU-1 0.5409, Time:   4.96 min
Epoch 02, Test step  [5000/5000], BLEU: 0.4596, Ave BLEU-1 0.5391, Time:   5.90 min


Epoch 03, Train step [100/118], Loss: 2.7195, Perplexity: 15.1726, Time:   6.76 min
Epoch 03, Train step [118/118], Loss: 2.8343, Perplexity: 17.0189, Time:   6.90 min
Epoch 03, Test step  [2500/5000], BLEU: 0.5639, Ave BLEU-1 0.5402, Time:   7.88 min
Epoch 03, Test step  [5000/5000], BLEU: 0.6189, Ave BLEU-1 0.5398, Time: