## 任务说明
实现并评估基于网格表示的transfomer模型，用于对图像生成对应的文本描述

## 实验数据
基本的数据集使用课程提供的DeepFashion-MultiModal数据集，包含了对应的图像和文本描述

## 实验环境
- 操作系统： 22.04.1-Ubuntu x86_64 内核6.2.0-39-generic
- GPU: NVIDIA GeForce RTX 3060 laptop GPU
- CUDA: 12.2
- python: conda-python 3.11.5
- pytorch: 2.1.1

In [None]:
from cgi import test
import os
from collections import Counter
import numpy as np
import json
import torch
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# 数据预处理阶段

## 建立数据集类

In [None]:
image_path = "./data/deepfashion-multimodal/images/"
train_text = "./data/deepfashion-multimodal/train_captions.json"
test_text = "./data/deepfashion-multimodal/test_captions.json"

class MyDataset(Dataset):
    def __init__(self, image_paths, train_captions, test_captions, transform=None):
        self.image_paths = image_paths
        self.train_captions = train_captions
        self.test_captions = test_captions
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        idx_name = self.image_paths[idx]
        file_name = idx_name.split("/")[-1]
        image = Image.open(idx_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.train_captions[file_name]
    
    def get_train_captions(self):
        return self.train_captions

## 读取数据集文本并建立图片的transform

In [None]:
# read data
train_captions = json.load(open(train_text, 'r'))
test_captions = json.load(open(test_text, 'r'))
image_paths = []
# add from train_captions
for key in train_captions.keys():
    image_paths.append(image_path + key)

transform = transforms.Compose(
    [
        transforms.Resize((256 * 4, 256 * 4)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

## 对文本描述进行编码

这一步将文本描述的每一个单词映射到一个整数, 并用一个整数序列来表示原本的文本描述

具体的执行步骤如下:

- 建立一个词典vocab, 用于将单词映射到整数
- 将0映射为占位符<pad>
- 将len(vocab)映射为未知单词<start>
- 将len(vocab)+1映射为结束符<end>
- 将原本的句子按照空格分割, 开头结尾加上<start>和<end>
- 用<pad>将句子补齐到最大长度
- 将句子中的每一个单词映射到整数

In [None]:
image_dic = train_captions.keys()
image_descriptions = train_captions.values()

# build vocabulary
vocab = Counter()
for description in image_descriptions:
    vocab.update(description.split())

# remove words that occur less than threshold
threshold = -1
words = [word for word, count in vocab.items() if count >= threshold]

# create a mapping from word to index and index to word
idx_to_word = {idx: word for idx, word in enumerate(words, 1)}

# add the start and end token to the vocabulary
idx_to_word[0] = "<pad>"
idx_to_word[len(idx_to_word)] = "<end>"
idx_to_word[len(idx_to_word)] = "<start>"

# add the end to the end pos
for key, description in train_captions.items():
    train_captions[key] = "<start> " + description + " <end>"

# pad the descriptions with <pad>
max_length = max(len(description.split()) for description in image_descriptions)
for key, description in train_captions.items():
    train_captions[key] = description + " <pad>" * (
        max_length - len(description.split())
    )
for key, description in test_captions.items():
    test_captions[key] = description + " <pad>" * (
        max_length - len(description.split())
    )

word_to_idx = {word: idx for idx, word in idx_to_word.items()}
vocab_size_len = len(idx_to_word)


# convert each word to its index
train_captions = {
    key: [word_to_idx[word] for word in value.split()]
    for key, value in train_captions.items()
}
test_captions = {
    key: [word_to_idx[word] for word in value.split()]
    for key, value in test_captions.items()
}


for key, value, idx in zip(train_captions.keys(), train_captions.values(), range(5)):
    print(key, value, idx)
    if idx == 4:
        break


## 建立数据集

In [None]:
for key, value in train_captions.items():
    train_captions[key] = torch.tensor(value, dtype=torch.long)

# make dataset
dataset = MyDataset(image_paths, train_captions, test_captions, transform)

for i in range(1):
    print(dataset[i][0].shape, dataset[i][1].shape)
    print(dataset[i][0], dataset[i][1])

## 定义模型

模型分为两个主要部分:
- resnet18
- transformer

### resnet18
resnet18能够刚好提取512维的特征, 这是transformer的嵌入维度的常用数值, 所以我们选择该模型作为图像特征提取器, 并使用预训练的参数, 在后续的训练过程中不再更新CNN部分的参数

具体的, 我们通过transform将读取的图片resize到了(256*4, 256*4), 然后输入到FeatureExtractorCNN中, 然后将其切割为16块. 然后对于每一块我们单独使用resnet18提取其特征, 然后将这16个特征拼接起来, 得到一个(16, batch, 512)的输入序列.

### transformer
transformer部分的参数如下:
```python
self.transformer = Transformer(
    d_model=emb_dim,
    nhead=8,
    dim_feedforward=1024,
    num_encoder_layers=1,
    num_decoder_layers=1,
    batch_first=False,
    dropout=0,
    bias=True,
)
```
但是在调试过程中我们发现, encoder部分无法有效学习到CNN部分提取的特征, 不论CNN给出的特征是多少, encoder都会输出一个相同的值, 在多次调参无果后, 我们决定删除encoder层, 直接将CNN提取的特征做为decoder的特征输入

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import Transformer

# 1. 使用预训练的 CNN 提取特征
class FeatureExtractorCNN(nn.Module):
    def __init__(self):
        super(FeatureExtractorCNN, self).__init__()
        
        self.resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])

    def forward(self, images):
        # resize to 3 times larger
        batch_size = images.shape[0]

        images = images.unfold(2, 256, 256).unfold(3, 256, 256)
        images = images.contiguous().view(-1, 3, 256, 256)

        # 提取特征
        features = self.resnet(images)
        features = features.view(batch_size, 16, -1)
        return features

# 2. 构建 Transformer 模型
class ImageCaptioningTransformer(nn.Module):
    def __init__(self, emb_dim, nhead, nhid, nlayers, vocab_size, max_seq_length):
        super(ImageCaptioningTransformer, self).__init__()
        self.pos_encoder = PositionalEncoding(emb_dim)
        self.transformer = Transformer(
            d_model=emb_dim,
            nhead=8,
            dim_feedforward=1024,
            num_encoder_layers=1,
            num_decoder_layers=1,
            batch_first=False,
            dropout=0,
            bias=True,
        )
        self.decoder = nn.Linear(emb_dim, vocab_size)
        self.max_seq_length = max_seq_length
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        
        self.trg_mask = None

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None,
                tgt_padding_mask=None, memory_mask=None):
        if self.trg_mask is None or self.trg_mask.size(0) != len(tgt):
            self.trg_mask = self.generate_square_subsequent_mask(len(tgt)).to(tgt.device)

        trg_pad_mask = self.make_len_mask(tgt)
        
        # output = self.encode(src, src_mask=src_mask, src_padding_mask=src_padding_mask)
        output = src
        output = self.decode(tgt, output, tgt_mask=self.trg_mask, tgt_key_padding_mask=trg_pad_mask,
                                memory_mask=memory_mask, memory_key_padding_mask=src_padding_mask)
        return output
    
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)
    
    def encode(self, src, src_mask=None, src_padding_mask=None):
        src = self.pos_encoder(src)  # 添加位置编码
        memory = self.transformer.encoder(src, mask=src_mask, src_key_padding_mask=src_padding_mask)
        return memory
    
    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        if self.trg_mask is None or self.trg_mask.size(0) != len(tgt):
            device = tgt.device
            self.trg_mask = self.generate_square_subsequent_mask(len(tgt)).to(device)

        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        output = self.transformer.decoder(tgt, memory, tgt_mask=self.trg_mask,
                                          tgt_key_padding_mask=tgt_key_padding_mask,
                                          memory_mask=memory_mask,
                                          memory_key_padding_mask=memory_key_padding_mask)
        output = self.decoder(output)
        return output

