In [None]:
import time
import json
import torch
import itertools
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from transformers import BertModel
from torch.nn import functional as F
from transformers import BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import BertConfig
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertEncoder

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: NVIDIA A40


In [None]:
train_text_inputs = torch.load('cls_train_text_inputs.pth')
train_text_masks = torch.load('cls_train_text_masks.pth')
train_text_attribute_label = torch.load('train_text_attribute_label.pth')
train_text_word_together = torch.load('train_text_word_together.pth')

In [None]:
train_img_obj_feat = torch.load('train_img_obj_feat.pth')

In [None]:
train_img_obj_bbox = torch.load('train_img_obj_bbox.pth')

In [6]:
train_img_ocr_feat = torch.load('train_img_ocr_feat.pth')

In [7]:
train_img_ocr_bbox = torch.load('train_img_ocr_bbox.pth')

In [8]:
train_img_masks = torch.load('cls_train_img_masks.pth')

In [9]:
train_edit_distance = torch.load('train_edit_distance.pth')

In [10]:
val_text_inputs = torch.load('cls_val_text_inputs.pth')

In [11]:
val_text_masks = torch.load('cls_val_text_masks.pth')

In [12]:
val_text_attribute_label = torch.load('val_text_attribute_label.pth')

In [13]:
val_text_word_together = torch.load('val_text_word_together.pth')

In [14]:
val_img_obj_feat = torch.load('val_img_obj_feat.pth')

In [15]:
val_img_obj_bbox = torch.load('val_img_obj_bbox.pth')

In [16]:
val_img_ocr_feat = torch.load('val_img_ocr_feat.pth')
val_img_ocr_bbox = torch.load('val_img_ocr_bbox.pth')
val_img_masks = torch.load('cls_val_img_masks.pth')
val_edit_distance = torch.load('val_edit_distance.pth')

In [17]:
# 因为连着的5个都是正样本，所以我们要把正样本分开

In [18]:
# 上Dataloader
batch_size = 128

train_data = TensorDataset(train_text_inputs, train_text_masks, train_text_attribute_label, train_text_word_together,
                           train_img_obj_feat, train_img_obj_bbox, train_img_ocr_feat, train_img_ocr_bbox, train_img_masks, train_edit_distance)

train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)


In [19]:
val_data = TensorDataset(val_text_inputs, val_text_masks, val_text_attribute_label, val_text_word_together,
                         val_img_obj_feat, val_img_obj_bbox, val_img_ocr_feat, val_img_ocr_bbox, val_img_masks, val_edit_distance)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)

In [20]:
config = BertConfig()

In [21]:
class OBJ_Encoder(nn.Module):

    def __init__(self,):
        super(OBJ_Encoder, self).__init__()

        self.num_obj_cdn_in = 2048  # 2048-dim
        self.num_obj_bbox_in = 4  # 4-dim
        self.num_obj_out = 768  # 768-dim
        self._build_obj_encoding()

    def _build_obj_encoding(self):

        # cdn feature
        self.cdn_to_mmt_in = nn.Linear(
            self.num_obj_cdn_in, self.num_obj_out
        )

        # OBJ location feature: bounding box coordinates (4-dim)
        self.linear_obj_bbox_to_mmt_in = nn.Linear(
            self.num_obj_bbox_in, self.num_obj_out
        )

        self.obj_layer_norm = nn.LayerNorm(self.num_obj_out)
        self.obj_bbox_layer_norm = nn.LayerNorm(self.num_obj_out)

    def forward(self, obj_cdistnet_in, obj_bbox_in):

        obj_feat = self._forward_obj_encoding(
            obj_cdistnet_in, obj_bbox_in)

        return obj_feat

    def _forward_obj_encoding(self, obj_cdistnet_in, obj_bbox_in):

        # OBJ appearance feature: cdistnet
        obj_cdn = F.normalize(obj_cdistnet_in, dim=-1)

        # MLP + LN
        cdn_feat = self.obj_layer_norm(
            self.cdn_to_mmt_in(obj_cdn))

        obj_bbox = obj_bbox_in

        bbox_feat = self.obj_bbox_layer_norm(
            self.linear_obj_bbox_to_mmt_in(obj_bbox))

        # obj_feat
        obj_mmt_in = (
            cdn_feat + bbox_feat

        )
        return obj_mmt_in


