In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import time
import sys
from Preprocess import load_captions
from data_loader import DataLoader
from data_loader import get_loader 
from Vocabulary import Vocabulary
from model import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

import nltk
# nltk.download('punkt')

In [2]:
# Parameters

train_dir="./train"

model_path = 'models/'

crop_size = 224
lr = 1e-3
num_epochs = 80
batch_size = 64
num_workers = 2

hidden_size = 512
embed_size = 512
num_layers = 1

threshold = 20

log_step = 100
save_step = 200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
captions_dict = load_captions(train_dir)

In [4]:
vocab = Vocabulary(captions_dict, threshold)
vocab_size = vocab.index

In [5]:
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
    ])

In [7]:
dataloader = DataLoader(train_dir, vocab, transform)
imagenumbers, captiontotal, imagetotal= dataloader.gen_data()

In [8]:
data_loader = get_loader(imagenumbers, captiontotal, imagetotal, batch_size,
                         shuffle=True, num_workers=num_workers) 

In [9]:
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=lr)

In [11]:
total_step = len(data_loader)
start_train_time = time.time()
for epoch in range(num_epochs):
    
    # Train
    for i, (images, captions, lengths) in enumerate(data_loader):

        # Set mini-batch dataset
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Get training statistics
        stats = "Epoch {:0>2d}, Train step [{:0>3d}/{}], Loss: {:.4f}, Perplexity: {:>7.4f}, Time: {:>6.2f} min".\
        format(epoch+1, i+1, total_step, loss.item(), np.exp(loss.item()), (time.time() - start_train_time)/60)
        # Print training statistics (on same line)
        print("\r" + stats, end="")
        sys.stdout.flush()
        
        # Print training stats (on different line), reset time and save checkpoint
        if (i+1) % log_step == 0:
            print("\r" + stats)

        # Save the model checkpoints
        if (i+1) % save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(encoder.state_dict(), os.path.join(
                model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
    print()
    
    # Test
    

Epoch 1, Train step [100/235], Loss: 2.9047, Perplexity: 18.2600, Time:  1.25 min
Epoch 1, Train step [200/235], Loss: 2.6438, Perplexity: 14.0660, Time:  2.40 min
Epoch 1, Train step [235/235], Loss: 2.5773, Perplexity: 13.1622, Time:  2.79 min
Epoch 2, Train step [100/235], Loss: 2.4499, Perplexity: 11.5866, Time:  4.00 min
Epoch 2, Train step [200/235], Loss: 2.3458, Perplexity: 10.4421, Time:  5.14 min
Epoch 2, Train step [235/235], Loss: 2.3343, Perplexity: 10.3226, Time:  5.54 min
Epoch 3, Train step [100/235], Loss: 2.1136, Perplexity:  8.2777, Time:  6.74 min
Epoch 3, Train step [200/235], Loss: 2.2255, Perplexity:  9.2585, Time:  7.90 min
Epoch 3, Train step [235/235], Loss: 2.1556, Perplexity:  8.6334, Time:  8.29 min
Epoch 4, Train step [100/235], Loss: 2.0653, Perplexity:  7.8874, Time:  9.43 min
Epoch 4, Train step [200/235], Loss: 2.0168, Perplexity:  7.5140, Time: 10.51 min
Epoch 4, Train step [235/235], Loss: 2.1403, Perplexity:  8.5019, Time: 10.90 min
Epoch 5, Train s

In [12]:
print(decoder)

DecoderRNN(
  (embed): Embedding(1072, 512)
  (lstm): LSTM(512, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=1072, bias=True)
)
