# 准备数据集

In [1]:
from torch.utils.data import Dataset
import torch
import numpy as np
from torch.utils.data import DataLoader

class MUStartDataset(Dataset):
    def __init__(self,mode = 'train',feature_path = './featuresIndepResnet152.pkl'):
        with open(feature_path,'rb') as f:
            import pickle
            data = pickle.load(f)
        self.feature_dict = data[mode]
        # [-1,1] -> [0,1]
        self.feature_dict['labels'] = ((self.feature_dict['labels'] + 1)/2).astype(np.int64)

    def __getitem__(self,index):
        feature ={}
        feature['audio_feature'] = self.feature_dict['audio_feature'][index]
        feature['video_features_p'] = self.feature_dict['video_features_p'][index]
        feature['bert_indices'] = self.feature_dict['bert_indices'][index]
        feature['box_pad_indices'] = self.feature_dict['box_pad_indices'][index]
        feature['big_graphs'] = self.feature_dict['big_graphs'][index]
        feature['labels'] = self.feature_dict['labels'][index]
        
        return feature
    def __len__(self):
        labels = self.feature_dict['labels']
        length = labels.shape[0]
        return length
    
    def get_sample_shape(self,index):
        shape_dict = {}
        shape_dict['audio_feature'] = self.feature_dict['audio_feature'][index].shape
        shape_dict['video_features_p'] = self.feature_dict['video_features_p'][index].shape
        shape_dict['bert_indices'] = self.feature_dict['bert_indices'][index].shape
        shape_dict['box_pad_indices'] = self.feature_dict['box_pad_indices'][index].shape
        shape_dict['big_graphs'] = self.feature_dict['big_graphs'][index].shape
        # shape_dict['labels'] = self.feature_dict['labels'][index].shape
        shape_dict['labels'] = type(self.feature_dict['labels'][index])
        return shape_dict
        
d = MUStartDataset('valid')
dl = DataLoader(d, batch_size=2, num_workers=0, shuffle=False)
batch_sample = iter(dl).next()
print(batch_sample.keys())
batch_sample['audio_feature'].size(2)
batch_sample['video_features_p'].size(2)

dict_keys(['audio_feature', 'video_features_p', 'bert_indices', 'box_pad_indices', 'big_graphs', 'labels'])


768

# 准备模型

In [2]:
import torch
import torch.nn as nn
from transformers import BertModel
from layers.dynamic_rnn import DynamicLSTM
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.functional as F

class MAG(nn.Module):
    def __init__(self, hidden_size, beta_shift, dropout_prob):        
        super(MAG, self).__init__()
        print("Initializing MAG with beta_shift:{} hidden_prob:{}".format(beta_shift, dropout_prob))

        self.W_hv = nn.Linear(VISUAL_DIM + TEXT_DIM, TEXT_DIM)
        self.W_ha = nn.Linear(ACOUSTIC_DIM + TEXT_DIM, TEXT_DIM)
        self.W_v = nn.Linear(VISUAL_DIM, TEXT_DIM)
        self.W_a = nn.Linear(ACOUSTIC_DIM, TEXT_DIM)
        self.beta_shift = beta_shift

        self.LayerNorm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, text_embedding, visual, acoustic):
        eps = 1e-6
        weight_v = F.relu(self.W_hv(torch.cat((visual, text_embedding), dim=-1)))
        weight_a = F.relu(self.W_ha(torch.cat((acoustic, text_embedding), dim=-1)))
        h_m = weight_v * self.W_v(visual) + weight_a * self.W_a(acoustic)
        em_norm = text_embedding.norm(2, dim=-1)
        hm_norm = h_m.norm(2, dim=-1)
        DEVICE = visual.device
        hm_norm_ones = torch.ones(hm_norm.shape, requires_grad=True).to(DEVICE)
        hm_norm = torch.where(hm_norm == 0, hm_norm_ones, hm_norm)
        thresh_hold = (em_norm / (hm_norm + eps)) * self.beta_shift
        ones = torch.ones(thresh_hold.shape, requires_grad=True).to(DEVICE)
        alpha = torch.min(thresh_hold, ones)
        alpha = alpha.unsqueeze(dim=-1)
        acoustic_vis_embedding = alpha * h_m
        embedding_output = self.dropout(
            self.LayerNorm(acoustic_vis_embedding + text_embedding)
        )

        return embedding_output


