In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv


class SimplifiedTextGCN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimplifiedTextGCN, self).__init__()
        self.conv1 = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        return x


class CombinedSimplifiedModel(nn.Module):
    def __init__(self, bert_model, object_features, place_features, lstm, attention_layer, gcn_in_channels, gcn_out_channels):
        super(CombinedSimplifiedModel, self).__init__()
        self.text_features = bert_model
        self.object_features = object_features
        self.place_features = place_features
        self.lstm = lstm
        self.attention_layer = attention_layer
        self.text_gcn = SimplifiedTextGCN(gcn_in_channels, gcn_out_channels)
        self.linear = nn.Linear(900, 3)  # For MVSA, it is 3.

    def forward(self, text, object_feature, place_feature, edge_index):
        text_feature = self.text_features(text)[0][:, 0, :]  # 取 BERT 输出的 [CLS] 令牌的隐藏状态
        text_feature = self.text_gcn(text_feature, edge_index)
        object_feature = self.object_features(object_feature)
        place_feature = self.place_features(place_feature)

        text_object_attention, _ = self.attention_layer(q=text_feature, k=object_feature, v=object_feature)
        text_place_attention, _ = self.attention_layer(q=text_feature, k=place_feature, v=place_feature)

        multi_feature = torch.cat([text_feature, text_object_attention, text_place_attention], dim=1)
        output = self.linear(multi_feature)

        return output


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_vocab(vocab_path):
    vocab = {}
    with open(vocab_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f.readlines()):
            word = line.strip()
            vocab[word] = idx
    return vocab

vocab_path = './vocab.txt'
vocab = load_vocab(vocab_path)


In [3]:
from transformers import BertTokenizer, BertModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
from torch.utils.data import Dataset, DataLoader
import json
from PIL import Image
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.data = []
        with open(file_path, 'r') as f:
            for line in f:
                example = json.loads(line.strip())
                self.data.append(example)
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        example = self.data[index]
        text = example['text']
        text_tokens = tokenizer.encode_plus(text, add_special_tokens=True, padding='max_length', max_length=64, truncation=True, return_tensors='pt')
        image_path = example['image']
        label = example['label']
        places = example['places']
        objects = example['objects']

        if label == 'negative':
          label = 0
        elif label == 'neutral':
          label = 1
        elif label == 'positive':
          label = 2
        
        # 加载图像并进行预处理
        image = Image.open('../'+image_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        
        # 转换为张量并返回
        return {'text': text_tokens['input_ids'].squeeze(), 'image': image, 'label': torch.tensor(label),
                'places': torch.tensor(places), 'objects': torch.tensor(objects)}

# 定义数据增强和预处理的转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 创建数据集和数据加载器
train_dataset = MyDataset('./train_all_anno.json', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

In [5]:
import itertools

def generate_fully_connected_edge_index(num_nodes):
    edge_index = list(itertools.combinations(range(num_nodes), 2))
    edge_index = [(i, j) for (i, j) in edge_index] + [(j, i) for (i, j) in edge_index]
    return torch.tensor(edge_index, dtype=torch.long).t().contiguous()


In [6]:

# 实例化模型和优化器
model = CombinedSimplifiedModel(bert_model=bert_model, object_features=None, place_features=None
                                , lstm=None, attention_layer=None, gcn_in_channels=128, gcn_out_channels=256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 设置梯度累积的批次
accumulation_steps = 4

# 训练模型
for epoch in range(10):
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    model.train()
    for i,batch in enumerate(train_loader, 0):
        text, object_feature, place_feature, label = (
            batch['text'].to(device),
            batch['objects'].to(device),
            batch['places'].to(device),
            batch['label'].to(device)
        )

        num_nodes = object_feature.shape[0] + place_feature.shape[0]
        edge_index = generate_fully_connected_edge_index(num_nodes).to(device)

        optimizer.zero_grad()
        output = model(text, object_feature, place_feature, edge_index)
        loss = criterion(output, label)
        loss.backward()

        # 梯度累积
        if (i+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        train_total += label.size(0)
        train_correct += (predicted == label).sum().item()

    train_acc = 100. * train_correct / train_total
    train_loss /= len(train_loader)
    print('Epoch [{}/{}], Loss: {:.4f}, Train Acc: {:.2f}%'.format(epoch+1, 10, train_loss, train_acc))

# 在测试集上评估模型性能
test_dataset = MyDataset('./test_all_anno.json', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
    for batch in test_loader:
        text, object_feature, place_feature, label = (
            batch['text'].to(device),
            batch['objects'].to(device),
            batch['places'].to(device),
            batch['label'].to(device)
        )
        
        num_nodes = object_feature.shape[0] + place_feature.shape[0]
        edge_index = generate_fully_connected_edge_index(num_nodes).to(device)
        
        output = model(text, object_feature, place_feature, edge_index)
        loss = criterion(output, label)
        test_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        test_total += label.size(0)
        test_correct += (predicted == label).sum().item()

test_acc = 100. * test_correct / test_total
test_loss /= len(test_loader)
print('Test Loss: {:.4f}, Test Acc: {:.2f}%'.format(test_loss, test_acc))