## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as m
from torch.utils.data import DataLoader
from transforms import *
from torchvision.transforms import Compose
from torchsummary import summary
from repeat_image_dataset import RepeatImageDataset
from text_preprocessing import *

print(f'PyTorch version: {torch.__version__}')
print("GPU found :)" if torch.cuda.is_available() else "No GPU :(")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device='cpu'

PyTorch version: 1.7.1
GPU found :)


In [2]:
IMAGE_SIZE = 224
EMBEDDING_SIZE = 512
CONTEXT_SIZE = 4
train_annotations_file = './flickr8k/annotations/annotations_image_id_train.csv'
test_annotations_file = './flickr8k/annotations/annotations_image_id_test.csv'

## Datas section

In [3]:
# Init text preprocessing class
tp = TextPreprocessor(train_annotations_file, sep=';')

In [4]:
transforms = Compose([Rescale(256), 
                      RandomCrop(IMAGE_SIZE), 
                      ToTensor(), Normalize(),
                      OneHotEncode(tp)])

train_repeat_dataset = RepeatImageDataset('./flickr8k/images/train/', train_annotations_file, transform=transforms)

print(f'Repeat dataset size: {len(train_repeat_dataset)}')

Repeat dataset size: 30000


In [5]:
batch_size = 1

# Build data loaders
train_repeat_loader = DataLoader(train_repeat_dataset, batch_size=batch_size)

## Model section

In [6]:
# retrieve pretrained model for features extraction
base_cnn = m.resnet18(pretrained=True)
#base_cnn

In [7]:
# Keep only the feature extraction layers of the model
cnn = nn.Sequential(*(list(base_cnn.children())[:-1])).to(device, dtype=torch.float)
#summary(cnn, (3, IMAGE_SIZE, IMAGE_SIZE))

## Build LSTM + Embedding

In [8]:
vocab_size = tp.vocab_size
print(vocab_size)

# RNN with LSTM of  layer
class LSTMCaptioning(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMCaptioning, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden2out = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, previous_state):
         
        # Get hidden states for each t (out) , and latest one (h = (ht, ct))
        lstm_out, (hn, cn) = self.lstm(x, previous_state)
        
        # Convert output of rnn to output targeted size
        out = self.hidden2out(lstm_out.view(1, -1))
        
        # Compute probability distribution over all words for this t
        pt = F.log_softmax(out, dim=1)
                           
        return (hn, cn), pt

# Need to copy class here to load trained Ngram model    
class NGram(nn.Module):

    def __init__(self, vocab_size, embedding_dim, context_size):
        super(NGram, self).__init__()
        
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(context_size * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)

    def forward(self, inputs):
        
        embeds = self.embeddings(inputs).view(len(inputs), -1)
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        
        return log_probs

# Load model for evaluation
ngram_model = NGram(vocab_size, EMBEDDING_SIZE, CONTEXT_SIZE)
ngram_model.load_state_dict(torch.load('./models/ngram_512_v1'))

embedding = list(ngram_model.children())[0].to(device)
embedding

8255


Embedding(8255, 512)

## Train model

In [9]:
input_size = EMBEDDING_SIZE
hidden_size = 256

model = LSTMCaptioning(input_size, hidden_size, vocab_size).to(device, dtype=torch.float)

In [None]:
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

num_epoch = 10
step_count = len(train_repeat_loader)
loss_function = nn.NLLLoss()

# Random init the lstm state
h0 = torch.rand((1, batch_size, hidden_size)).to(device, dtype=torch.float)
c0 = torch.rand((1, batch_size, hidden_size)).to(device, dtype=torch.float)