from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers import BertTokenizer

class MultimodalConfig(object):
    def __init__(self, beta_shift, dropout_prob):
        self.beta_shift = beta_shift
        self.dropout_prob = dropout_prob
        
class MAG_BertModel(BertPreTrainedModel):
    def __init__(self, config, multimodal_config):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.MAG = MAG(
            config.hidden_size,
            multimodal_config.beta_shift,
            multimodal_config.dropout_prob,
        )

        self.init_weights()
        
    def forward(
    self,
    input_ids,
    visual,
    acoustic,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    output_attentions=None,
    output_hidden_states=None,
    singleTask = False,
    ):
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )
        fused_embedding = self.MAG(embedding_output, visual, acoustic)
        
        encoder_outputs = self.encoder(
            fused_embedding,
        )

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        # sequence_output, pooled_output, (hidden_states), (attentions)
        return outputs
        

class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias :
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias',None)
        
    def forward(self, text, adj):
        hidden = torch.matmul(text,self.weight)
        
        denom = torch.sum(adj,dim=2,keepdim=True) + 1
        output = torch.matmul(adj, hidden.float())/denom
        if self.bias is not None:
            output = output + self.bias

        return output

import torch.nn as nn
class AlignSubNet(nn.Module):
    def __init__(self, dst_len):
        """
        mode: the way of aligning avg_pool 这个模型并没有参数
        """
        super(AlignSubNet, self).__init__()
        self.dst_len = dst_len

    def get_seq_len(self):
        return self.dst_len
    
    def __avg_pool(self, text_x, audio_x, video_x):
        def align(x):
            raw_seq_len = x.size(1)
            if raw_seq_len == self.dst_len:
                return x
            if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len:
                pad_len = 0
                pool_size = raw_seq_len // self.dst_len
            else:
                pad_len = self.dst_len - raw_seq_len % self.dst_len
                pool_size = raw_seq_len // self.dst_len + 1
            pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)])
            x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1)
            x = x.mean(dim=1)
            return x
        text_x = align(text_x)
        audio_x = align(audio_x)
        video_x = align(video_x)
        return text_x, audio_x, video_x
    
 
    def forward(self, text_x, audio_x, video_x):
        if text_x.size(1) == audio_x.size(1) == video_x.size(1):
            return text_x, audio_x, video_x
        return self.__avg_pool(text_x, audio_x, video_x)
    
