## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as m
from torch.utils.data import DataLoader
from transforms import *
from torchvision.transforms import Compose
from torchsummary import summary
from random_caption_dataset import RandomCaptionDataset
from text_preprocessing import *
from inference import *
import matplotlib.pyplot as plt

print(f'PyTorch version: {torch.__version__}')
print("GPU found :)" if torch.cuda.is_available() else "No GPU :(")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PyTorch version: 1.6.0
No GPU :(


In [3]:
IMAGE_SIZE = 224
EMBEDDING_SIZE = 512
CONTEXT_SIZE = 4
train_annotations_file = './flickr8k/annotations/annotations_image_id_train.csv'
test_annotations_file = './flickr8k/annotations/annotations_image_id_test.csv'

## Datas section

In [4]:
# Init text preprocessing class
tp = TextPreprocessor(train_annotations_file, sep=';')

In [5]:
transforms = Compose([Rescale(256), 
                      RandomCrop(IMAGE_SIZE), 
                      ToTensor(), 
                      Normalize(),
                      AddDelimiters(),
                      PadSentence(tp),
                      OneHotEncode(tp)
                     ])

train_random_dataset = RandomCaptionDataset('./flickr8k/images/train/', train_annotations_file, transform=transforms)

print(f'Random dataset size: {len(train_random_dataset)}')

Random dataset size: 6000


In [6]:
caption = train_random_dataset[100]['caption']
print(caption)
print(tp.max_len)
print(caption.shape)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
40
(40, 8255)


In [None]:
batch_size = 1

# Build data loader
train_random_loader = DataLoader(train_random_dataset, batch_size=batch_size)

## Model section

In [None]:
# retrieve pretrained model for features extraction
base_cnn = m.resnet18(pretrained=True)
#base_cnn

In [None]:
# Keep only the feature extraction layers of the model
cnn = nn.Sequential(*(list(base_cnn.children())[:-1])).to(device, dtype=torch.float)
#summary(cnn, (3, IMAGE_SIZE, IMAGE_SIZE))

## Build LSTM + Embedding

In [None]:
vocab_size = tp.vocab_size
print(vocab_size)

# RNN with LSTM of  layer
class LSTMCaptioning(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(LSTMCaptioning, self).__init__()
        
        # Random init the lstm state
        self.h0 = torch.zeros((num_layers, batch_size, hidden_size)).to(device)
        self.c0 = torch.zeros((num_layers, batch_size, hidden_size)).to(device)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, dropout=0.5)
        self.hidden2out = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, previous_state=None):
         
        if previous_state is None:
            previous_state = (self.h0, self.c0)
            
        # Get hidden states for each t (out) , and latest one (h = (ht, ct))
        lstm_out, (hn, cn) = self.lstm(x, previous_state)
        
        # Convert output of rnn to output targeted size
        out = self.hidden2out(lstm_out.view(1, -1))
        
        # Compute probability distribution over all words for this t
        pt = F.log_softmax(out, dim=1)
                           
        return (hn, cn), pt

# Need to copy class here to load trained Ngram model    
class NGram(nn.Module):

    def __init__(self, vocab_size, embedding_dim, context_size):
        super(NGram, self).__init__()
        
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(context_size * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)

    def forward(self, inputs):
        
        embeds = self.embeddings(inputs).view(len(inputs), -1)
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        
        return log_probs


# Not trained embedding layer to encode words to hidden space
#embedding = nn.Embedding(vocab_size, EMBEDDING_SIZE).to(device)

# Load model for evaluation
ngram_model = NGram(vocab_size, EMBEDDING_SIZE, CONTEXT_SIZE)
ngram_model.load_state_dict(torch.load('./models/ngram_512_v1'))

embedding = list(ngram_model.children())[0].to(device)
embedding

## Train model

In [None]:
input_size = EMBEDDING_SIZE
hidden_size = 512
num_layers = 3

model = LSTMCaptioning(input_size, hidden_size, vocab_size, num_layers).to(device, dtype=torch.float)

