In [None]:
class NonLocalBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 inter_channels,
                 out_channels,
                 dimension=1,
                 sub_sample=True,          
                 bn_layer=True):
        super(NonLocalBlock, self).__init__()
        
        # 确定维度
        self.dimension = dimension
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels
        self.out_channels = out_channels
        
        # 三个 1*1
        self.g = nn.Conv1d(in_channels=self.in_channels,
                 out_channels=self.inter_channels,
                 kernel_size=1,
                 stride=1,
                 padding=0)       
        
        self.theta = nn.Conv1d(in_channels=self.in_channels,
                             out_channels=self.inter_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
        self.phi = nn.Conv1d(in_channels=self.in_channels,
                           out_channels=self.inter_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)
        
        if sub_sample:
            self.g = nn.Sequential(self.g, nn.MaxPool1d(kernel_size=(2)))
            self.phi = nn.Sequential(self.phi, nn.MaxPool1d(kernel_size=(2)))
        
        # 最后的 1*1
        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv1d(in_channels=self.inter_channels,
                          out_channels=self.out_channels,
                          kernel_size=1,
                          stride=1,
                          padding=0), 
                nn.BatchNorm1d(self.in_channels))
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv1d(in_channels=self.inter_channels,
                             out_channels=self.out_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)
            
    def forward(self, x):
        '''
        :param x: (B, F, N)
        :return:
        '''
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)         #[bs, c, N]
        g_x = g_x.permute(0, 2, 1)                                        #[bs, N, c]
        
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) #[bs, c, N]
        theta_x = theta_x.permute(0, 2, 1)                                #[bs, N, c]
        
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)     #[bs, c, N] 
        
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x                                                       
        return z

In [None]:
class Road_status(nn.Module):
    def __init__(self, embedding_pretrained_dict=embedding_dict):
        super(Road_status, self).__init__()
        
        # 数据一：历史与实时路况
        # 近期路况特征
        # 20 --> 4
        self.curr_t = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0, dilation=4), # T  20-->4   
            nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1, dilation=1))# 4-->8*4
        # 20 --> 5
        self.curr_c = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=4, stride=4, padding=0, dilation=1), # T  20-->5   
            nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1, dilation=1))# 4-->8*5
        # Non_Local
        self.cur_nonlocalblock = NonLocalBlock(in_channels=8, inter_channels=16, out_channels=8, sub_sample=False, bn_layer=False)                                                 
        
        # 历史路况特征
        self.hist = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=8, dilation=4), nn.ReLU(),       # 100
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=8, dilation=4), nn.ReLU(),      # 100
            nn.Conv1d(in_channels=64, out_channels=32, kernel_size=4, stride=4, padding=0, dilation=1), nn.ReLU(),    # 25 
            nn.Conv1d(in_channels=32, out_channels=8, kernel_size=5, stride=5, padding=0, dilation=1), nn.ReLU())     # 8*5
        
        
        # 数据二：道路属性
        # link_id embedding，预训练来自GCN
        self.embed = nn.Embedding.from_pretrained(embedding_dict, freeze=False)
        # 其他道路特征
        self.attr = nn.Sequential(
            nn.Linear(in_features=14, out_features=32), nn.ReLU(),
            nn.Linear(in_features=32, out_features=32), nn.ReLU()) 

        # 全连接
        self.fc = nn.Sequential(
                nn.Linear(156, 64),
                nn.Linear(64, 16), nn.ReLU(),
                nn.Linear(16, 3))
        
        self.flatten = nn.Flatten()    
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        '''x：(B,112)'''    
        link = self.embed(x[:,0].long())
        attr = self.attr(x[:,1:15])
        curr_t = self.curr_t(x[:,-20:].unsqueeze(1)) # 8*4
        curr_c = self.curr_c(x[:,-20:].unsqueeze(1)) # 8*5       
        cur = torch.cat([curr_t,curr_c], dim=-1)        
        cur = self.cur_nonlocalblock(cur)
        cur = self.flatten(cur)        
        
        hist = self.hist(x[:,15:-20].unsqueeze(1))   # 8*5
        hist = self.flatten(hist)
        
        feat = torch.cat([link, attr, cur, hist], dim=-1)
        y = self.fc(feat)
        out = self.softmax(y)
        return out