class CMGCN(nn.Module):
    def __init__(self, multimodal_config):
        super(CMGCN, self).__init__()
        print('create CMGCN model')
        # self.bert = BertModel.from_pretrained('./bert-base-uncased/')
        self.mag_bert = MAG_BertModel.from_pretrained('./bert-base-uncased/',multimodal_config=multimodal_config)
        self.text_lstm = DynamicLSTM(768,4,num_layers=1,batch_first=True,bidirectional=True)
        self.vit_fc = nn.Linear(768,2*4)
        self.gc1 = GraphConvolution(2*4, 2*4)
        self.gc2 = GraphConvolution(2*4, 2*4)
        self.fc = nn.Linear(2*4,2)
        
        
    def forward(self, inputs):
        bert_indices = inputs['bert_indices']
        graph = inputs['big_graphs']
        box_vit = inputs['video_features_p']
        bert_text_len = torch.sum(bert_indices != 0, dim = -1)
        # 2,24, audio_feature key 2 33 33 , 2,10 768 
        visual = box_vit
        acoustic = inputs['audio_feature']
        self.align_subnet = AlignSubNet(bert_indices.size(1))
        bert_indices, acoustic, visual= self.align_subnet(bert_indices,acoustic,visual)
        
        acoustic = acoustic.float()
        
        outputs = self.mag_bert(bert_indices, visual, acoustic)
        
        encoder_layer = outputs[0]
        pooled_output = outputs[1]
        
        text_out, (_, _) = self.text_lstm(encoder_layer, bert_text_len)
        # 与原始代码不同，这里因为进行了全局的特征填充，导致text_out可能无法达到填充长度，补充为0
        if text_out.shape[1] < encoder_layer.shape[1]:
            pad = torch.zeros((text_out.shape[0],encoder_layer.shape[1]-text_out.shape[1],text_out.shape[2]))
            text_out = torch.cat((text_out,pad),dim=1)

        box_vit = box_vit.float()
        box_vit = self.vit_fc(box_vit)
        features = torch.cat([text_out, box_vit], dim=1)

        graph = graph.float()
        x = F.relu(self.gc1(features, graph))
        x = F.relu(self.gc2(x,graph))
        
        alpha_mat = torch.matmul(features,x.transpose(1,2))
        alpha_mat = alpha_mat.sum(1, keepdim=True)
        alpha = F.softmax(alpha_mat, dim = 2)
        x = torch.matmul(alpha, x).squeeze(1)
        
        output = self.fc(x)
        return output

def init_params():
    for child in cmgcn_model.children():
        # print(type(child) != BertModel)
        if type(child) != MAG_BertModel:
            for p in child.parameters():
                # print(type(child))
                # print(p.shape, p.requires_grad)
                if p.requires_grad :
                    # print(len(p.shape))
                    if len(p.shape) > 1:
                        torch.nn.init.xavier_uniform_(p)
                        # print(p[0][:2])
                    else:
                        import math
                        stdv = 1.0 / math.sqrt(p.shape[0])
                        torch.nn.init.uniform_(p, a=-stdv, b=stdv)
                        # print('else', p[:2])
    print('init_params()')
    
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')

beta_shift = 1.0 
dropout_prob = 0.5 
multimodal_config = MultimodalConfig(
    beta_shift=beta_shift, dropout_prob=dropout_prob
)


d = MUStartDataset('valid')
dl = DataLoader(d, batch_size=2, num_workers=0, shuffle=False)
batch_sample = iter(dl).next()

ACOUSTIC_DIM = batch_sample['audio_feature'].size(2)
VISUAL_DIM = batch_sample['video_features_p'].size(2)
TEXT_DIM = 768

cmgcn_model = CMGCN(multimodal_config = multimodal_config ).to(device)

init_params()    

optimizer = torch.optim.Adam([
    {'params':cmgcn_model.mag_bert.parameters(),'lr':2e-5},
    {'params':cmgcn_model.text_lstm.parameters(),},
    {'params':cmgcn_model.vit_fc.parameters(),},
    {'params':cmgcn_model.gc1.parameters(),},
    {'params':cmgcn_model.gc2.parameters(),},
    {'params':cmgcn_model.fc.parameters(),},
],lr=0.001,weight_decay=1e-5)
# optimizer = torch.optim.Adam(cmgcn_model.parameters(),lr=1e-3,weight_decay=1e-5)

d = MUStartDataset('valid')
dl = DataLoader(d, batch_size=2, num_workers=0, shuffle=False)
batch = iter(dl).next()
batch.keys()
inputs ={}
for key in batch.keys():
    inputs[key] = batch[key].to(device)
outputs = cmgcn_model(inputs)
print(outputs.shape)

create CMGCN model
Initializing MAG with beta_shift:1.0 hidden_prob:0.5


