In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class CptioningNetwork(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, train_CNN = False, bidir_lstm = True):
        super(CptioningNetwork, self).__init__()
        self.train_CNN = train_CNN
        self.bidir_lstm = bidir_lstm
        
        self.resNet = models.resnet18(pretrained=True)
        self.resNet.fc = nn.Linear(self.resNet.fc.in_features, embed_size)
        
        for name, param in self.resNet.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
                
        
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        if self.bidir_lstm:
            self.lstm = nn.LSTM(embed_size, int(hidden_size/2), num_layers, bidirectional = self.bidir_lstm)
        else:
            self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        
        self.linear = nn.Linear(hidden_size, vocab_size)
        
        self.dropout2 = nn.Dropout(0.4)
        
    def forward(self, images, captions):
        
        features = self.resNet(images)
        features = self.relu(features)
        
        embeddings = self.embed(captions)
        embeddings = torch.cat((self.dropout2(features.unsqueeze(1)), self.dropout1(embeddings)), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)

        return outputs
    
    def caption_forward(self, image, max_length=24):
        result = []
        output_results = []
        features = self.resNet(image)
        x = self.relu(features)
        x = x.unsqueeze(0)
        states = None
        with torch.no_grad():
            for _ in range(max_length):

                hiddens, states = self.lstm(x, states)
                outputs = self.linear(hiddens)
                output_results.append(list(outputs[0][0]))
                predicted_index = outputs.argmax(dim=2)
                x = self.embed(predicted_index)
                result.append(predicted_index[0][0].item())
        return result, torch.tensor(output_results)

In [3]:
def train(model, test_loader, train_loader, criterion, optimizer): 
    model.train()

    testloader_iterator = iter(test_loader)
    X, Y = next(testloader_iterator)

    for epoch in range(num_epochs):

        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        for idx, (imgs, captions) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:, :-1])
            
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )

            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()
    return model

In [4]:
def caption_test(model, test_data, l = 10, log = False):
    
    counter = 0
    losses = []
    test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)
    criterion = nn.CrossEntropyLoss(ignore_index = dataUtils.rev_vocabulary["<PAD>"])
    
    for i, j in test_loader:
        
        cap, res = model.caption_forward(i)
        
        if log:
            imshow(i)
            print('caption:', dataUtils.reverse_numericalize(np.array(j[0])))
            print('caption:', dataUtils.reverse_numericalize(np.array(cap)))
        
        res = res.unsqueeze(0)
        loss = criterion(
            res.reshape(-1, res.shape[2]), j.reshape(-1)
        )
        losses.append(loss.item())
        
        counter += 1
        if counter > l:
            break
            
    return np.array(losses).mean()

In [5]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [6]:
def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step