class OCR_Encoder(nn.Module):

    def __init__(self,):
        super(OCR_Encoder, self).__init__()

        self.num_ocr_cdn_in = 2048  # 2048-dim
        self.num_ocr_bbox_in = 4  # 4-dim
        self.num_ocr_out = 768  # 768-dim
        self._build_ocr_encoding()

    def _build_ocr_encoding(self):

        # ft_concat_cdn feature
        self.cdn_to_mmt_in = nn.Linear(
            self.num_ocr_cdn_in, self.num_ocr_out
        )

        # OCR location feature: bounding box coordinates (4-dim)
        self.linear_ocr_bbox_to_mmt_in = nn.Linear(
            self.num_ocr_bbox_in, self.num_ocr_out
        )

        self.ocr_layer_norm = nn.LayerNorm(self.num_ocr_out)
        self.ocr_bbox_layer_norm = nn.LayerNorm(self.num_ocr_out)

    def forward(self, ocr_cdistnet_in, ocr_bbox_in):

        ocr_feat = self._forward_ocr_encoding(
            ocr_cdistnet_in, ocr_bbox_in)

        return ocr_feat

    def _forward_ocr_encoding(self, ocr_cdistnet_in, ocr_bbox_in):

        # OCR appearance feature: cdistnet
        ocr_cdn = F.normalize(ocr_cdistnet_in, dim=-1)

        # MLP + LN
        cdn_feat = self.ocr_layer_norm(
            self.cdn_to_mmt_in(ocr_cdn))

        ocr_bbox = ocr_bbox_in

        bbox_feat = self.ocr_bbox_layer_norm(
            self.linear_ocr_bbox_to_mmt_in(ocr_bbox))

        # ocr_feat
        ocr_mmt_in = (
            cdn_feat + bbox_feat
        )
        return ocr_mmt_in


class CrossModel(nn.Module):
    def __init__(self,):
        super(CrossModel, self).__init__()
        # self.config = config
        self.encoder = BertEncoder(config)

        # 上面的Encoder用于visual，文本的我们就用预训练好的
        # 实体化Bert模型
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # img cls linear
        self.img_cls_Linear =  nn.Linear(768, 768)
        # img obj Encoder
        self.obj_encoder = OBJ_Encoder()
        # img ocr Encoder
        self.ocr_encoder = OCR_Encoder()
        """
        # output L
        # text L
        self.classifier1 = nn.Sequential(
            nn.Linear(768, 500),  # 全连接
            nn.Dropout(0.2),  # drop 50% neurons
            nn.ReLU(),  # 激活函数
            nn.Linear(500, 100),  # 全连接
        )  # 这里是2层的全连接
        # img L
        self.classifier2 = nn.Sequential(
            nn.Linear(768, 500),  # 全连接
            nn.Dropout(0.2),  # drop 50% neurons
            nn.ReLU(),  # 激活函数
            nn.Linear(500, 100),  # 全连接
        )  # 这里是2层的全连接
        """
        

    def forward(self, text_inputs, text_masks, b_img_obj_feat, b_img_obj_bbox, b_img_ocr_feat, b_img_ocr_bbox, img_masks):
        # 文本用bert
        text_outputs = self.bert(
            input_ids=text_inputs, attention_mask=text_masks)  # 这是现有的
        # 文本特征最后隐层输出
        text_cls = text_outputs[0][:,0,:]
        text_last_hidden_state = text_outputs[0][:,1:,:]
        
        # img_obj
        img_obj = self.obj_encoder(b_img_obj_feat, b_img_obj_bbox)

        # img_ocr
        img_ocr = self.ocr_encoder(b_img_ocr_feat, b_img_ocr_bbox)
        
        # img_cls
        length_img = len(img_ocr)
        cls_token = torch.ones((length_img,1,768),dtype=torch.float32,requires_grad=True).to(device)
        img_cls_token = F.normalize(self.img_cls_Linear(cls_token))

        # 图片用transformer的encoder
        img_inputs = torch.cat((img_cls_token, img_obj, img_ocr), dim=1)  # 连接obj和ocr输入特征 obj:36 ocr:10
        # 把mask转换一下
        img_masks = img_masks[:, None, None, :]
        img_masks = (1.0 - img_masks) * -10000.0
        img_outputs = self.encoder(img_inputs, attention_mask=img_masks)
        # 图像特征最后隐层输出
        img_cls = img_outputs['last_hidden_state'][:,0,:]
        img_last_hidden_state = img_outputs['last_hidden_state'][:,1:,:]
        

        return text_cls, text_last_hidden_state, img_cls, img_last_hidden_state


