In [65]:
import torch
import torch.nn as nn
import numpy as np

# Acoustic Branch

Inputs for acoustic branch will be N x 40 where N [1,33]  
Time step: (2, 10) (seconds?)  
N: relative duration after feature extraction

In [82]:
class AcousticNet(nn.Module):
    def __init__(self, num_conv_layers = 3, kernel_size = 2, conv_width = 32, num_gru_layers = 2):
        super(AcousticNet, self).__init__()
        self.num_conv_layers = num_conv_layers
        self.conv1 = nn.Conv1d(in_channels=40, out_channels=conv_width, kernel_size=kernel_size, padding = kernel_size - 1)
        self.conv2 = nn.Conv1d(in_channels=conv_width, out_channels=conv_width, kernel_size=kernel_size, padding = kernel_size - 1)
        self.conv3 = nn.Conv1d(in_channels=conv_width, out_channels=conv_width, kernel_size=kernel_size, padding = kernel_size - 1)
        self.conv4 = nn.Conv1d(in_channels=conv_width, out_channels=conv_width, kernel_size=kernel_size, padding = kernel_size - 1)
        self.convs = [self.conv1, self.conv2, self.conv3, self.conv4]
        self.max_pool = nn.MaxPool1d(kernel_size = 2)
        self.relu = nn.ReLU()
        
        self.gru = nn.GRU(input_size=conv_width,hidden_size=32,num_layers=num_gru_layers) # 19 is hardcoded
        self.mean_pool = nn.AvgPool1d(kernel_size=2)
        
    def forward(self, x):
        for i in range(self.num_conv_layers):
            x = self.relu(self.max_pool(self.convs[i](x)))
        x = torch.transpose(x, 1, 2) 
        x, _ = self.gru(x)
        x = self.mean_pool(x)
        return x

In [83]:
# Test dummy input
net = AcousticNet(num_conv_layers = 4, num_gru_layers = 3, kernel_size = 3, conv_width = 128)
test_vec = torch.randn(10, 40, 17) # samples x features (or channels) x N (relative duration)
output = net(test_vec)
print(f'Shape of output: {output.shape}')
# assert output.shape[-1] == 16

Shape of output: torch.Size([10, 2, 16])


# Lexical Branch

In [84]:
# implement GRU (or transformer)
class LexicalNet(nn.Module):
    def __init__(self, num_gru_layers = 2):
        super(LexicalNet, self).__init__()
        # implement GRU (or transformer)
        self.gru = nn.GRU(input_size=300,hidden_size=32,num_layers=num_gru_layers)
        self.mean_pool = nn.AvgPool1d(kernel_size=2) 
        
    def forward(self, x):
        x, _ = self.gru(x)
        x = self.mean_pool(x)
        print(x.shape)
        return x

In [85]:
# Test dummy input
net = LexicalNet(num_gru_layers = 3)
test_vec = torch.randn(10, 1, 300)
output = net(test_vec)
assert output.shape[-1] == 16

torch.Size([10, 1, 16])


# Master branch

In [86]:
class MasterNet(nn.Module):
    def __init__(self, acoustic_modality = True, lexical_modality = True, visual_modality = False,
                 num_conv_layers = 3, kernel_size = 2, conv_width = 32, num_gru_layers = 2,
                 num_dense_layers = 1, dense_layer_width = 32, grl_lambda = .3):
        super(MasterNet, self).__init__()
        
        self.acoustic_modality = acoustic_modality
        self.lexical_modality = lexical_modality
        self.visual_modality = visual_modality
        
        self.acoustic_model = AcousticNet(num_conv_layers = num_conv_layers, kernel_size = kernel_size, 
                                     conv_width = conv_width, num_gru_layers = num_gru_layers)
        self.lexical_model = LexicalNet(num_gru_layers = 2)
        
        # emotion classifier
#         self.dense1_emo = nn.Linear()
#         self.dense2_emo = nn.Linear()
        
        width = 0 # width of the FC layers
        if self.acoustic_modality:
            width += 3
        if self.visual_modality:
            width += 0 # to implement
        if self.lexical_modality:
            width += 1
            
        self.fc_1 = nn.Linear(width, dense_layer_width)
        self.fc_2 = nn.Linear(dense_layer_width, 3)
        self.softmax = nn.Softmax(dim=1)

        self.relu = nn.ReLU()
#         # To implement   
#         if num_dense_layers == 2:
#             self.fc = nn.Sequential()
#             self.linear_1 = nn.Linear(width, dense_layer_width)
#         else:
#             self.fc = 
        
        # confound classifier -- to implement
        self.grl = None
        self.dense1_con = None
        self.dense2_con = None
        
        
    def forward_a(self, x_a):
        x = x_a
        x = self.acoustic_model(x)
        return x
    
    def forward_l(self, x_l):
        x = x_l
        x = self.lexical_model(x)
        return x
    
    def forward_v(self, x_v):
        x = x_v
        return x
    
    def encoder(self, x_v, x_a, x_l):
        print('x_a before encoding', x_a.shape)
        print('x_l before encoding', x_l.shape)
        if self.visual_modality:
            x_v = self.forward_v(x_v)
        if self.acoustic_modality:
            x_a = self.forward_a(x_a)
        if self.lexical_modality:
            x_l = self.forward_l(x_l)
        print('x_a after encoding', x_a.shape)
        print('x_l after encoding', x_l.shape)
        
        if self.visual_modality:
            if self.acoustic_modality:
                if self.lexical_modality:
                    x = torch.cat((x_v, x_a, x_l), 1)
                else:
                    x = torch.cat((x_v, x_a), 1)
            else:
                if self.lexical_modality:
                    x = torch.cat((x_v, x_l), 1)
                else:
                    x = x_v
        else:
            if self.acoustic_modality:
                if self.lexical_modality:
                    x = torch.cat((x_a, x_l), 1)
                else:
                    x = x_a
            else:
                x = x_l
        print('x after concat', x.shape)
        return x

    def recognizer(self, x):
        print(x.shape)
        x = self.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

    def forward(self, x_v, x_a, x_l):
        x = self.encoder(x_v, x_a, x_l)
        x = self.recognizer(x)
        return x

In [87]:
# Test dummy input
net = MasterNet()
acoustic_features = torch.randn(10, 40, 17) # samples x features (or channels) x N (relative duration)
lexical_features = torch.randn(10, 1, 300)
visual_features = None
output = net(visual_features, acoustic_features, lexical_features)
print(f'Shape of output: {output.shape}')
# assert output.shape[-1] == 16

x_a before encoding torch.Size([10, 40, 17])
x_l before encoding torch.Size([10, 1, 300])
torch.Size([10, 1, 16])
x_a after encoding torch.Size([10, 3, 16])
x_l after encoding torch.Size([10, 1, 16])
x after concat torch.Size([10, 4, 16])
torch.Size([10, 4, 16])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x16 and 4x32)