In [72]:
import torch
import torch.nn as nn
import torch.nn.init as init

In [73]:
class LandmarkEmbedding(nn.Module):
    def __init__(self,ins,outs,name,hidden_size=50):
        super(LandmarkEmbedding, self).__init__()
        self.outs = outs
        self.name = name
        self.fc1 = nn.Linear(ins,hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size,outs)
        
    def forward(self,x):
        hidden = (self.fc1(x))
        output = self.fc2(hidden)
        return output

In [74]:
if __name__ == "__main__":
    hand_embedding = LandmarkEmbedding(21*2, 20, name = 'hand')
# (batch_size, ncols*ndims)
    hand = torch.randn(20, 10, 21*2)
    print(hand_embedding(hand).shape)

torch.Size([20, 10, 20])


In [75]:
lips_idxs = range(0,40)
hand_idxs = range(40,61)
pose_idxs = range(61,66)
ndims = 2
input_size = 10

In [76]:
class Embedding(nn.Module):
    def __init__(self,units,hidden_size=50):
        super(Embedding, self).__init__()
        self.units = units
        self.lips_embedding = LandmarkEmbedding(len(lips_idxs)*ndims,units,'lips')
        self.hand_embedding = LandmarkEmbedding(len(hand_idxs)*ndims,units,'hand')
        self.pose_embedding = LandmarkEmbedding(len(pose_idxs)*ndims,units,'pose')
        self.S = nn.Parameter(torch.randn(3))        
        self.fc = nn.Linear(384,384)
        self.pe = nn.Embedding(input_size+1, 384)
        
    def forward(self, x, non_empty_frame_idxs):
# (batch_size, input_size, ncols, ndims)
        batch_size,nframes,ncols,ndims = x.shape
        lips = torch.reshape(x[:,:,lips_idxs,:],(batch_size,nframes,len(lips_idxs)*ndims))
        hand = torch.reshape(x[:,:,hand_idxs,:],(batch_size,nframes,len(hand_idxs)*ndims))
        pose = torch.reshape(x[:,:,pose_idxs,:],(batch_size,nframes,len(pose_idxs)*ndims))
        print(lips.shape)
        lips_embedding = self.lips_embedding(lips)
        print(lips_embedding.shape)
        hand_embedding = self.hand_embedding(hand)
        pose_embedding = self.pose_embedding(pose)
        print(lips_embedding.shape)
        S = torch.softmax(self.S, dim=0)
        full_embedding = S[0]*lips_embedding + S[1]*hand_embedding + S[2]*pose_embedding
        print(full_embedding.shape)
        max_non_empty_frame_idxs = torch.max(non_empty_frame_idxs,dim=1,keepdim=True)[0]
        print(max_non_empty_frame_idxs.shape)
        max_non_empty_frame_idxs = torch.clamp(max_non_empty_frame_idxs, min=1, max=float('inf'))
        print(max_non_empty_frame_idxs.shape)
        normalised_non_empty_frame_idxs = torch.where(non_empty_frame_idxs<0,
                                                      torch.tensor(input_size),
                                                      non_empty_frame_idxs/max_non_empty_frame_idxs * input_size).int()
        print(normalised_non_empty_frame_idxs.shape)
        output = self.fc(full_embedding) + self.pe(normalised_non_empty_frame_idxs)
        return output

In [77]:
if __name__ == "__main__":
    embedding = Embedding(384)
    inputs = torch.randn(20, 10, 66, 2)
    idxs =torch.randint(low=0, high=10, size=(20, 10))
    print(embedding(inputs,idxs).shape)

torch.Size([20, 10, 80])
torch.Size([20, 10, 384])
torch.Size([20, 10, 384])
torch.Size([20, 10, 384])
torch.Size([20, 1])
torch.Size([20, 1])
torch.Size([20, 10])
torch.Size([20, 10, 384])


In [88]:
class Transformer(nn.Module):
    def __init__(self,n_blocks):
        super(Transformer,self).__init__()
        self.n_blocks = n_blocks
        self.mahs = [nn.MultiheadAttention(384, num_heads=2,batch_first=True) for _ in range(n_blocks)]
        self.mlps = [nn.Sequential(nn.Linear(384,20),nn.ReLU(),nn.Linear(20,384)) for _ in range(n_blocks)]
        self.embedding = Embedding(384)
        self.classifier = nn.Sequential(nn.Linear(384,250),nn.Softmax(dim=-1))
        
    def forward(self, x, padding_mask):
        print('padding_mask: ',padding_mask.shape) 
        inputs = self.embedding(x, padding_mask)
        print('inputs: ',inputs.shape) 
        for i in range(self.n_blocks):
            inputs = inputs + self.mahs[i](inputs,inputs,inputs,key_padding_mask=padding_mask)[0]
            inputs = inputs + self.mlps[i](inputs)
        print('inputs: ',inputs.shape)    
        mask = padding_mask.unsqueeze(2)
        output = torch.sum(inputs * mask, axis=1) / torch.sum(mask, axis=1)
        print('output: ',output.shape) 
        return self.classifier(output)

In [89]:
model = Transformer(2)

In [90]:
#(batch_size,input_size,ncols,ndims)
attn_mask = torch.randint(0, 10, (20,10))
print('attn:',attn_mask.shape)

attn: torch.Size([20, 10])


In [91]:
print(model(torch.randn(20, 10, 66, 2),attn_mask).shape)

padding_mask:  torch.Size([20, 10])
torch.Size([20, 10, 80])
torch.Size([20, 10, 384])
torch.Size([20, 10, 384])
torch.Size([20, 10, 384])
torch.Size([20, 1])
torch.Size([20, 1])
torch.Size([20, 10])
inputs:  torch.Size([20, 10, 384])
inputs:  torch.Size([20, 10, 384])
output:  torch.Size([20, 384])
torch.Size([20, 250])
