In [2]:
import torch
import torch.nn as nn
import torchvision
import os
import pandas as pd
import spacy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

spacy_eng = spacy.load('en')

In [3]:
# Feature Extractor - Reteaua Neuronala Convolutionala care extrage features de dimensiunea embedding-ului dintr-o imagine
# si le utilizeaza in continuare in RNN pentru a face caption-ul pentru aceasta
class FeatureExtractorCNN(nn.Module):
    def __init__(self, embedding_size, continue_training=False):
        super(FeatureExtractorCNN, self).__init__()
        self.continue_training = continue_training
        self.extractor = torchvision.models.inception_v3(pretrained=True)
        self.extractor.fc = nn.Linear(in_features=self.extractor.fc.in_features, out_features=embedding_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        features = self.extractor(x).logits

        for name, parameters in self.extractor.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                parameters.require_grad = True
            else:
                parameters.require_grad = self.continue_training

        return self.dropout(self.activation(features))

In [4]:
# CaptioningRNN - Retea Neuronala Recurenta care utilizeaza features-urile obtinute din imaginea de input si,
# utilizand mai multe layere de LSTM (long short-term memory) selecteaza cele mai potrivite cuvinte care ar
# putea alcatui un caption corect pentru input-ul nostru
class CaptioningRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, layer_number):
        super(CaptioningRNN, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, layer_number)
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocabulary_size)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, features, captions):
        embeddings = self.embedding(captions)
        embeddings = self.dropout(embeddings)
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        
        hidden_output, _ = self.lstm(embeddings)
        output = self.fc(hidden_output)
        return output