In [22]:
def initialize_model(epochs=2):
    """
    初始化，优化器还有学习率，epochs就是训练次数
    """
    # 初始化我们的Bert分类器
    cross_model = CrossModel()
    # 用GPU运算
    cross_model.to(device)
    # 创建优化器

    optimizer = AdamW(cross_model.parameters(),
                       lr=5e-5,  # 默认学习率
                       weight_decay=5e-4,
                       eps=1e-7  # 默认精度
                       )

    
    # 训练的总步数
    total_steps = len(train_dataloader) * epochs
    # 学习率调度器，说白了就是自适应学习速率
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,  # Default value
                                                num_training_steps=total_steps)
    return cross_model, optimizer, scheduler

In [23]:

class step1_Loss(nn.Module):
    def __init__(self, ):
        super(step1_Loss, self).__init__()
    
    def forward(self, text_cls, img_cls):
    # 前者是文本内容的最后隐层信息，后者是图片内容的最后隐层信息

        cos_sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
        positive_score_list = cos_sim(text_cls, img_cls)

        ######################################################
        positive_score = positive_score_list[0]
        positive_score = torch.exp(positive_score).to(device)

        negative_img_score_list = cos_sim(text_cls[0], img_cls) # (negtive+positive)(img no match)
        negative_text_score_list = cos_sim(img_cls[0], text_cls) # (negtive+positive)(text no match)
        negative_score = torch.exp(negative_img_score_list).sum()+torch.exp(negative_text_score_list).sum()

        step1_NCE_loss =  - (torch.log(positive_score/negative_score))
        
        for i in range(1, len(positive_score_list)):
            # 对于每个正样本，都有2*bz-2个负样本
            positive_score = positive_score_list[i]
            positive_score = torch.exp(positive_score).to(device)

            negative_img_score_list = cos_sim(text_cls[i], img_cls) # (negtive+positive)(img no match)
            negative_text_score_list = cos_sim(img_cls[i], text_cls) # (negtive+positive)(text no match)
            negative_score = torch.exp(negative_img_score_list).sum()+torch.exp(negative_text_score_list).sum()

            step1_NCE_loss = step1_NCE_loss - (torch.log(positive_score/negative_score))
            
         #-----------------------------------------Step 1-----------------------------------------------
        return step1_NCE_loss


