In [1]:
import torch
import torch.nn as nn
import torchvision

In [2]:
# Feature Extractor - Reteaua Neuronala Convolutionala care extrage features de dimensiunea embedding-ului dintr-o imagine
# si le utilizeaza in continuare in RNN pentru a face caption-ul pentru aceasta
class FeatureExtractorCNN(nn.Module):
    def __init__(self, embedding_size, continue_training=False):
        super(FeatureExtractorCNN, self).__init__()
        self.continue_training = continue_training
        self.extractor = torchvision.models.inception_v3(pretrained=True)
        self.extractor.fc = nn.Linear(in_features=self.extractor.fc.in_features, out_features=embedding_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        features = self.extractor(x)

        for name, parameters in self.extractor.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                parameters.require_grad = True
            else:
                parameters.require_grad = self.continue_training

        return self.dropout(self.activation(features))

In [3]:
# CaptioningRNN - Retea Neuronala Recurenta care utilizeaza features-urile obtinute din imaginea de input si,
# utilizand mai multe layere de LSTM (long short-term memory) selecteaza cele mai potrivite cuvinte care ar
# putea alcatui un caption corect pentru input-ul nostru
class CaptioningRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, layer_number):
        super(CaptioningRNN, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, layer_number)
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocabulary_size)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, features, captions):
        embeddings = self.embeddings(captions)
        embeddings = self.dropout(embeddings)
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        
        hidden_output, _ = self.lstm(embeddings)
        output = self.fc(hidden_output)
        return output

In [7]:
# ImageCaptioningNet - Retea Neuronala care combina cele doua tipuri de retele neuronale
# de mai sus - CNN si RNN pentru a crea descrieri pentru pozele noastre
class ImageCaptioningNet(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, layer_number, continue_training=False):
        super(ImageCaptioningNet, self).__init__()
        self.feature_extractor_cnn = FeatureExtractorCNN(embedding_size, continue_training)
        self.captioning_rnn = CaptioningRNN(embedding_size, hidden_size, vocabulary_size, layer_number)
    
    def forward(self, images, captions):
        features = self.feature_extractor_cnn(images)
        output = self.captioning_rnn(features, captions)
        return output
    
    def generate_caption(self, image, vocabulary, max_length=50):
        caption = []
        with torch.no_grad():
            out = self.feature_extractor_cnn(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hidden_output, states = self.captioning_rnn.lstm(out, states)
                output = self.captioning_rnn.linear(hidden_output.squeeze(0))
                
                predicted_word = output.argmax(dim=1)
                caption.append(predicted_word.item())
                out = self.captioning_rnn.embedding(predicted_word).unsqueeze(0)

                if vocabulary.itos[predicted_word.item()] == "<EOS>":
                    break
        
        return [vocabulary.itos[word_idx] for word_idx in caption]