In [5]:
# ImageCaptioningNet - Retea Neuronala care combina cele doua tipuri de retele neuronale
# de mai sus - CNN si RNN pentru a crea descrieri pentru pozele noastre
class ImageCaptioningNet(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, layer_number, continue_training=False):
        super(ImageCaptioningNet, self).__init__()
        self.feature_extractor_cnn = FeatureExtractorCNN(embedding_size, continue_training)
        self.captioning_rnn = CaptioningRNN(embedding_size, hidden_size, vocabulary_size, layer_number)
    
    def forward(self, images, captions):
        features = self.feature_extractor_cnn(images)
        output = self.captioning_rnn(features, captions)
        return output
    
    def generate_caption(self, image, vocabulary, max_length=50):
        caption = []
        with torch.no_grad():
            out = self.feature_extractor_cnn(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hidden_output, states = self.captioning_rnn.lstm(out, states)
                output = self.captioning_rnn.linear(hidden_output.squeeze(0))
                
                predicted_word = output.argmax(dim=1)
                caption.append(predicted_word.item())
                out = self.captioning_rnn.embedding(predicted_word).unsqueeze(0)

                if vocabulary.itos[predicted_word.item()] == "<EOS>":
                    break
        
        return [vocabulary.itos[word_idx] for word_idx in caption]

# Antrenarea
<br>
Mai intai, descarcam setul de imagini si titluri de pe kaggle si o dezarhivam. Pentru asta, e nevoie de incarcat kaggle.json care contine cheia privata a unui cont kaggle.

In [None]:
!pip install kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json



In [None]:
!kaggle datasets download -d aladdinpersson/flickr8kimagescaptions

In [None]:
!unzip flickr8kimagescaptions.zip 

## Pregatirea datelor

Pentru a putea utiliza datele descarcate, e nevoie de normalizat titlurile si mapat la valori numerice. Pentru a face asta, in prealabil vom crea vocabularul tuturor cuvintelor prezente in titlurile de antrenare. Vocabularul va reprezenta o mapare dintre cuvintele tokenizate si indexurile sale. De asemenea, se vor adauga inca 4 "cuvinte" in vocabular:
1. '\<PAD>': token adaugat la sfarsitul titlurilor mai scurte pentru ca toate titlurile din un batch sa aiba acelasi numar de tokenuri;
2. '\<SOS>': tokenul care reprezinta inceputul titlului;
3. '\<EOS>': tokenul care reprezinta sfarsitul titlului;
4. '\<UNK>': tokenul care se pune in locul cuvintelor care nu fac parte din vocabular.

Un token (cuvant) nu va fi adaugat in vocabular daca in toate titlurile acesta se intalneste mai putin de `freq_threshold` ori.



In [8]:
class Vocabulary:
  def __init__(self, freq_threshold):
    # initializam maparile dintre tokenuri si indexul sau si invers
    # adaugam tokenurile speciale la initializare
    self.idx_str_dict = {0: '<PAD>', 1:'<SOS>', 2:'<EOS>', 3:'<UNK>'}
    self.str_idx_dict = {'<PAD>':0, '<SOS>':1, '<EOS>':2, '<UNK>':3}
    self.freq_threshold = freq_threshold

  def __len__(self):
    return len(self.idx_str_dict)

  @staticmethod
  def tokenizer(text):
    """
    Functie statica care primeste un string care reprezinta o propozitie in engleza
    si returneaza o lista de stringuri care reprezinta tokenizarea acesteea.
    """
    return [token.text.lower() for token in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentences):
    """
    sunt parcurse si tokenizate titlurile si sunt pastrate cele care apar mai mult de `freq_threshold` ori
    """
    frequencies = {}
    index = 4
    for sentence in sentences:
      for word in self.tokenizer(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1
    
    for word, freq in frequencies.items():
      if freq >= self.freq_threshold:
        self.str_idx_dict[word] = index
        self.idx_str_dict[index] = word
        index += 1

  def normalize(self, text):
    """
    Primeste un titlu pe care il tokenizeaza si returneaza lista tokenilor.
    Tokenurile care nu sunt prezente in vocabular sunt inlocuite cu '<UNK>'
    """
    tokenized_text = self.tokenizer(text)

    return [
      self.str_idx_dict[token] if token in self.str_idx_dict else self.str_idx_dict['<UNK>']
      for token in tokenized_text 
    ]


In [9]:
class FlickrDataset(Dataset):
  def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
    self.root_dir = root_dir
    self.dataframe = pd.read_csv(captions_file)
    self.transform = transform
    
    self.images = self.dataframe['image']
    self.captions = self.dataframe['caption']

    self.vocabulary = Vocabulary(freq_threshold)
    self.vocabulary.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.dataframe)

  def __getitem__(self, index):
    caption = self.captions[index]
    image_id = self.images[index]
    image = Image.open(os.path.join(self.root_dir, image_id)).convert('RGB')

    if self.transform is not None:
      image = self.transform(image)
    
    normalized_caption = [self.vocabulary.str_idx_dict['<SOS>']] # start of sentence
    normalized_caption += self.vocabulary.normalize(caption)
    normalized_caption.append(self.vocabulary.str_idx_dict['<EOS>']) # end of sencence

    return image, torch.tensor(normalized_caption)

In [10]:
class Collate:
  """
  Clasa care adauga padding la toate titlurile dintr-un batch pentru a le face de aceeasi lungime
  """
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx
    
  def __call__(self, batch):
    images = [item[0].unsqueeze(0) for item in batch] # a dimensiune aditionala pentru concatenare
    images = torch.cat(images, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return images, targets

In [21]:
def get_dataloader(
    root_dir,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=2,
    shuffle=True, 
    pin_memory=True
):
  """
  Initializeaza un Dataset, completeaza vocabularul si pe baza Datasetului creaza un DataLoader
  """
  dataset = FlickrDataset(root_dir, annotation_file, transform=transform)
  
  pad_idx = dataset.vocabulary.str_idx_dict['<PAD>']

  dataloader = DataLoader(
      dataset=dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      shuffle=shuffle,
      pin_memory=pin_memory,
      collate_fn=Collate(pad_idx)
  )

  return dataloader, dataset

In [None]:
!ls 

In [None]:
# test
transform = transforms.Compose(
    [
      transforms.Resize((224, 224)),
      transforms.ToTensor()
    ]
)
dataloader, _ = get_dataloader('flickr8k/images/', annotation_file='flickr8k/captions.txt', transform=transform)

stop_at = 10
print(len(dataloader))
for index, (images, captions) in enumerate(dataloader):
  if index == stop_at:
    break

  print(images.shape)
  print(captions.shape)

## Functia de antrenare

In [26]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(
    model,
    dataloader,
    loss_fn,
    optimizer,
    epochs
):
  model.train()

  for epoch in range(epochs):

    total_loss = 0

    for idx, (images, captions) in enumerate(tqdm(dataloader)):
      images = images.to(DEVICE)
      captions = captions.to(DEVICE)

      outputs = model(images, captions[:-1])
      loss = loss_fn(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
      total_loss += loss
      optimizer.zero_grad()
      loss.backward(loss)
      optimizer.step()
    
    print(f'Epoca {epoch}, Avg loss: {total_loss/len(dataloader)}')
    torch.save(model.state_dict(), './model_last.pt')

  return model




In [None]:
transform = transforms.Compose(
    [
      transforms.Resize((356, 356)),
      transforms.RandomCrop((299, 299)),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ] 
)

dataloader, dataset = get_dataloader(
    root_dir='flickr8k/images',
    annotation_file='flickr8k/captions.txt',
    num_workers=2,
    transform=transform
)

embedding_size = 256
hidden_size = 256
vocabulary_size = len(dataset.vocabulary)
layer_number = 1
lr = 3e-4
epochs = 10

model = ImageCaptioningNet(
    embedding_size=embedding_size,
    hidden_size=hidden_size,
    vocabulary_size=vocabulary_size,
    layer_number=layer_number,
    continue_training=True
)

model.to(DEVICE)



loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.vocabulary.str_idx_dict['<PAD>'])
optimizer = optim.Adam(model.parameters(), lr=lr)


train(
    model=model,
    dataloader=dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    epochs=epochs
)