In [24]:
class step2and3_Loss(nn.Module):
    def __init__(self, ):
        super(step2and3_Loss, self).__init__()
        self.step1_loss = step1_Loss()
    def forward(self, text_last_hidden_state, img_last_hidden_state,b_text_masks,b_img_masks,b_text_attribute_label, b_text_word_together, b_editor_distant, step2=False,step3=False):
        
        # 前者是文本内容的最后隐层信息，后者是图片内容的最后隐层信息
            # 但是他们的维度都是被padding过的，所以我们要根据b_text_masks截取有用的部分作平均池化     
         #-----------------------------------------Step 2-----------------------------------------------
        if step2:
            # step1是全局平均池化，求的是全局的信息，step2则是通过全局粗粒度对齐后，增加的细粒度信息对齐，分别为obj和ocr
            # obj通过全局对齐后的逻辑直接找max，然后通过fasttext的类别特征rerank来监督
            # obj的监督还是可以用NCE loss
            # ocr就是通过CDist识别出来的字符，直接和文本中认为是OCR的单词来作编辑距离监督

            #### 首先是obj的 ####
            # global训练完了，先直接找最大匹配的token，但是最开始我们还要把text中分词的token合并
            # text_last_hidden_state, b_text_word_together, b_text_attribute_label, b_text_masks
            # attribute_label中0是不管的，1是obj，2是ocr，3是relation
            # 从上面四个输入得到我们要的obj token
            # 先用mask来把last_hiddeen_state给截断，得到有意义的部分

            # loss obj和ocr
            step2_obj_loss = torch.tensor(0,dtype=float).to(device)
            step2_ocr_loss = torch.tensor(0,dtype=float).to(device)
            step3_relation_loss = torch.tensor(0,dtype=float).to(device)
            step2_obj_loss.requires_grad_(True)
            step2_ocr_loss.requires_grad_(True)
            step3_relation_loss.requires_grad_(True)

            cos_sim = nn.CosineSimilarity(dim=-1, eps=1e-6)

            for i in range(len(text_last_hidden_state)):
                # relation_order来存最大可能的物体，包含obj和ocr
                relation_order = []
                this_text_last_hidden_state = text_last_hidden_state[i][:(b_text_masks[i].sum()-2)]
                # 现在this就是去掩码的有用信息
                length = 0
                for j in b_text_word_together[i]:
                    if j!=-1:
                        length+=1
                    else:
                        break
                # length就是实际句子单词的长度
                # 然后我们进行合并
                this_text_word_together = b_text_word_together[i][:length]
                this_text_attribute_label = b_text_attribute_label[i][:length]
                counts = 0
                for k in range(length):
                    # 这里的this_text_word_together[k]就是后面几个token为一个实际单词
                    word_len = this_text_word_together[k]
                    # ocr_times来记录这是第几个ocr，来返回对应所需要的编辑距离label值
                    ocr_times = 0

                    # 这是一个单词的最后隐层信息
                    if word_len>1:
                        word_last_hidden_state = this_text_last_hidden_state[counts:counts+word_len].mean(dim=0)
                    else:
                        word_last_hidden_state = this_text_last_hidden_state[counts:counts+word_len]

                    counts += word_len

                    if this_text_attribute_label[k]==1:

                        # temp_obj.append(word_last_hidden_state)
                        # 这是obj的隐层信息，直接作余弦相似度找最大
                        obj_cos_sim = cos_sim(word_last_hidden_state, img_last_hidden_state[i][:36])
                        max_obj_i_sim = obj_cos_sim[obj_cos_sim.argmax()]

                        # obj_cos_sim.argmax()为该text中obj匹配到最佳img obj的序号
                        relation_order.append(int(obj_cos_sim.argmax()))

                        #############################   step2_obj_loss，让obj匹配最大的最大     ###############################
                        step2_obj_loss = step2_obj_loss - torch.log(torch.exp(max_obj_i_sim-torch.tensor(1).to(device)))


                    elif this_text_attribute_label[k]==2:

                        # 防止图片中没识别出ocr报错
                        if b_img_masks[i].sum()!=37:
                            # 这里我们直接得到这句话所有ocr对应的编辑距离
                            # b_editor_distant应该被补0了，所以是[bz, 30, 10]的维度
                            # 这里的k就是text中ocr的个数
                            this_editor_distant = b_editor_distant[i][ocr_times]

                            # 图片中有ocr，直接求loss
                            ocr_cos_sim = cos_sim(word_last_hidden_state, img_last_hidden_state[i][36:(b_img_masks[i].sum()-1)])
                            relation_order.append(int(ocr_cos_sim.argmax() + 36))  
                            # 求出来了对应的余弦相似性，用编辑距离作监督
                            # 因为要保证差是整数，所以平方即可
                            for l in range(len(ocr_cos_sim)):
                                #############################   step2_ocr_loss，让ocr匹配与编辑距离最符  ############################
                                step2_ocr_loss = step2_ocr_loss + torch.pow((this_editor_distant[l]-ocr_cos_sim[l]),torch.tensor(2).to(device))

                        else:
                            # 这里就是文本里有识别的ocr，但是图片里面没有，那么我们就把ocr退化为obj，使用求obj_loss的方法
                            ocr_cos_sim = cos_sim(word_last_hidden_state, img_last_hidden_state[i][:36])
                            max_ocr_i_sim = ocr_cos_sim[ocr_cos_sim.argmax()]
                            relation_order.append(int(ocr_cos_sim.argmax())) 

                            #############################   退化后的ocr loss，让ocr匹配最大的最大     ###############################
                            step2_ocr_loss = step2_ocr_loss - torch.log(torch.exp(max_ocr_i_sim-torch.tensor(1).to(device)))

                        ocr_times +=1 # ocr_times记录+1
                    #-----------------------------------------Step 2-----------------------------------------------
                    
                counts = 0
                relation_order = list(set(relation_order))        
                # 取relation对，不会重复取相同元素
                if len(relation_order)>1:
                    relation_order = list(itertools.combinations(relation_order, 2))
                else:
                    relation_order = [(relation_order[0], relation_order[0])]

                for k in range(length):
                    # 这里的this_text_word_together[k]就是后面几个token为一个实际单词
                    word_len = this_text_word_together[k]
                    # ocr_times来记录这是第几个ocr，来返回对应所需要的编辑距离label值

                    # 这是一个单词的最后隐层信息
                    if word_len>1:
                        word_last_hidden_state = this_text_last_hidden_state[counts:counts+word_len].mean(dim=0)
                    else:
                        word_last_hidden_state = this_text_last_hidden_state[counts:counts+word_len]

                    counts += word_len
                    
                    if this_text_attribute_label[k]==3:
                        #-----------------------------------------Step 3-----------------------------------------------
                        if step3:
                            # 这个时候我们就要从relation_order里面找最佳匹配的两个物体
                            # order直接从last_img_hidden_state[i]里对应
                            # 避免重复，先去重序号
               
                            # max_relation用来记录最大的relation对是谁
                            #print(relation_order[0][0])
                            #print(relation_order[0][1])
                           
                            max_relation = cos_sim(word_last_hidden_state,(img_last_hidden_state[i][relation_order[0][0]] + img_last_hidden_state[i][relation_order[0][1]]))

                            for l in relation_order:
                                # 把relation对特征直接相加，然后和relation text特征作余弦相似度
                                relation_pair = img_last_hidden_state[i][l[0]] + img_last_hidden_state[i][l[1]]
                                relation_cos_sim = cos_sim(word_last_hidden_state, relation_pair)
                              
                                if relation_cos_sim>max_relation:
                                    max_relation = relation_cos_sim
                            # 找到了最大的relation对——max_relation就可以返回loss了
                            step3_relation_loss = step3_relation_loss - torch.log(torch.exp(max_relation-torch.tensor(1).to(device)))
                        #-----------------------------------------Step 3-----------------------------------------------
        if step2==True and step3==False:
            return step2_obj_loss + step2_ocr_loss
        elif step2==True and step3==True:
            return step2_obj_loss + step2_ocr_loss + step3_relation_loss