Some weights of the model checkpoint at ./bert-base-uncased/ were not used when initializing MAG_BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing MAG_BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MAG_BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MAG_BertModel were not initialized from the model checkpoint at ./bert-base-uncased/ and are newly initialized: ['bert.MAG.W

init_params()
torch.Size([2, 2])


# 开始训练

In [3]:
num_epoch = 1
cmgcn_model_path = 'cmgcn_model.pth'

# def train():
print('start train:' + '-'*10)
train_dataset = MUStartDataset(mode='train')
valid_dataset = MUStartDataset(mode='valid')
test_dataset = MUStartDataset(mode='test')
train_dataloader = DataLoader(train_dataset,batch_size=2,num_workers=0,shuffle=False)
valid_dataloader = DataLoader(valid_dataset,batch_size=2,num_workers=0,shuffle=False)
test_dataloader = DataLoader(test_dataset,batch_size=2,num_workers=0,shuffle=False)

def evaluate_acc_f1(data_loader):
    n_correct, n_total = 0, 0
    targets_all, outputs_all = None, None
    cmgcn_model.eval()
    with torch.no_grad():
        for i_batch,batch in enumerate(data_loader):
            inputs ={}
            for key in batch.keys():
                inputs[key] = batch[key].to(device)
            outputs = cmgcn_model(inputs)
            targets = batch['labels'].to(device)
            
            n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
            n_total += len(outputs)
            
            if targets_all is None:
                targets_all = targets
                outputs_all = outputs
            else:
                targets_all = torch.cat((targets_all,targets), dim=0)
                outputs_all = torch.cat((outputs_all,outputs), dim=0)
    
    # if macro :
    from sklearn import metrics
    acc = n_correct / n_total
    f1 = metrics.f1_score(targets_all.cpu(), torch.argmax(outputs_all,-1).cpu(), labels=[0,1], average='macro', zero_division=0)
    precision = metrics.precision_score(targets_all.cpu(), torch.argmax(outputs_all,-1).cpu(), labels=[0,1], average='macro', zero_division=0)
    recall = metrics.recall_score(targets_all.cpu(), torch.argmax(outputs_all,-1).cpu(), labels=[0,1], average='macro', zero_division=0)

    return acc,f1,precision,recall

max_val_acc , max_val_f1, max_val_epoch, global_step = 0, 0, 0, 0
for i_epoch in range(num_epoch):
    print('i_epoch:', i_epoch)
    n_correct, n_total, loss_total = 0, 0, 0
    for i_batch,batch in enumerate(train_dataloader):
        global_step += 1
        cmgcn_model.train()
        optimizer.zero_grad()
        inputs ={}
        for key in batch.keys():
            inputs[key] = batch[key].to(device)
        # print(inputs.keys()) 
        # dict_keys(['audio_feature', 'video_features_p', 'bert_indices', 'box_pad_indices', 'big_graphs', 'labels'])
        outputs = cmgcn_model(inputs)
        targets = batch['labels'].to(device)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
        n_total += len(outputs)
        loss_total += loss.item() * len(outputs)
        
        train_acc = n_correct / n_total
        train_loss = loss_total / n_total
        
        if global_step % 1 == 0:
            val_acc, val_f1, val_precision, val_recall = evaluate_acc_f1(valid_dataloader)
            if val_acc >= max_val_acc:
                max_val_f1 = val_f1
                max_val_acc = val_acc
                max_val_epoch = i_epoch
                torch.save(cmgcn_model.state_dict(),cmgcn_model_path)
                print('here save the model cmgcn_model.pth')
        
    if i_epoch - max_val_epoch >= 0:
        print('early stop')
        break
        
    break
cmgcn_model.load_state_dict(torch.load(cmgcn_model_path))
test_acc, test_f1,test_precision,test_recall = evaluate_acc_f1(test_dataloader)
# test_acc, test_f1,test_precision,test_recall = evaluate_acc_f1(test_data_loader)
print('test_acc:', test_acc)
print('test_f1:', test_f1)
print('test_precision', test_precision)
print('test_recall', test_recall)

# return 0
# train()

start train:----------
i_epoch: 0
here save the model cmgcn_model.pth
early stop
test_acc: 0.5
test_f1: 0.3333333333333333
test_precision 0.25
test_recall 0.5
