In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from utils import save_checkpoint, load_checkpoint, print_examples
from get_loader import get_loader
from model import CNNtoRNN
from tqdm import tqdm

In [2]:
def train():
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_loader, dataset = get_loader(
        root_folder="custom_datasets/flickr8k/images",
        annotation_file="custom_datasets/flickr8k/captions.txt",
        transform=transform,
        num_workers=2
    )

    torch.backends.cudnn.benchmark = True # performance boost
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    load_model = False
    save_model = True


    # Hyperparameters
    embed_size = 256
    hidden_size = 256
    vocab_size = len(dataset.vocab)
    num_layers = 1
    learning_rate = 3e-4
    num_epochs = 100

    # for tensorboard
    writer = SummaryWriter("runs/flickr")
    step = 0

    # initialize model, loss, etc.
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    for epoch in range(num_epochs):
        print_examples(model, device, dataset)
        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        for idx, (imgs, captions) in enumerate(tqdm(train_loader)):
            imgs = imgs.to(device)
            captions = captions.to(device)
            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))

            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()

In [None]:
if __name__ == "__main__":
    train()

Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: harbor observes fake cloudy form forehead forehead 23 shaggy trees harbor sparkler flight snowboarders parking old stripe electric cape making collie enjoy into harbor corn bathroom fist bounces sponsored keep sticks runway dunks reaches seashore chatting various protest aid seashore rocky paddling deep picking edge hall vertical barn orange rests
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: seen keep skiing moment deep pumpkin closed rails observes fake cloudy paddling deep picking edge hall vertical barn orange rests parachute upturned gazes juice amusement hooded drops old cloth stripe electric cape grins cords flipping gym participate underground checkered wheeled playful sidecar gloved lunch barrier barrier main wood concerned jack
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: harbor sparkler flight snowboarders parking cloudy form forehead forehead 23 shaggy trees harbor 

 23%|██████████████████▍                                                            | 295/1265 [01:45<03:50,  4.20it/s]