In [25]:
NCE_loss = step1_Loss()
step2and3_loss = step2and3_Loss()

In [26]:
# 训练模型
def train(model,  train_dataloader, test_dataloader=None, epochs=2, step1 = True, step2 = False, step3=False, evaluation=False):
    Train_loss = []
    Test_loss = []

    # 开始训练循环
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # 表头
        print(
            f"{'Epoch':^7} | {'每10个Batch':^10} | {'训练集 Loss':^14} |{'时间':^9}")
        print("-" * 100)

        # 测量每个epoch经过的时间
        t0_epoch, t0_batch = time.time(), time.time()

        # 在每个epoch开始时重置跟踪变量
        total_loss, batch_loss, batch_counts = 0, 0, 0

        # 把model放到训练模式
        model.train()
        times_all = 0
        
        
        # 分batch训练
        for step, batch in enumerate(train_dataloader):
            batch_counts += 1
            times_all += 1
            # 把batch加载到GPU
            b_text_inputs, b_text_masks, b_text_attribute_label, b_text_word_together, b_img_obj_feat, b_img_obj_bbox, b_img_ocr_feat, b_img_ocr_bbox, b_img_masks, b_editor_distant = batch
            

            b_text_inputs = b_text_inputs.to(device)
            b_text_masks = b_text_masks.to(device)
            b_text_attribute_label = b_text_attribute_label.to(device)
            b_text_word_together = b_text_word_together.to(device)
            b_img_obj_feat = b_img_obj_feat.to(device)
            b_img_obj_bbox = b_img_obj_bbox.to(device)
            b_img_ocr_feat = b_img_ocr_feat.to(device)
            b_img_ocr_bbox = b_img_ocr_bbox.to(device)
            b_img_masks = b_img_masks.to(device)
            b_editor_distant = b_editor_distant.to(device)
            
            
            # 归零导数
            model.zero_grad()

            # 真正的训练
            text_cls, text_last_hidden_state, img_cls, img_last_hidden_state = model(b_text_inputs, b_text_masks, b_img_obj_feat, b_img_obj_bbox, b_img_ocr_feat, b_img_ocr_bbox, b_img_masks)
            
            
            if step1==True and step2==False:
                loss = NCE_loss(text_cls,img_cls)
                loss.backward()
            elif step1==True and step2==True and step3==False:
                loss = NCE_loss(text_cls,img_cls) + step2and3_loss(text_last_hidden_state,img_last_hidden_state, b_text_masks, b_img_masks, b_text_attribute_label, b_text_word_together, b_editor_distant,step2=True,step3=False)
                loss.backward()
            elif step1 == True and step2 == True and step3 == True:
                loss = NCE_loss(text_cls,img_clss) + step2and3_loss(text_last_hidden_state,img_last_hidden_state, b_text_masks, b_img_masks, b_text_attribute_label, b_text_word_together, b_editor_distant, step2=True,step3=True)
                loss.backward()
            
            batch_loss += loss.item()
            total_loss += loss.item()

            
            # 归一化，防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # 更新参数和学习率
            optimizer.step()
            scheduler.step()

            

            # Print每10个batch的loss和time
            if (step % 10 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # 计算10个batch的时间
                time_elapsed = time.time() - t0_batch
                # Print训练结果
                print(
                    f"{epoch_i + 1:^7} | {step:^11} | {batch_loss /batch_counts:^16.6f}| {time_elapsed:^9.2f}")

                # 重置batch参数
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()
        
        # 每5个存一下模型
        save_name = 'cls_cross_model_'+ str(epoch_i+1)+'.pth'
        torch.save(model, save_name)
            
        # 计算平均loss 这个是训练集的loss
        avg_train_loss = total_loss / times_all
        print(
            f"{epoch_i + 1:^7} | {'-':^10} | {avg_train_loss:^14.6f} | {time_elapsed:^9.2f}")
        print("-" * 100)
        print("\n")



In [27]:
cross_model, optimizer, scheduler = initialize_model(epochs=10)
# print("Start training and validation:\n")
print("Start training and testing:\n")
train(cross_model, train_dataloader, val_dataloader, epochs=10, step1=True, step2=False, step3=False)  # 这个是有评估的

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Start training and testing:

 Epoch  | 每10个Batch  |    训练集 Loss    |   时间    
----------------------------------------------------------------------------------------------------
   1    |     10      |    690.975297   |   10.28  
   1    |     20      |    664.385724   |   9.11   
   1    |     30      |    657.349463   |   9.10   
   1    |     40      |    653.539282   |   9.11   
   1    |     50      |    647.840979   |   9.11   
   1    |     60      |    644.965833   |   9.11   
   1    |     70      |    642.498712   |   9.11   
   1    |     80      |    640.474963   |   9.12   
   1    |     90      |    637.166901   |   9.11   
   1    |     100     |    634.881848   |   9.11   
   1    |     110     |    633.286713   |   9.11   
   1    |     120     |    632.582648   |   9.12   
   1    |     130     |    629.492310   |   9.12   
   1    |     140     |    628.413318   |   9.11   
   1    |     150     |    627.319122   |   9.16   
   1    |     160     |    627.222223   |

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


   1    |     -      |   618.466749   |   4.20   
----------------------------------------------------------------------------------------------------


 Epoch  | 每10个Batch  |    训练集 Loss    |   时间    
----------------------------------------------------------------------------------------------------
   2    |     10      |    608.794794   |   10.05  
   2    |     20      |    607.737152   |   9.12   
   2    |     30      |    608.640491   |   9.12   
   2    |     40      |    607.426642   |   10.82  
   2    |     50      |    607.662024   |   9.13   
   2    |     60      |    607.748511   |   9.13   
   2    |     70      |    608.312158   |   9.13   
   2    |     80      |    607.790393   |   9.13   
   2    |     90      |    607.621326   |   9.13   
   2    |     100     |    607.434760   |   9.13   
   2    |     110     |    607.941754   |   9.13   
   2    |     120     |    607.491083   |   9.13   
   2    |     130     |    608.093140   |   9.13   
   2    |     140    

In [28]:
torch.save(cross_model, 'cross_model.pth')

In [29]:
str(1)

'1'

In [30]:
import torch
a= torch.tensor(([1,2,3,1,5]),dtype=float)
a.mean(dim=0)

tensor(2.4000, dtype=torch.float64)

In [31]:
torch.pow(torch.tensor(-10),torch.tensor(2))

tensor(100)

In [32]:
import itertools

In [33]:
list(itertools.combinations(torch.tensor([1,2,3]), 2))

[(tensor(1), tensor(2)), (tensor(1), tensor(3)), (tensor(2), tensor(3))]

***

In [34]:
b_text_inputs, b_text_masks, b_text_attribute_label, b_text_word_together, \
b_img_obj_feat, b_img_obj_bbox, b_img_ocr_feat, b_img_ocr_bbox, b_img_masks, b_editor_distant = train_dataloader.__iter__().next()

In [35]:
b_text_inputs.shape

torch.Size([128, 61])

In [36]:
b_text_inputs = b_text_inputs.to(device)
b_text_masks = b_text_masks.to(device)
b_text_attribute_label = b_text_attribute_label.to(device)
b_text_word_together = b_text_word_together.to(device)
b_img_obj_feat = b_img_obj_feat.to(device)
b_img_obj_bbox = b_img_obj_bbox.to(device)
b_img_ocr_feat = b_img_ocr_feat.to(device)
b_img_ocr_bbox = b_img_ocr_bbox.to(device)
b_img_masks = b_img_masks.to(device)
b_editor_distant = b_editor_distant.to(device)

In [37]:
cross_model, optimizer, scheduler = initialize_model(epochs=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [38]:
text_cls, text_last_hidden_state, img_cls, img_last_hidden_state \
= cross_model(b_text_inputs, b_text_masks, b_img_obj_feat, b_img_obj_bbox, b_img_ocr_feat, b_img_ocr_bbox, b_img_masks)

In [39]:
text_cls

tensor([[-0.2775, -0.1659, -0.0778,  ..., -0.1822,  0.1579,  0.4985],
        [-0.8105, -0.4942, -0.5276,  ..., -0.0192,  0.4456, -0.0930],
        [-0.3033,  0.0124, -0.3964,  ...,  0.0385,  0.0782,  0.2377],
        ...,
        [ 0.1432, -0.0670, -0.0317,  ..., -0.1520,  0.1794,  0.3405],
        [ 0.0601,  0.4484, -0.1519,  ..., -0.1184,  0.0894, -0.0289],
        [-0.0931,  0.3288, -0.5555,  ...,  0.1626,  0.4895, -0.0292]],
       device='cuda:0', grad_fn=<SliceBackward>)

In [40]:
img_cls

tensor([[ 0.6549,  0.2449,  0.0435,  ...,  0.8504,  0.1459, -1.2015],
        [ 0.6749, -0.2373, -1.2552,  ...,  0.0999,  0.1404, -0.8349],
        [ 0.8559,  0.0077, -0.7344,  ...,  0.4500,  0.2473, -0.9800],
        ...,
        [ 0.4493,  0.3524, -1.2906,  ...,  0.7226,  0.7867, -1.1003],
        [ 0.4568,  0.2141, -0.7252,  ...,  0.9375,  0.4080, -1.3505],
        [ 0.0148,  0.6836, -0.1288,  ...,  1.3798,  0.3511, -1.6702]],
       device='cuda:0', grad_fn=<SliceBackward>)

In [41]:
loss = NCE_loss(text_cls,img_cls)

In [42]:
loss

tensor(709.8785, device='cuda:0', grad_fn=<SubBackward0>)

In [43]:
img_cls.grad



In [44]:
cross_model.img_cls_Linear.bias.is_leaf

True

In [45]:
cross_model.img_cls_Linear.bias.grad

In [46]:
loss.backward()

In [47]:
cross_model.img_cls_Linear.bias.grad

tensor([-3.7289e-11, -1.7608e-09, -1.0696e-09, -7.2760e-12, -4.5225e-10,
        -2.9331e-10,  1.4370e-10, -1.4588e-09,  8.7311e-11,  9.1222e-10,
        -7.6398e-10,  1.7872e-10, -5.5138e-12,  9.0608e-11, -4.2928e-10,
        -8.5493e-11, -1.5643e-10,  7.7080e-10,  1.1596e-10, -9.3678e-10,
        -8.9494e-10, -2.5011e-10, -4.2928e-10, -4.6293e-10,  2.0736e-10,
         4.6857e-09, -2.5011e-11,  4.1837e-10,  6.9488e-10, -4.5839e-10,
        -1.4552e-09, -2.1646e-10,  3.7471e-10,  2.3574e-09,  7.6852e-11,
         8.1855e-11, -1.5666e-10,  1.5716e-09, -2.0555e-10, -2.0373e-10,
         3.7835e-10,  2.5102e-10, -6.8076e-10, -7.2396e-10, -4.3224e-10,
         5.5479e-11,  1.7644e-09,  1.2415e-10,  2.7012e-10,  1.6712e-10,
        -9.7680e-10,  2.2628e-09,  2.5193e-10, -2.4966e-10, -6.5847e-10,
         5.7298e-10,  3.5561e-10,  1.0043e-09, -9.8225e-11, -3.7102e-10,
        -5.3853e-10, -2.7940e-09,  1.5916e-11, -1.4290e-10, -3.5216e-09,
        -9.9317e-10, -1.0459e-10,  9.3223e-10, -2.2