In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

# Define the Graph Convolutional Network (GCN) layer
class GraphConvolutionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphConvolutionLayer, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.relu(x)
        return x

# Define the Pose Estimation Network using OpenPose
class PoseEstimationNetwork(nn.Module):
    def __init__(self):
        super(PoseEstimationNetwork, self).__init__()
        self.resnet = resnet18(pretrained=True)
        self.intermediate_layer_getter = IntermediateLayerGetter(self.resnet, {'layer4': 'layer4'})
        self.fpn = FeaturePyramidNetwork([256, 512, 1024, 2048], 256, extra_blocks=LastLevelP6P7(256, 256))
        self.conv1 = nn.Conv2d(256, 17, 1)
    
    def forward(self, x):
        x = self.intermediate_layer_getter(x)['layer4']
        x = self.fpn(x)['0']
        x = self.conv1(x)
        return x

# Define the Action Recognition Network using GCN
class ActionRecognitionNetwork(nn.Module):
    def __init__(self, num_classes):
        super(ActionRecognitionNetwork, self).__init__()
        self.conv1 = GraphConvolutionLayer(17, 32)
        self.conv2 = GraphConvolutionLayer(32, 64)
        self.fc1 = nn.Linear(64, num_classes)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = torch.mean(x, dim=0)
        x = self.fc1(x)
        return x

# Define the Dual Attention Network
class DualAttentionNetwork(nn.Module):
    def __init__(self, num_classes):
        super(DualAttentionNetwork, self).__init__()
        self.pose_estimation_network = PoseEstimationNetwork()
        self.action_recognition_network = ActionRecognitionNetwork(num_classes)
        self.attention1 = nn.MultiheadAttention(embed_dim=17, num_heads=4)
        self.attention2 = nn.MultiheadAttention(embed_dim=num_classes, num_heads=4)
    
    def forward(self, x, edge_index):
        pose = self.pose_estimation_network(x)
        pose = F.softmax(pose, dim=1)
        pose = torch.transpose(pose, 1, 0)
        pose = pose.unsqueeze(0)
        
        action = self.action_recognition_network(pose, edge_index)
        action = F.softmax(action, dim=1)
        action = action.unsqueeze(0)
        
        pose_att, _ = self.attention1(pose, pose, pose)
        action_att, _ = self.attention2(action, action, action)
        
        output = torch.cat((pose_att.squeeze(0), action_att.squeeze(0)), dim=1


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
#实例化 DualAttentionNetwork 模型
model = DualAttentionNetwork(num_classes=10)
#定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#定义损失函数
criterion = nn.CrossEntropyLoss()

#定义训练循环，其中需要进行前向传播、计算损失、反向传播和更新参数
for epoch in range(num_epochs):
    for data, edge_index, label in train_loader:
        optimizer.zero_grad()
        output = model(data, edge_index)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
#在训练集上进行评估
with torch.no_grad():
    correct = 0
    total = 0
    for data, edge_index, label in train_loader:
        output = model(data, edge_index)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
    print('Accuracy on training set: {}%'.format(100 * correct / total))

#在测试集上进行评估    
with torch.no_grad():
    correct = 0
    total = 0
    for data, edge_index, label in test_loader:
        output = model(data, edge_index)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
    print('Accuracy on test set: {}%'.format(100 * correct / total))
