In [None]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers.optimization import get_scheduler
from transformers import BertTokenizer, GPT2LMHeadModel, GPT2Model, CLIPFeatureExtractor, CLIPVisionModel, logging, AdamW

In [None]:
device = 'cpu' if torch.cuda.is_available() else 'cpu'

In [None]:
logging.set_verbosity_error()   # 消除未使用权重的warning

img = torch.from_numpy(cv2.imread('./images/baby.jpg').transpose(2,0,1))
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
text_model = GPT2Model.from_pretrained("uer/gpt2-chinese-cluecorpussmall").to(device)

In [None]:
# 冻结参数
# 不训练, 不需要计算梯度
for param in vision_model.parameters():
    param.requires_grad = False

for param in text_model.parameters():
    param.requires_grad = False

In [None]:
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

### 定义模型

In [None]:
# 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self, gpt: GPT2Model, prefix_len, const_len):
        super().__init__()
        self.prefix_len = prefix_len
        self.const_len = const_len
        # self.const_embeddings = torch.nn.Parameter(torch.randn(1, const_len, 768))

        # self.vision = vit
        self.mapping_vision = torch.nn.Linear(257*1024, prefix_len*768)
        # self.mapping_prefix = torch.nn.Linear(768, 768)
        self.text_gen = gpt
        # 加载生成模型生成部分最后一层fc参数
        self.fc = torch.nn.Linear(768, tokenizer.vocab_size, bias=False)
        parameters = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion =torch.nn.CrossEntropyLoss()

    def forward(self, img_feature, labels):
        bs = img_feature.shape[0]
        # 图像特征提取
        # with torch.no_grad():
            # img_feature = self.vision(pixel_values).last_hidden_state.view(bs, -1)

        # 将图像特征map到文本特征
        prefix_embeddings = self.mapping_vision(img_feature).view(bs, self.prefix_len, 768)
        # prefix_embeddings = torch.concat([prefix_embeddings, self.const_embeddings.expand(bs, self.const_len, 768)], dim=1)
        # prefix_embeddings = self.mapping_prefix(prefix_embeddings)
        label_embeddings = self.text_gen.wte(labels)
        
        # 文本生成
        logits = torch.concat([prefix_embeddings, label_embeddings], dim=1)
        logits = self.text_gen(inputs_embeds = logits).last_hidden_state    # attention_mask默认全1
        logits = self.fc(logits)[:, self.prefix_len+self.const_len-1:-1]
        
        # 计算损失
        shift_logits = logits.flatten(end_dim=1)
        shift_labels = labels.flatten()

        loss = self.criterion(shift_logits, shift_labels)

        return {
            'loss': loss,
            'logits': logits
        }

In [None]:
model = Model(text_model, 5, 0).to(device)
pixel_values = feature_extractor(img, return_tensors='pt')['pixel_values'].to(device)

In [None]:
img_feature = vision_model(pixel_values).last_hidden_state.view(1, -1)

In [None]:
label = tokenizer.encode('我爱的宝宝', return_tensors='pt', add_special_tokens=False).to(device)
model(img_feature, label)

In [None]:
epoches = 2
model.train()
model.text_gen.eval()
optimizer = AdamW(model.parameters(), lr = 5e-6)
scheduler = get_scheduler(name='linear',
                            num_warmup_steps=0,
                            num_training_steps=epoches,
                            optimizer=optimizer)

for i in range(epoches):
    loss = model(img_feature, label)['loss']
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 解决梯度爆炸
    
    optimizer.step()
    scheduler.step()
    
    optimizer.zero_grad()
    model.zero_grad()

    print(loss)

In [None]:
def train():
    global model
    device = 'cpu' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr = 2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)
    
    model.train()
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 解决梯度爆炸

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if(i % 50 == 0):
            labels = data['labels'][:, 1:]
            out = out['logits'].argmax(dim=2)[:, :-1]

            accuracy = (out == labels).sum().item() / labels.numel()

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), lr, accuracy)

    model = model.to('cpu')

train()
torch.save(model, './models/en_gen.model')

In [None]:
def generate(img_feature, model, tokenizer, max_length, num_samples):

    def generate_loop(logits, output):
        with torch.no_grad():
            out = model.text_gen(inputs_embeds = logits).last_hidden_state    # attention_mask默认全1
            out = model.fc(out)[:, -1]

        # 在前num_samples个采样
        topk_value = torch.topk(out, num_samples).values
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        # 赋值
        out = out.masked_fill(out < topk_value, -float('inf'))

        # 根据概率采样, 无放回
        out = out.softmax(dim=1)
        out = out.multinomial(num_samples=1)
        if(output is None):
            output = out
        else:
            output = torch.cat([output, out], dim=1)
        
        #将输出编码
        out_embeddings = model.text_gen.wte(out)
        logits = torch.cat([logits, out_embeddings], dim=1)

        if(logits.shape[1] >= max_length):
            return output
        
        return generate_loop(logits, output)

    logits = model.mapping_vision(img_feature).view(1, model.prefix_len, 768)

    # 重复5遍
    output = None
    logits = logits.expand(5, model.prefix_len, 768)
    output = generate_loop(logits, output)

    for i in range(5):
        print(i, tokenizer.decode(output[i].flatten(), add_special_tokens=False))

In [None]:
generate(img_feature, model, tokenizer, max_length=10, num_samples=10)

In [None]:
def cap_img(img_path: str, feature_extractor: CLIPFeatureExtractor, vision_model: CLIPVisionModel):
    # 导入图片
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    print('你输入的图片: ')
    plt.imshow(img)
    plt.show()
    img = torch.from_numpy(img.transpose(2,0,1))
    
    # 图片预处理
    pixel_values = feature_extractor(img, return_tensors='pt')['pixel_values']
    img_feature = vision_model(pixel_values).last_hidden_state.view(1, -1)

    # 生成文字
    print('生成的文字: ')
    generate(img_feature, model, tokenizer, max_length=10, num_samples=5)

cap_img('./images/baby.jpg', feature_extractor, vision_model)