for epoch in range(num_epoch):
    for i, sample in enumerate(train_repeat_loader):
        
        
        image = sample['image'].to(device, dtype=torch.float)
        caption = sample['caption'].to(device, dtype=torch.long)
        
        # Reset grad
        model.zero_grad()
        
        # Get the input image embedding 
        image_embedding = cnn(image).view(-1, batch_size, EMBEDDING_SIZE)
        
        
        # Forward pass for t=-1: image
        (hn, cn), probs = model(image_embedding, (h0, c0))
        
        del image_embedding
        del image
        
        target = tp.target_from_vect(caption[:, 0]).to(device)
        
        # Compute loss for 1st word prediction
        loss = loss_function(probs, target)
        
        # Forward pass for t>=0: n - 1 first words of the sentence
        for j, word in enumerate(caption[:, :-1][0]):

            
            # Get index of the word in embedding matrix
            idxs = torch.argmax(word)
            
            # Encode word to hidden space
            word_embedding = embedding(idxs).view(1, batch_size, EMBEDDING_SIZE)
            
            # Feed the rnn
            (hn, cn), probs = model(word_embedding, (hn, cn))
            
            target = tp.target_from_vect(caption[:, j+1]).to(device)
            
            # Add current word's loss
            loss += loss_function(probs, target)

        
        # Compute loss and backprop
        loss.backward()
        optimizer.step()
        
        
        # Debug
        if((i+1) % int(step_count/5) == 0):
            print(
                        f"Epoch [{epoch + 1}/{num_epoch}]"
                        f", step [{i + 1}/{step_count}]"
                        f", loss: {loss.item():.4f}"
                    )
        

In [14]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, './models/model_v1_repeat')


## Test model performances

In [None]:
# Load model for evaluation
trained_model = LSTMCaptioning(input_size, hidden_size, vocab_size)
trained_model.load_state_dict(torch.load('./models/model_v1_repeat'))
trained_model.to(device)

In [None]:
# Create test loaders for datasets

# Only preprocess images
test_transforms = Compose([Rescale(256), 
                      RandomCrop(IMAGE_SIZE), 
                      ToTensor(), 
                      Normalize()])

test_repeat_dataset = RepeatImageDataset('./flickr8k/images/test/', test_annotations_file, transform=test_transforms)

test_repeat_loader = DataLoader(test_repeat_dataset, batch_size=batch_size)

In [None]:
with torch.no_grad():
    
    for sample in test_repeat_loader:
        
        caption = list()

        # Random init the lstm state
        h0 = torch.rand((1, batch_size, hidden_size)).to(device, dtype=torch.float)
        c0 = torch.rand((1, batch_size, hidden_size)).to(device, dtype=torch.float)

        # Encode input image
        image = sample['image'].to(device, dtype=torch.float)
        image_embedding = cnn(image).view(-1, batch_size, EMBEDDING_SIZE).to(device)

        # Get first word prediction probabilities
        (hn, cn), probs = model(image_embedding, (h0, c0))

        # Extract predicted word
        pred_idx = torch.argmax(probs)
        pred_word_vect = tp.encoding_matrix[pred_idx]
        predicted_word = tp.vect_to_word(pred_word_vect)

        caption.append(predicted_word)
        
        print(predicted_word)
        
        i = 0
        # Build caption until model outputs stop word
        while predicted_word != '<stop>' and i < 20:

            word_embedding = embedding(pred_idx).view(1, batch_size, EMBEDDING_SIZE).to(device)

            (hn, cn), probs = model(word_embedding, (hn, cn))

            pred_idx = torch.argmax(probs)
            pred_word_vect = tp.encoding_matrix[pred_idx]
            predicted_word = tp.vect_to_word(pred_word_vect)

            caption.append(predicted_word)

            print(predicted_word)
            
            i += 1

        caption = " ".join(caption)

        print(caption)
        
        break

In [9]:
test = torch.randn(10,3)
print(test)
print(torch.topk(test, k=3, dim=0))

tensor([[-0.0330, -0.8076,  0.4611],
        [-1.4472,  0.5975, -0.1127],
        [-0.7613, -0.2993,  0.7059],
        [-0.5506,  0.8039, -1.2187],
        [ 0.9716,  0.3158, -0.1192],
        [ 0.1314,  2.0221, -0.5677],
        [-1.1597,  1.0611, -0.8042],
        [-0.6171,  1.3323, -1.8513],
        [-0.5031, -1.8007, -1.6748],
        [-2.0235, -0.1504,  0.9782]])
torch.return_types.topk(
values=tensor([[ 0.9716,  2.0221,  0.9782],
        [ 0.1314,  1.3323,  0.7059],
        [-0.0330,  1.0611,  0.4611]]),
indices=tensor([[4, 5, 9],
        [5, 7, 2],
        [0, 6, 0]]))