# 辅助类：位置编码
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)

## 定义生成描述的函数
关于生成部分, 我们需要逐词生成描述, 并且在生成完毕后将其对照词典转化为单词, 生成从一个<start>开始, 当生成<end>时停止, 并且在生成过程中, 我们需要将生成的单词作为下一次的输入, 以此来实现逐词生成的目的

In [None]:
# 逐词生成描述
def generate_caption(model, image_features, word2idx, idx2word,  max_length=93):

    outputs = [word2idx["<start>"]]

    for i in range(max_length - 1):
        
        with torch.no_grad():
            
            out = model(image_features, torch.tensor(outputs).unsqueeze(1).to(device))
        
        next_word = out.argmax(dim=2)[-1].item()

        # 检查是否达到结束标记
        if next_word == word2idx['<end>']:
            break  # 一旦生成<end>标记，立即停止生成
        
        # rand select next word from the top 1
        next_word = torch.topk(out, 5, dim=2)[1][-1].squeeze().tolist()
        # print next_word map to words
        #print([idx2word[idx] for idx in next_word])
        next_word = next_word[0]

        outputs.append(next_word)

    # 转换序列为文字
    caption = [idx2word[idx] for idx in outputs]

    return ' '.join(caption)

## 训练模型
- 训练的数据batch_size为8, 这是由于计算资源的限制选择的最大值
- epoch为5, 并且每个epoch训练完毕后都进行参数保存
- 优化器使用Adam, 学习率为1e-3
- 损失函数使用ce loss, 并且对于<pad>的部分不进行计算

