In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from get_loader import get_loader
from utils import make_prediction, check_accuracy, save_checkpoint, load_checkpoint
from image_caption import CNNtoRNN
from torchvision import models

__root_folder__  
Place where images are kept 
  
__annotation_file__  
csv file which contains image name and corresponding caption

In [2]:
def train():
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    
    train_loader, dataset = get_loader(
        root_folder = "flickr30k_images/flickr30k_images/",
        annotation_file = "flickr30k_images/results.csv", 
        transform = transform
    )
    
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda:3')
    load_model = False
    save_model = True
    
    #Hyperparameters
    embed_size = 256
    hidden_size =256
    vocab_size = len(dataset.vocab)
    num_layers = 1
    learning_rate = 3e-4
    epochs = 100
    
    # for tensorboard
    writer = SummaryWriter('runs/flickr30')
    step = 0
    
    #initialize model etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    if load_model:
        step = load_checkpoint(torch.load("mycheckpoint.pyh.tar"),model, optimizer)
        
    model.train()
    
    for epoch in range(epochs):
        if save_model:
            checkpoint = {
                "state_dict" : model.state_dict(),
                "optimizer" : optimizer.state_dict(),
                "step" : step
            }
            
        for idx, (imgs, captions) in enumerate(train_loader):
            imgs = imgs.to(device)
            captions = captions.to(device)
            
            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
            
            writer.add_scalar("Training loss : ", loss.item(), global_step=step )
            step += 1
            
            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()
            

In [None]:
train()