In [1]:
import torch
import numpy as np
from data_processing import *
import scipy.io as spio
from torch.nn import functional as F

In [2]:
def cross_entropy(y_true,y_pred):
    C=0
    # one-hot encoding
    for col in range(y_true.shape[-1]):
        y_pred[col] = y_pred[col] if y_pred[col] < 1 else 0.99999
        y_pred[col] = y_pred[col] if y_pred[col] > 0 else 0.00001
        C+=y_true[col]*torch.log(y_pred[col])+(1-y_true[col])*torch.log(1-y_pred[col])
    return -C

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
len_traj = 4 
batch_size = 16
d_obs = 2
d_embed = 128 # embedding dimension
n_heads = 8
d_k = 16
d_hidden = 16
d_class = 4
n_layers = 4 # Encoder内含
trajectory = torch.rand(batch_size, len_traj, d_obs)

class Embedding(nn.Module):
    '''将轨迹序列映射到隐空间'''
    def __init__(self, inpt_dim, embed_dim):
        super(Embedding, self).__init__()
        self.fc = nn.Linear(inpt_dim, embed_dim)
    
    def forward(self, x):
        x = self.fc(x)
        return x

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
    
    def forward(self, Q, K, V):
        # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] [1,8,5,5]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_embed, d_k, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.W_Q = nn.Linear(d_embed, d_k * n_heads) # d_embed,7维, d_k,16*8=128维
        self.W_K = nn.Linear(d_embed, d_k * n_heads)
        self.W_V = nn.Linear(d_embed, d_k * n_heads)
        self.fc = nn.Linear(n_heads * d_k, d_embed)
        self.layer_norm = nn.LayerNorm(d_embed)
        self.DotProduct = ScaledDotProductAttention(d_k)

    def forward(self, x):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = x, x.size(0) # 残差跨层连接
        
        # q_s = k_s = v_s: [batch_size, n_heads, len_q, d_k]
        q_s = self.W_Q(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v_s = self.W_V(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        
        # context: [batch_size, n_heads, len_q, d_k]
        # attn: [batch_size, n_heads, len_q(=len_k), len_k(=len_q)]
        context, attn = self.DotProduct(q_s, k_s, v_s) # context是attn✖V
        # contiguous()的功能类似deepcopy
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size x len_q x n_heads * d_k] 最后一个维度是将8个head concat起来，维度依然512
        
        output = self.fc(context) # [batch_size, len_q, d_embed]
        return self.layer_norm(output + residual), attn # output: [batch_size, len_q, d_model]

class PoswiseFeedForwardNet(nn.Module):
    # 该模块也可用linear+ReLU实现
    def __init__(self, d_embed, d_hidden):
        super(PoswiseFeedForwardNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=d_embed, out_channels=d_hidden, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_hidden, out_channels=d_embed, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embed)
    def forward(self, x):
        residual = x # [batch_size, len_q, d_model]
        x = nn.ReLU()(self.conv1(x.transpose(1, 2)))
        x = self.conv2(x).transpose(1, 2)
        return self.layer_norm(x + residual)

class EncoderLayer(nn.Module):
    def __init__(self, d_embed, d_k, n_heads, d_hidden):
        super(EncoderLayer, self).__init__()
        self.MultiHeadAttention = MultiHeadAttention(d_embed, d_k, n_heads)
        self.PoswiseFeedForwardNet = PoswiseFeedForwardNet(d_embed, d_hidden)

    def forward(self, x):
        x, attn = self.MultiHeadAttention(x) # x to same Q,K,V
        x = self.PoswiseFeedForwardNet(x) # x: [batch_size, len_q, d_embed]
        return x, attn

class Encoder(nn.Module):
    '''
    using transformer encoder to classify sequential data
    '''
    def __init__(self, d_obs, d_embed, d_class, d_k, d_hidden, n_heads, n_layers):
        super(Encoder, self).__init__()
        self.embedding = Embedding(inpt_dim=d_obs, embed_dim=d_embed) # state dimension，embedding dimension
        self.layers = nn.ModuleList([EncoderLayer(d_embed, d_k, n_heads, d_hidden) for _ in range(n_layers)])
        self.fc = nn.Linear(d_embed, d_class)

    def forward(self, x): # enc_inputs : [batch_size x source_len]
        y = self.embedding(x)
        attentions = []
        for layer in self.layers:
            y, attention = layer(y)
            attentions.append(attention)

        out = F.log_softmax(self.fc(y), dim=-1)
        return out, attentions

if __name__ == '__main__':
    from model.neural_network import LSTM_net
    # encoder = Encoder(d_obs, d_embed, d_class, d_k, d_hidden, n_heads, n_layers)
    encoder = LSTM_net()
    trajectory = torch.rand(batch_size, len_traj, d_obs, dtype=torch.float32)
    shape = torch.randint(high = 4, size = (16, 4))
    loss = nn.CrossEntropyLoss()
    optimiser = torch.optim.Adam(
        encoder.parameters(),
        lr = 0.001,
        weight_decay=0.0001
    )
    for _ in range(100):
        pred = encoder(trajectory)
        optimiser.zero_grad()
        _loss = loss(pred, shape)
        _loss.backward()
        optimiser.step()
        print (f'{_loss.detach():.2f}')

    '''
    from torchinfo import summary
    summary(encoder, (batch_size, len_traj, d_obs))
    print(context.shape, attn[0].shape)
    '''



1.39
1.38
1.38
1.38
1.38
1.38
1.38
1.38
1.37
1.37
1.37
1.37
1.37
1.37
1.37
1.36
1.36
1.36
1.36
1.36
1.36
1.35
1.35
1.35
1.35
1.35
1.35
1.34
1.34
1.34
1.34
1.34
1.34
1.33
1.33
1.33
1.33
1.33
1.33
1.33
1.33
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.32
1.31
1.31
1.31
1.31
1.31
1.31
1.31
1.31
1.31
1.30
1.30
1.30
1.30
1.30
1.30
1.30
1.30
1.30
1.30
1.30
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.29
1.28
1.28
1.28
1.28
1.28
1.28
1.28
1.28
1.28
1.28
1.28


In [4]:
x = torch.randn(1, 1, 3)
print (x)
_x = nn.Linear(3, 3)

x = _x(x)
print (x)

tensor([[[-0.9983, -0.2701,  0.0599]]])
tensor([[[-0.6897, -0.0663, -0.4248]]], grad_fn=<ViewBackward0>)


In [8]:
x[0:6].shape

torch.Size([6, 2, 3])