In [1]:
from tqdm.notebook import tqdm
import json
import os
import torch.utils.data as data
import math
from utils import clean_sentences, calculate_bleu_scores
import torch
from torch import nn, optim
from torchvision import transforms
import sys
sys.path.append('./cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN


batch_size = 64                # batch size
vocab_threshold = 5            # minimum word count threshold
vocab_from_file = True         # if True, load existing vocab file
embed_size = 300               # dimensionality of image and word embeddings
hidden_size = 256              # number of features in hidden state of the RNN decoder
num_epochs = 1                 # number of training epochs
train_images_folder = "train2014"
val_images_folder = "val2014"
train_annotations_file = "captions_train2014.json"
val_annotations_file = "captions_val2014.json"


train_transform = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])
val_transform = transforms.Compose([
    transforms.Resize(256),                          
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),      
                         (0.229, 0.224, 0.225))])


# Build data loader.
train_loader = get_loader(transform=train_transform,
                          mode='train',
                          batch_size=batch_size,
                          vocab_threshold=vocab_threshold,
                          vocab_from_file=vocab_from_file,
                          img_folder=train_images_folder,
                          annotations_file=train_annotations_file)
val_loader = get_loader(transform=val_transform,
                        mode='test',
                        batch_size=batch_size,
                        img_folder=val_images_folder,
                        annotations_file=val_annotations_file)

# The size of the vocabulary.
vocab_size = len(train_loader.dataset.vocab)        # type: ignore

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

params = list(decoder.parameters()) + list(encoder.parameters())

optimizer = optim.Adam(params, lr=3e-4)

# Set the total number of training steps per epoch.
total_steps = math.ceil(len(train_loader.dataset.caption_lengths) / train_loader.batch_sampler.batch_size) # type: ignore


  warn(f"Failed to load image Python extension: {e}")


Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=1.12s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 414113/414113 [00:32<00:00, 12768.15it/s]


Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.55s)
creating index...
index created!


In [2]:
def train(loader, total_steps):
    encoder.train()
    decoder.train()
    epoch_loss = 0

    for i_step in tqdm(range(1, total_steps+1)):
        # Randomly sample a caption length, and sample indices with that length.
        indices = loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)     # type: ignore
        loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        preds = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(preds.view(-1, vocab_size), captions.view(-1))
        epoch_loss += loss.item()
        
        # Backward pass.
        loss.backward()
        
        # Update the parameters in the optimizer.
        optimizer.step()
        break

    return epoch_loss / total_steps

In [3]:
def validate(loader):
    encoder.eval()
    decoder.eval()
    results = []

    for images, img_ids in tqdm(loader):
        with torch.no_grad():
            images = images.to(device)

            # Obtain the embedded image features.
            features = encoder(images).unsqueeze(1)

            # Pass the embedded image features through the model to get a predicted caption.
            output = decoder.generate_captions(features)
            sentences = clean_sentences(loader.dataset.vocab.idx2word, output)
            results.extend([{"image_id": img_id.item(), "caption": sentence} for img_id, sentence in zip(img_ids, sentences)])  # type: ignore 
        break
    with open("results.json", 'w') as res_file:
        json.dump(results, res_file)


In [5]:
for epoch in range(num_epochs):
    train_loss = train(train_loader, total_steps)
    validate(val_loader)
    bleu_scores = calculate_bleu_scores(val_loader.dataset.coco, "results.json")    # type: ignore
 
    with open("statistics.txt", 'a') as file:
        file.write(str(train_loss) + '\n')
        file.write(' '.join(map(str, bleu_scores)) + '\n\n')
        
    torch.save(encoder.state_dict(), "./models/encoder.pt")
    torch.save(decoder.state_dict(), "./models/decoder.pt")


  0%|          | 0/6471 [00:00<?, ?it/s]

  0%|          | 0/633 [00:00<?, ?it/s]