In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from imgcaptioning.data_loader import get_loader
from imgcaptioning.model import CNN, RNN
from imgcaptioning.vocabulary import Vocabulary

Tranform the image for ResNet50

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

Check for cuda

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Hyperparameters

In [None]:
embed_size = 256
hidden_size = 512
vocab_threshold = 5
batch_size = 64
epochs = 3
save = 1
log = 20
log_file = "training_log.txt"

Load the training data

In [None]:
data_loader, no_of_batches = get_loader(
    transform=transform, 
    batch_size=32
)
print(no_of_batches)

In [None]:
voc = Vocabulary(vocab_threshold)

Load the models

In [None]:
encoder = CNN(embed_size=embed_size)
decoder = RNN(len(voc), embed_size, hidden_size)

encoder.to(device)
decoder.to(device)

Define loss and optimizer

In [None]:
criterion = (nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss())
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

Train the model

In [None]:
f = open(log_file, 'w')

for epoch in range(1, epochs+1):
    for i in range(1, no_of_batches+1):
        img, cap = next(iter(data_loader))
        img = img.to(device)
        cap = cap.to(device)

        encoder.zero_grad()
        decoder.zero_grad()

        features = encoder(img)
        outputs = decoder(features, cap)

        loss = criterion(outputs.view(-1, len(voc)), cap.view(-1))
        loss.backward()
        optimizer.step()

        stats = (
            f"Epoch [{epoch}/{epochs}], Step [{i}/{no_of_batches}], "
            f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}"
        )

        f.write(stats + "\n")
        f.flush()
        
        if i % log == 0:
            print("\r" + stats)
        
        if epoch % save == 0:
            torch.save(
                decoder.state_dict(), os.path.join("./models", "decoder-%d.pkl" % epoch)
            )
            torch.save(
                encoder.state_dict(), os.path.join("./models", "encoder-%d.pkl" % epoch)
            )

f.close()