Архитектура модели, создающей описания изображения на английском языке, состоит из 3 частей: сверточной нейронной сети, модуля внимания и рекуррентной нейронной сети. В качестве CNN была выбрана предобученная сеть EfficientNet-B7, поскольку она является одной из самых эффективных сверточных нейросетей последнего поколения. При этом у нее были убраны два последних слоя – линейный слой и слой пулинга, а веса самой сети при обучении заморожены. Поверх данной сети был добавлен обучаемый сверточный слой с целью дообучения модели на выбранном наборе данных и придания ей большей гибкости. В качестве модуля внимания был выбран механизм Bahdanau Attention, так как он является достаточно распространенным для решаемой задачи и довольно прост в реализации. В качестве рекуррентной нейронной сети был использован стандартный модуль LSTM.

# Импорт библиотек и подготовка данных

In [None]:
import os
from collections import Counter
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import nltk
from torchtext.data.metrics import bleu_score
from PIL import Image
import matplotlib.pyplot as plt
spacy_eng = spacy.load("en_core_web_sm")

In [None]:
class Vocabulary:
    def __init__(self,freq_threshold):
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        self.stoi = {v:k for k,v in self.itos.items()}        
        self.freq_threshold = freq_threshold
        
    def __len__(self): return len(self.itos)
    
    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self,text):
        tokenized_text = self.tokenize(text)
        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ]

In [None]:
class Dataset(Dataset):

    def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())
        
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        caption = self.captions[idx]
        img_name = self.imgs[idx]
        img_location = os.path.join(self.root_dir,img_name)
        img = Image.open(img_location).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        caption_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(caption_vec)

In [None]:
class DatasetTest(Dataset):
    
    def __init__(self, root_dir, caption_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.transform = transform
        self.df = pd.read_csv(caption_file)
        self.imgs = os.listdir(self.root_dir)
        self.captions = self.df["caption"]
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_location = os.path.join(self.root_dir, img_name)
        img = Image.open(img_location).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img

In [None]:
def get_data_loader(dataset, batch_size, shuffle=False, num_workers=1):
    
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return data_loader

In [None]:
transforms = T.Compose([
    T.Resize((224,224)),
    T.ToTensor()
])

In [None]:
def show_image(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [None]:
dataset = Dataset(
    root_dir = 'coco-2017-dataset/coco2017/train2017',
    captions_file = 'cococaptions/train_captions.txt',
    transform=transforms
)

img, caps = dataset[0]
show_image(img,"Image")
print("Token:",caps)
print("Sentence:")
print([dataset.vocab.itos[token] for token in caps.tolist()])

In [None]:
class CapsCollate:
     
    def __init__(self,pad_idx,batch_first=False):
        self.pad_idx = pad_idx
        self.batch_first = batch_first
    
    def __call__(self,batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
        return imgs,targets

In [None]:
BATCH_SIZE = 4
NUM_WORKER = 1

pad_idx = dataset.vocab.stoi["<PAD>"]

data_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True,
    collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True)
)

In [None]:
dataiter = iter(data_loader)
batch = next(dataiter)
images, captions = batch

for i in range(BATCH_SIZE):
    img,cap = images[i],captions[i]
    caption_label = [dataset.vocab.itos[token] for token in cap.tolist()]
    eos_index = caption_label.index('<EOS>')
    caption_label = caption_label[1:eos_index]
    caption_label = ' '.join(caption_label)                      
    show_image(img,caption_label)
    plt.show()

In [None]:
def get_data_loader(dataset,batch_size,shuffle=False,num_workers=1):
    
    pad_idx = dataset.vocab.stoi["<PAD>"]
    collate_fn = CapsCollate(pad_idx=pad_idx,batch_first=True)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    return data_loader

## Проверка

In [None]:
BATCH_SIZE = 256
NUM_WORKER = 4

transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

data_loader = get_data_loader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True,
)

vocab_size = len(dataset.vocab)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
def show_image(img, title=None):

    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [None]:
dataset_val = Dataset(
    root_dir = "coco-2017-dataset/coco2017/val2017",
    captions_file = "cococaptions/val_captions.txt",
    transform=transforms
)

data_loader_val = get_data_loader(
    dataset=dataset_val,
    batch_size=1,
    num_workers=NUM_WORKER,
    shuffle=True,
)

In [None]:
len(dataset.vocab)

In [None]:
dataset_test =  DatasetTest(
    root_dir = "coco-2017-dataset/coco2017/test2017",
    caption_file = "cococaptions/val_captions.txt",
    transform=transforms
)