In [None]:
'''
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

num_epoch = 5
step_count = len(train_random_loader)
loss_function = nn.NLLLoss()
losses = list()


for epoch in range(num_epoch):
    
    epoch_loss = 0
    
    for i, sample in enumerate(train_random_loader):
        
        image = sample['image'].to(device, dtype=torch.float)
        caption = sample['caption'].to(device, dtype=torch.long)
        
        # Reset grad
        model.zero_grad()
        
        # Get the input image embedding 
        image_embedding = cnn(image).view(-1, batch_size, EMBEDDING_SIZE)
        
        
        # Forward pass for t=-1: image
        (hn, cn), probs = model(image_embedding)
        
        del image_embedding
        del image
        
        target = tp.target_from_vect(caption[:, 0]).to(device)
        
        # Compute loss for 1st word prediction
        loss = loss_function(probs, target)

        
        # Forward pass for t>=0: n - 1 first words of the sentence
        for j, word in enumerate(caption[:, :-1][0]):

            # Get index of the word in embedding matrix
            idxs = torch.argmax(word)
            
            # Encode word to hidden space
            word_embedding = embedding(idxs).view(1, batch_size, EMBEDDING_SIZE)
            
            # Feed the rnn
            (hn, cn), probs = model(word_embedding, (hn, cn))
            
            target = tp.target_from_vect(caption[:, j+1]).to(device)
            
            # Add current word's loss
            loss += loss_function(probs, target)
            
        epoch_loss += loss
        
        # Compute loss and backprop
        loss.backward()
        optimizer.step()
          
        # Debug
        if((i+1) % int(step_count/20) == 0):
            print(
                        f"Epoch [{epoch + 1}/{num_epoch}]"
                        f", step [{i + 1}/{step_count}]"
                        f", loss: {loss.item():.4f}"
                    )
            
    losses.append(epoch_loss)
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, './models/model_random_init0_lstm3_checkpoint')
'''

In [None]:
# Save model 
#torch.save(model.state_dict(), './models/model_random_init0_lstm3')

In [None]:
'''
epochs = list(range(1, num_epoch+1))
mean_losses = list(map(lambda x: x/6000, losses))
plt.plot(epochs, mean_losses)
plt.xlabel('Epoch')
plt.ylabel('Average loss')
plt.grid()
plt.savefig('Loss model random v2 training with lr=0.1.png')
plt.show()
'''

## Test model performances

In [None]:
# Load model for evaluation
trained_model = LSTMCaptioning(input_size, hidden_size, vocab_size, num_layers)
trained_model.load_state_dict(torch.load('./models/model_random_init0_lstm3'))
trained_model.to(device)

In [None]:
# Create test loaders for datasets

# Only preprocess images
test_transforms = Compose([Rescale(256), 
                      RandomCrop(IMAGE_SIZE), 
                      ToTensor(), 
                      Normalize()])

test_random_dataset_unformatted = RandomCaptionDataset('./flickr8k/images/test', test_annotations_file)
train_random_dataset_unformatted = RandomCaptionDataset('./flickr8k/images/train', train_annotations_file)


test_random_dataset = RandomCaptionDataset('./flickr8k/images/test', test_annotations_file, transform=test_transforms)

test_random_loader = DataLoader(test_random_dataset, batch_size=batch_size)

In [None]:
image = test_random_dataset[300]['image'].view(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device, dtype=torch.float)
caption = test_random_dataset[300]['caption']
captions = beam_search(cnn, embedding, trained_model, image, tp, EMBEDDING_SIZE, hidden_size)

In [None]:
captions

In [None]:
from skimage import io
image = test_random_dataset_unformatted[100]['image']
io.imshow(image)

In [None]:
import sacrebleu
target = caption
pred = captions[10]

print(caption)
print(pred)
bleu = sacrebleu.sentence_bleu(pred, [target], smooth_method='exp')
print(bleu.score)
