In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import sys
sys.path.append('/opt/cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN
import math
import nltk
nltk.download('punkt')



batch_size = 32            
vocab_threshold = 6        
vocab_from_file = True     
embed_size = 512           
hidden_size = 512          
num_epochs = 1             
save_every = 1             
print_every = 100          
log_file = 'training_log.txt'       

transform_train = transforms.Compose([ 
    transforms.Resize(256),                          
    transforms.RandomCrop(224),                      
    transforms.RandomHorizontalFlip(),               
    transforms.ToTensor(),                           
    transforms.Normalize((0.485, 0.456, 0.406),      
                         (0.229, 0.224, 0.225))])

data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file)

vocab_size = len(data_loader.dataset.vocab)

encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

p1 = list(decoder.parameters())
p2 = list(encoder.embed.parameters())
params = p1 + p2

optimizer = torch.optim.Adam(params, lr=0.001, betas = (0.9, 0.999), eps = 1e-08)

total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=1.01s)
creating index...


  0%|          | 550/414113 [00:00<02:30, 2743.21it/s]

index created!
Obtaining caption lengths...


100%|██████████| 414113/414113 [01:20<00:00, 5124.91it/s]


In [2]:
import torch.utils.data as data
import numpy as np
import os
import requests
import time

f = open(log_file, 'w')

old_time = time.time()
response = requests.request("GET", 
                            "http://metadata.google.internal/computeMetadata/v1/instance/attributes/keep_alive_token", 
                            headers={"Metadata-Flavor":"Google"})

for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        if time.time() - old_time > 60:
            old_time = time.time()
            requests.request("POST", 
                             "https://nebula.udacity.com/api/v1/remote/keep-alive", 
                             headers={'Authorization': "STAR " + response.text})
        
        indices = data_loader.dataset.get_train_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        images, captions = next(iter(data_loader))

        images = images.to(device)
        captions = captions.to(device)
        
        decoder.zero_grad()
        encoder.zero_grad()
        
        features = encoder(images)
        outputs = decoder(features, captions)
        
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        
        loss.backward()
        
        optimizer.step()
            
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
        
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        f.write(stats + '\n')
        f.flush()
        
        if i_step % print_every == 0:
            print('\r' + stats)
            
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-batch_size-32-%d.pkl' % epoch))
        torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-batch_size-32-%d.pkl' % epoch))

f.close()

Epoch [1/1], Step [100/12942], Loss: 3.4919, Perplexity: 32.8496
Epoch [1/1], Step [200/12942], Loss: 3.2534, Perplexity: 25.87761
Epoch [1/1], Step [300/12942], Loss: 3.2247, Perplexity: 25.1467
Epoch [1/1], Step [400/12942], Loss: 3.3074, Perplexity: 27.3152
Epoch [1/1], Step [500/12942], Loss: 3.0925, Perplexity: 22.0315
Epoch [1/1], Step [600/12942], Loss: 3.3092, Perplexity: 27.3624
Epoch [1/1], Step [700/12942], Loss: 3.4140, Perplexity: 30.3865
Epoch [1/1], Step [800/12942], Loss: 2.9398, Perplexity: 18.91236
Epoch [1/1], Step [900/12942], Loss: 2.8619, Perplexity: 17.4946
Epoch [1/1], Step [1000/12942], Loss: 2.5162, Perplexity: 12.3810
Epoch [1/1], Step [1100/12942], Loss: 3.0550, Perplexity: 21.2207
Epoch [1/1], Step [1200/12942], Loss: 3.0839, Perplexity: 21.8441
Epoch [1/1], Step [1300/12942], Loss: 2.5316, Perplexity: 12.5741
Epoch [1/1], Step [1400/12942], Loss: 3.2453, Perplexity: 25.6689
Epoch [1/1], Step [1500/12942], Loss: 2.8855, Perplexity: 17.9130
Epoch [1/1], Step