In [None]:
# 定义模型参数
emb_dim = 512  # 嵌入维度
nhead = 8  # 多头注意力的头数
nhid = 512  # 前馈网络的维度
nlayers = 1  # 编码器和解码器层的数量
vocab_size = vocab_size_len  # 词汇表大小
max_seq_length = 512  # 最大序列长度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建模型实例
cnn = FeatureExtractorCNN().to(device)
transformer = ImageCaptioningTransformer(
    emb_dim, nhead, nhid, nlayers, vocab_size, max_seq_length
).to(device)

print(transformer)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

cnn = cnn.to(device)
cnn.eval()

In [None]:
transformer = transformer.to(device)
transformer.load_state_dict(torch.load("./model/transformer_re4.pth"))

# train
epoch = 5

optimizer = torch.optim.Adam(transformer.parameters(), 0.001)

vocab_size = vocab_size_len

# 使用权重创建损失函数
criterion = nn.CrossEntropyLoss(ignore_index=0)

for e in range(epoch):
    for i, (image, caption) in enumerate(dataloader):
        transformer.eval()
        
        
        caption = caption.to(device)
        
        optimizer.zero_grad()
        # 提取图像特征
        features = cnn(image.to(device))
        
        features = features.transpose(0, 1)
        caption = caption.transpose(0, 1)
        
        loss = 0
        
        output = None
        
        #check each pos
        temp = 0
        pad_num = 0
        pad_num_2 = 0
        
        # split caption from length 2 to max_length
        output = transformer(features, caption[:-1, :])
        
        loss = criterion(output.reshape(-1, 160), caption[1:, :].reshape(-1))
        
        if i % 20 == 0:
            print("epoch: {}, step: {}, loss: {}".format(e, i, loss))
                
        loss.backward()
        optimizer.step()
        
    torch.save(transformer.state_dict(), "./model/transformer_thd{}.pth".format(e+2))

## 效果展示
这个代码块可以选择其中一张图片进行展示, 然后生成描述文本并输出, 然后输出参考文本

In [None]:
import matplotlib.pyplot as plt

test_data = dataset[330]
print(test_data[0].shape, test_data[1].shape)

test_image = test_data[0].unsqueeze(0).to(device)
test_feature = cnn(test_image)
test_feature = test_feature.transpose(0, 1)

transformer.load_state_dict(torch.load("./model/transformer_re4.pth"))
transformer.eval()

# show the img
plt.imshow(test_data[0].permute(1, 2, 0).numpy())
plt.show()

output = generate_caption(transformer, test_feature, word_to_idx, idx_to_word)

# output = transformer(test_feature, test_feature, None, None, None, None)

print("output: ", output)

caption = test_data[1].numpy()
sentence = ""
for idx in caption:
    if idx == 0:
        continue
    sentence += idx_to_word[idx.item()] + " "
    
print("reference: ", sentence)

# idx2word
idx2word = {idx: word for idx, word in enumerate(words, 1)}
idx2word[0] = "<pad>"
idx2word[len(idx2word)] = "<end>"

## 计算全部测试集
对测试集的全部图片进行描述生成, 并保存为json文件, 用于后续的指标计算

In [None]:
transformer.eval()
test_data = json.load(open(test_text, 'r'))

test_image_paths = []
# add from train_captions
for key in test_data.keys():
    test_image_paths.append(image_path + key)
    
res = {}
    
for i, path in enumerate(test_image_paths):
    if i == 620:
        break
    
    test_image = Image.open(path).convert("RGB")
    test_image = transform(test_image).unsqueeze(0).to(device)
    test_feature = cnn(test_image)
    
    test_feature = test_feature.transpose(0, 1)
    
    output = generate_caption(transformer, test_feature, word_to_idx, idx_to_word)
    
    # remove start
    output = output.split(" ")
    output = output[1:]
    output = " ".join(output)
    
    res[path.split("/")[-1]] = output
    
json.dump(res, open("./result2.json", 'w'))

## 计算指标
使用自己实现的指标计算函数对三种不同指标进行计算评估

In [None]:
import evaluate
import json
import numpy as np

test_text = "./data/deepfashion-multimodal/test_captions.json"

# read from result.json
test_data = json.load(open("./result.json", 'r'))
real_data = json.load(open(test_text, 'r'))

test_selected_data = {}
real_selected_data = {}

indice = 0
for key, value in test_data.items():
    real_selected_data[key] = real_data[key]
    test_selected_data[key] = value
    indice += 1

eval = evaluate.DeepFashionEvalCap(real_selected_data, test_selected_data)
eval.evaluate()