data_loader_test = get_data_loader(
    dataset=dataset_val,
    batch_size=1,
    num_workers=NUM_WORKER,
    shuffle=True,
)

## Архитектура модели

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        effnet = models.efficientnet_b7(pretrained=True)
        for param in effnet.parameters():
            param.requires_grad_(False)

        modules = list(effnet.children())[:-2]
        self.effnet = nn.Sequential(*modules)
        self.conv = nn.Sequential(nn.Conv2d(2560, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False),
                                  nn.BatchNorm2d(2048, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
                                  nn.SiLU(inplace=True))

    def forward(self, images):
        features = self.effnet(images)
        features = self.conv(features)
        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(-1))
        return features

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()        
        self.attention_dim = attention_dim        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)        
        self.A = nn.Linear(attention_dim,1)
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)
        w_ah = self.W(hidden_state)        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))        
        attention_scores = self.A(combined_states)
        attention_scores = attention_scores.squeeze(2)        
        alpha = F.softmax(attention_scores,dim=1)        
        attention_weights = features * alpha.unsqueeze(2)
        attention_weights = attention_weights.sum(dim=1)
        
        return alpha, attention_weights

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)

    def forward(self, features, captions):
        
        embeds = self.embedding(captions)
        
        h, c = self.init_hidden_state(features)
        seq_length = len(captions[0])-1
        batch_size = captions.size(0)
        num_features = features.size(1)        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))                    
            output = self.fcn(self.drop(h))            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,vocab=None):

        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)
        alphas = []     
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            alphas.append(alpha.cpu().detach().numpy())
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)        
            predicted_word_idx = output.argmax(dim=1)
            captions.append(predicted_word_idx.item())
            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))        
       
        return [vocab.itos[idx] for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = len(dataset.vocab),
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

## Инициализация модели и её обучение

In [None]:
embed_size=512
vocab_size = len(dataset.vocab)
attention_dim=512
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4

In [None]:
model = EncoderDecoder(
    embed_size=embed_size,
    vocab_size = len(dataset.vocab),
    attention_dim=attention_dim,
    encoder_dim=encoder_dim,
    decoder_dim=decoder_dim
).to(device)

In [None]:
model.load_state_dict(torch.load('cocoweights/model_weights.pth'))
model.train()

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def gen_original(txt):
    answer = []
    for token in txt[0].tolist():
        if token == 1:
            continue
        elif token == 0:
            break
        else:
            answer.append(dataset.vocab.itos[token])
    return answer

In [None]:
def gen_original_val(txt):
    answer = []
    for token in txt[0].tolist():
        if token == 1:
            continue
        elif token == 0:
            break
        else:
            answer.append(dataset_val.vocab.itos[token])
    return answer

In [None]:
num_epochs = 100
print_every = 300
k = 0

for epoch in range(1,num_epochs+1):   
    for idx, (image, captions) in enumerate(iter(data_loader)):
        image,captions = image.to(device),captions.to(device)
        optimizer.zero_grad()
        outputs,attentions = model(image, captions)
        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()

        if (idx+1)%print_every == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
            model.eval()
            with torch.no_grad():
                print('----------TRAIN----------')
                dataiter = iter(data_loader)
                img,txt = next(dataiter)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                reference = gen_original(txt)[:-2] if gen_original(txt)[-2] == '.' else gen_original(txt)[:-1]
                print("reference:", reference)
                print("caption:", caps[:-2])
                BLEUscore = nltk.translate.bleu_score.sentence_bleu([reference], caps[:-2], weights=(0.25, 0.25, 0.25, 0.25))
                print('BLEU score:', BLEUscore)
                show_image(img[0],title=caption)
                print("----------VALID----------")
                dataiter_val = iter(data_loader_val)
                img,txt = next(dataiter_val)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                reference = gen_original_val(txt)[:-2] if gen_original_val(txt)[-2] == '.' else gen_original_val(txt)[:-1]
                print("reference:", reference)
                print("caption:", caps[:-2])
                BLEUscore = nltk.translate.bleu_score.sentence_bleu([reference], caps[:-2], weights=(0.25, 0.25, 0.25, 0.25))
                print('BLEU score:', BLEUscore)
                show_image(img[0],title=caption)
                print("----------TEST----------")
                dataiter_test = iter(data_loader_test)
                img,txt = next(dataiter_test)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                show_image(img[0],title=caption)                
                torch.save(model.state_dict(), f'model_weights{k}.pth')
                k += 1
            model.train()
    torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')