In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np
from torchaudio.functional import edit_distance
from torchvision import transforms

from transformers import AutoTokenizer, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import torchvision.models as models

class EncoderCNN(nn.Module):            # copied from the link in the task
    def __init__(self, embedding_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super().__init__()
        resnet = models.resnet152(weights="ResNet152_Weights.DEFAULT")
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embedding_size)
        self.bn = nn.BatchNorm1d(embedding_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class Decoder(nn.Module):
    def __init__(self, 
                 dict_size,
                 input_dim=128, 
                 embedding_dim=128,
                 n_layers=1,
                 max_len=100, 
                 bos_token_id=1):      
        super().__init__()

        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_size=input_dim//2, 
                            num_layers=n_layers,
                            batch_first=True)
        self.linear = nn.Linear(input_dim // 2, dict_size)
        
        self.max_len = max_len
        self.embedding = nn.Embedding(dict_size, embedding_dim)
        self.bos_token_id = bos_token_id
        self.input_dim = input_dim

    def forward(self, encoder_output, captions):
        
        h = encoder_output[:, :, :self.input_dim//2].contiguous()
        c = encoder_output[:, :, self.input_dim//2:].contiguous()
        
        inputs = self.embedding(captions)
        
        lstm_outputs, (_, _) = self.lstm(inputs, (h, c))
        
        logps = F.log_softmax(self.linear(lstm_outputs), dim=-1)
        
        return logps
        
    def sample(self, encoder_output, max_len=None):
        with torch.no_grad():

            if max_len is None:
                max_len = self.max_len
                
            h = encoder_output[:, :, :self.input_dim//2].contiguous()
            c = encoder_output[:, :, self.input_dim//2:].contiguous()
    
            cur_token_emb = self.embedding(torch.empty(encoder_output.shape[1], 1).fill_(self.bos_token_id).int().to(device))
            
            logps = []
            
            while len(logps) < max_len:
                output, (h, c) = self.lstm(cur_token_emb, (h, c))
                next_logp = F.log_softmax(self.linear(output), dim=-1)
                logps.append(next_logp)
    
                cur_token_emb = self.embedding(torch.argmax(next_logp.detach(), dim=-1))

        return torch.cat(logps, dim=1)


class ImageCaptionModel(nn.Module):
    def __init__(self,
                 dict_size,
                 bos_token_id,
                 embedding_size=256,
                 decoder_emb_dim=128,
                 n_layers=1,
                 max_len=50):
        
        super().__init__()
        self.encoder = EncoderCNN(embedding_size)
        self.decoder = Decoder(dict_size=dict_size,
                               input_dim=embedding_size, 
                               embedding_dim=decoder_emb_dim,
                               n_layers=n_layers, 
                               max_len=max_len, 
                               bos_token_id=bos_token_id)
        self.embedding_size = embedding_size
        self.decode_n_layers = n_layers
        
    def forward(self, images, captions):
        images_embeddings = self.encoder(images)
        features = images_embeddings.view(self.decode_n_layers, -1, self.embedding_size)
        logps = self.decoder(features, captions)
        return logps

    def sample(self, images, max_len=None):
        images_embeddings = self.encoder(images)
        features = images_embeddings.view(self.decode_n_layers, -1, self.embedding_size)
        logps = self.decoder.sample(features, max_len)
        return logps

In [3]:
from pycocotools.coco import COCO
from PIL import Image

class ImageCaptionDataset(nn.Module):
    def __init__(self, captions_path, img_dir, transforms=None, first_k=10000, n_after_first_k=None):
        self.coco = COCO(captions_path)
        if n_after_first_k is not None:
            self.ids = list(self.coco.anns.keys())[first_k:n_after_first_k]
        elif first_k is not None:
            self.ids = list(self.coco.anns.keys())[:first_k]
        else:
            self.ids = list(self.coco.anns.keys())
        self.transforms = transforms
        self.img_dir = img_dir
        
    def __len__(self):
        return len(self.ids)
        
    def __getitem__(self, idx):
        ann_id = self.ids[idx]
        caption = self.coco.anns[ann_id]["caption"]
        img_id = self.coco.anns[ann_id]["image_id"]
        path = self.coco.loadImgs(img_id)[0]["file_name"]
        
        image = Image.open(self.img_dir+"/"+path).convert("RGB")
        if self.transforms is not None:
            image = self.transforms(image)

        return image, str(caption)


class Collator:
    def __init__(self, tokenizer, max_len=50):
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __call__(self, raw_batch):
        images = [elem[0] for elem in raw_batch]
        captions = [elem[1] for elem in raw_batch]

        input_ids = self.tokenizer(captions, padding=True, truncation=True, 
                                   max_length=self.max_len, return_tensors="pt")["input_ids"]
        
        return torch.stack(images), input_ids

В качестве метрики снова возьмем minimal edit distance потому что почему бы и нет.

In [4]:
def train_epoch(train_loader, model, loss_function, optimizer, callback=None):
    epoch_loss = 0
    total = 0
    for it, (images, captions) in enumerate(tqdm(train_loader, leave=False)):
                            
        batch_loss = train_on_batch(model, images, captions, optimizer, loss_function)
        
        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)
            
        epoch_loss += batch_loss * len(images)
        total += len(images)
    
    return epoch_loss / total


def train_on_batch(model, images, captions, optimizer, loss_function):
    model.train()
    optimizer.zero_grad()
    preds = model(images.to(device), captions.to(device))
    loss = loss_function(preds, captions.to(device))
    loss.backward()

    optimizer.step()
    return loss.detach().cpu().item()


def trainer(count_of_epoch, 
            batch_size, 
            loader,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        
        
        epoch_loss = train_epoch(train_loader=loader, 
                    model=model, 
                    loss_function=loss_function,
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})


class Callback():
    def __init__(self, writer, test_loader, loss_function, delimeter=100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.loader = test_loader

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            pred = []
            real = []
            model.eval()
            with torch.no_grad():
                for it, (images, captions) in enumerate(tqdm(self.loader, leave=False)):

                    output = model.sample(images.to(device)).detach()
                        
                    pred.extend(torch.argmax(output, dim=-1).cpu().tolist())
                    real.extend(captions.tolist())
                
                test_edit_disctance = np.mean([edit_distance(pred_sent, real_sent) for \
                                               pred_sent, real_sent in zip(pred, real)])
                
                self.writer.add_scalar('Edit_Disctance/test', test_edit_disctance, self.step)
          
    def __call__(self, model, loss):
        return self.forward(model, loss)

In [5]:
class LSTM_loss():
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        
    def __call__(self, pred, target):
        return F.nll_loss(pred[:,:-1,:].reshape(-1, self.vocab_size), target[:,1:].reshape(-1))

In [6]:
transf = transforms.Compose([transforms.Resize((224, 224)),
                             transforms.ToTensor(),
                             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])


train_dataset = ImageCaptionDataset(captions_path="annotations/captions_train2014.json", 
                             img_dir="train2014", transforms=transf,
                             first_k=None, n_after_first_k=None)
test_dataset = ImageCaptionDataset(captions_path="annotations/captions_val2014.json", 
                             img_dir="val2014", transforms=transf,
                             first_k=5000, n_after_first_k=None)

loading annotations into memory...
Done (t=0.66s)
creating index...
index created!
loading annotations into memory...
Done (t=0.34s)
creating index...
index created!


In [7]:
%load_ext tensorboard
%tensorboard --logdir ./ --port=6006

In [8]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
tokenizer.add_eos_token = True

loss_function = LSTM_loss(vocab_size=len(tokenizer))

collator = Collator(tokenizer, max_len=50)

optimizer = torch.optim.AdamW
lr = 1e-4
hidden_dim = 768
batch_size = 70
test_step_size = 200
decoder_emb_dim = 64
n_epochs=3

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=collator)
        
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                          collate_fn=collator)

model = ImageCaptionModel(dict_size=len(tokenizer),
                          bos_token_id=tokenizer.bos_token_id,
                          embedding_size=hidden_dim,
                          decoder_emb_dim=decoder_emb_dim,
                          n_layers=1,
                          max_len=50).to(device)

writer = SummaryWriter(log_dir=f'Image_captioning')

callback = Callback(writer, test_loader, loss_function, delimeter=test_step_size)

trainer(count_of_epoch=n_epochs, 
        batch_size=batch_size, 
        loader=train_loader,
        model=model, 
        loss_function=loss_function,
        optimizer=optimizer,
        lr=lr,
        callback=callback)

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5916 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/5916 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/5916 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Результаты, хм..., сомнительные, все таки для языковых задач лучше начинать хотя бы с tiny-bert-a или брать огромную двунаправленную lstm-ку с attention-ом, но кажется в задании ни того ни другого не просили. В целом сетка запускается и даже обучается. Примеры работы делать не стал, поскольку уже устал от этой домашки.