# 探索mag_bert模型

In [1]:
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers import BertModel, BertConfig

configuration = BertConfig()
embeddings = BertEmbeddings(configuration)

from transformers import BertTokenizer, BertModel
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt")

print(inputs['input_ids'])

embedding_output = embeddings(inputs['input_ids'])

print(embedding_output.shape)

visual = torch.randn((2,8,2048))
print(visual.shape)

acoustic = torch.randn((2,8,33))
print(acoustic.shape)

text_embedding = embedding_output
vt = torch.cat((visual, text_embedding), dim = -1)
print(vt.shape)
import torch.nn as nn
gvt = nn.Linear(2048 + 768, 768)
weight_v = gvt(vt)

print(weight_v.shape)
layer_vt = nn.Linear(2048, 768)
h_m = weight_v * layer_vt(visual)
print(h_m.shape)


tensor([[  101,  7592,  1010,  2026,  3899,  2003, 10140,   102],
        [  101,  7592,  1010,  2026,  3899,  2003, 10140,   102]])
torch.Size([2, 8, 768])
torch.Size([2, 8, 2048])
torch.Size([2, 8, 33])
torch.Size([2, 8, 2816])
torch.Size([2, 8, 768])
torch.Size([2, 8, 768])


In [2]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class MAG(nn.Module):
    def __init__(self, hidden_size, beta_shift, dropout_prob):        
        super(MAG, self).__init__()
        print("Initializing MAG with beta_shift:{} hidden_prob:{}".format(beta_shift, dropout_prob))

        self.W_hv = nn.Linear(VISUAL_DIM + TEXT_DIM, TEXT_DIM)
        self.W_ha = nn.Linear(ACOUSTIC_DIM + TEXT_DIM, TEXT_DIM)
        self.W_v = nn.Linear(VISUAL_DIM, TEXT_DIM)
        self.W_a = nn.Linear(ACOUSTIC_DIM, TEXT_DIM)
        self.beta_shift = beta_shift

        self.LayerNorm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, text_embedding, visual, acoustic):
        eps = 1e-6
        weight_v = F.relu(self.W_hv(torch.cat((visual, text_embedding), dim=-1)))
        weight_a = F.relu(self.W_ha(torch.cat((acoustic, text_embedding), dim=-1)))
        h_m = weight_v * self.W_v(visual) + weight_a * self.W_a(acoustic)
        em_norm = text_embedding.norm(2, dim=-1)
        hm_norm = h_m.norm(2, dim=-1)
        DEVICE = visual.device
        hm_norm_ones = torch.ones(hm_norm.shape, requires_grad=True).to(DEVICE)
        hm_norm = torch.where(hm_norm == 0, hm_norm_ones, hm_norm)
        thresh_hold = (em_norm / (hm_norm + eps)) * self.beta_shift
        ones = torch.ones(thresh_hold.shape, requires_grad=True).to(DEVICE)
        alpha = torch.min(thresh_hold, ones)
        alpha = alpha.unsqueeze(dim=-1)
        acoustic_vis_embedding = alpha * h_m
        embedding_output = self.dropout(
            self.LayerNorm(acoustic_vis_embedding + text_embedding)
        )

        return embedding_output

beta_shift = 1.0
dropout_prob = 0.5 
hidden_size = 768
ACOUSTIC_DIM = 33
VISUAL_DIM = 2048
TEXT_DIM = 768
mag_model = MAG(hidden_size,beta_shift,dropout_prob)
print(mag_model)
text_embedding = torch.randn((2,8,TEXT_DIM))
visual = torch.randn((2,8,VISUAL_DIM))
acoustic = torch.randn((2,8,ACOUSTIC_DIM))
out = mag_model(text_embedding, visual, acoustic)
print(out.shape)

Initializing MAG with beta_shift:1.0 hidden_prob:0.5
MAG(
  (W_hv): Linear(in_features=2816, out_features=768, bias=True)
  (W_ha): Linear(in_features=801, out_features=768, bias=True)
  (W_v): Linear(in_features=2048, out_features=768, bias=True)
  (W_a): Linear(in_features=33, out_features=768, bias=True)
  (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
torch.Size([2, 8, 768])


# mag_bert模型的使用

In [3]:
from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers import BertTokenizer

class MultimodalConfig(object):
    def __init__(self, beta_shift, dropout_prob):
        self.beta_shift = beta_shift
        self.dropout_prob = dropout_prob
        
class MAG_BertModel(BertPreTrainedModel):
    def __init__(self, config, multimodal_config):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.MAG = MAG(
            config.hidden_size,
            multimodal_config.beta_shift,
            multimodal_config.dropout_prob,
        )

        self.init_weights()
        
    def forward(
    self,
    input_ids,
    visual,
    acoustic,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    singleTask = False,
    ):
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )
        print(embedding_output.shape)
        fused_embedding = self.MAG(embedding_output, visual, acoustic)
        print(fused_embedding.shape)
        
        encoder_outputs = self.encoder(
            fused_embedding,
        )

        sequence_output = encoder_outputs[0]
        print(sequence_output.shape)
        pooled_output = self.pooler(sequence_output)
        print(pooled_output.shape)
        # 单任务提取pooled_output
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        # sequence_output, pooled_output, (hidden_states), (attentions)
        print(len(outputs))
        return outputs
        
        
beta_shift = 1.0 
dropout_prob = 0.5 
multimodal_config = MultimodalConfig(
    beta_shift=beta_shift, dropout_prob=dropout_prob
)
print(multimodal_config.beta_shift)
mag_bertmodel = MAG_BertModel.from_pretrained('./bert-base-uncased/',multimodal_config=multimodal_config)
print(mag_bertmodel.config)


ACOUSTIC_DIM = 33
VISUAL_DIM = 2048
TEXT_DIM = 768
visual = torch.randn((2,8,VISUAL_DIM))
acoustic = torch.randn((2,8,ACOUSTIC_DIM))

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt")

input_ids = inputs['input_ids']

print(input_ids)
outputs = mag_bertmodel(input_ids, visual, acoustic)
print(outputs[0].shape,outputs[1].shape)

1.0
Initializing MAG with beta_shift:1.0 hidden_prob:0.5


Some weights of the model checkpoint at ./bert-base-uncased/ were not used when initializing MAG_BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing MAG_BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MAG_BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MAG_BertModel were not initialized from the model checkpoint at ./bert-base-uncased/ and are newly initialized: ['bert.MAG.W

BertConfig {
  "_name_or_path": "./bert-base-uncased/",
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.14.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

tensor([[  101,  7592,  1010,  2026,  3899,  2003, 10140,   102],
        [  101,  7592,  1010,  2026,  3899,  2003, 10140,   102]])
torch.Size([2, 8, 768])
torch.Size([2, 8, 768])
torch.Size([2, 8, 768])
torch.Size([2, 768])
2
torch.Size([2, 8, 768]) torch.Size([2, 768])
