In [1]:
from transformers import ResNetConfig,ResNetModel,BertTokenizer,Trainer
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torch import nn
import torchvision
from torchvision import transforms, datasets

In [2]:
# tokenizer
tokenizer = BertTokenizer('./vocab.txt',)
# encoder
config_encoder = ResNetConfig()
model_encoder = ResNetModel(config=config_encoder)
# decoder
model_decoder = nn.Transformer(batch_first=True)

In [3]:
# 测试 tokenizer
test_caption = 'Krxk is a great developer.'
t = tokenizer.encode(test_caption,return_tensors='pt')
t = t.repeat((2, 1))
print(t)
print(t.shape)

tensor([[ 101,  180, 1197, 1775, 1377, 1110,  170, 1632, 9991,  119,  102],
        [ 101,  180, 1197, 1775, 1377, 1110,  170, 1632, 9991,  119,  102]])
torch.Size([2, 11])


In [4]:
# model
class Krxk_model(nn.Module):
    def __init__(self, tokenizer, encoder, decoder):
        super().__init__()
        self.tokenizer = tokenizer
        self.encoder = encoder
        self.embed_size = self.tokenizer.vocab_size
        self.transformer_encoder_in_nums = 512
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(2048 * 7 * 7, self.transformer_encoder_in_nums)) # 将 resnet 转化为 全连接
        self.embed = nn.Embedding(self.embed_size, self.transformer_encoder_in_nums) # 嵌入为 transformer的in_features维度
        self.decoder = decoder
        
        # 用作临时 caption 测试案例
        test_caption = 'Krxk is a great developer.' 
        print(test_caption)
        self.test_caption_tokenized = self.tokenizer.encode(test_caption, return_tensors='pt').repeat(6, 1) # 临时案例中有6张图片
        print(self.test_caption_tokenized.shape)
        print(self.embed(self.test_caption_tokenized).shape)
        
    def forward(self, X):
        encode = self.encoder(X).last_hidden_state
        features = self.fc(encode) # 将图片特征映射为词表特征
        
        # 此处将 embeddings 与 resnet 提取的特征进行concat 输入 transformer
#         embeddings = torch.concat((embeddings, embeddings), dim=1)
        embeddings = self.embed(self.test_caption_tokenized)
        f = features.unsqueeze(1)
        print(embeddings.shape, f.shape)
        embeddings = torch.concat((embeddings, features.unsqueeze(1)), dim=1)

        # decode 需要输入 src 与 tgt，分别用作 Transformer 编码器与解码器（输入正确的caption）的输入
        decode = self.decoder(embeddings, torch.ones_like(embeddings))
        return decode

In [5]:
# 定义 model
model = Krxk_model(tokenizer ,model_encoder, model_decoder)

Krxk is a great developer.
torch.Size([6, 11])
torch.Size([6, 11, 512])


In [6]:
# 图片测试
pic_dir = './pic/'
data_args = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
data_set = datasets.ImageFolder(pic_dir, transform=data_args)
batch_size = 32
data_iter = torch.utils.data.DataLoader(data_set, batch_size=batch_size)

In [7]:
# 模型测试
model.eval() # 评估模式运行减少计算量
for X, _ in data_iter:
    print(X.shape)
    generated_ids = model.forward(X)
    print(generated_ids.shape)
    break

torch.Size([6, 3, 224, 224])
torch.Size([6, 11, 512]) torch.Size([6, 1, 512])
torch.Size([6, 12, 512])


In [None]:
# 训练