# working model for tensorfusion

In [21]:
import pickle
import torch

In [22]:
# data load
with open('./data/lou_dataset_1_3.pkl', 'rb') as f:
    raw_dataset = pickle.load(f)
raw_dataset['Session01'].keys()

dict_keys(['file_names', 'text_embeddings', 'wav_embeddings', 'Emotion', 'Arousal', 'Valence'])

In [23]:
# encoding Emotion
encode_dict = {b:i for i, b in enumerate(raw_dataset['Session01']['Emotion'].unique())}
decode_dict = {i:b for i, b in enumerate(raw_dataset['Session01']['Emotion'].unique())}
encode_dict, decode_dict    

({'neutral': 0,
  'happy': 1,
  'happy;neutral': 2,
  'surprise;neutral': 3,
  'happy;surprise': 4,
  'angry;neutral': 5},
 {0: 'neutral',
  1: 'happy',
  2: 'happy;neutral',
  3: 'surprise;neutral',
  4: 'happy;surprise',
  5: 'angry;neutral'})

In [24]:
raw_dataset['Session01']['Emotion'] = raw_dataset['Session01']['Emotion'].map(encode_dict)

# torch dataset 만들기
- 참고: https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html

In [25]:
import os
import pandas as pd
from datasets import Dataset
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split

In [26]:
class EtriDataset(Dataset):
    def __init__(self, file_names, text_embeddings, wav_embeddings, Emotion, Arousal, Valence):
        self.file_names = file_names
        self.text_embeddings = text_embeddings
        self.wav_embeddings = wav_embeddings
        self.label_emotion = Emotion
        self.label_arousal = Arousal
        self.label_valence = Valence
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        text_embeddings = self.text_embeddings[idx]
        wav_embeddings = self.wav_embeddings[idx]
        label_emotion = self.label_emotion[idx]
        label_arousal = self.label_arousal[idx]
        label_valence = self.label_valence[idx]
        return text_embeddings, wav_embeddings, label_emotion, label_arousal, label_valence

In [27]:
# data load 및 나누기: https://076923.github.io/posts/Python-pytorch-11/

dataset = EtriDataset(raw_dataset['Session01']['file_names'],
                      raw_dataset['Session01']['text_embeddings'],
                      raw_dataset['Session01']['wav_embeddings'],
                      raw_dataset['Session01']['Emotion'],
                      raw_dataset['Session01']['Arousal'],
                      raw_dataset['Session01']['Valence'])


In [28]:
dataset_size = len(dataset)
train_size = int(dataset_size * 0.7)
validation_size = int(dataset_size * 0.15)
test_size = dataset_size - train_size - validation_size

train_dataset, validation_dataset, test_dataset = random_split(dataset, [train_size, validation_size, test_size])

print(train_size, test_size, validation_size)
print(f"Training Data Size : {len(train_dataset)}")
print(f"Validation Data Size : {len(validation_dataset)}")
print(f"Testing Data Size : {len(test_dataset)}")

217 48 46
Training Data Size : 217
Validation Data Size : 46
Testing Data Size : 48


In [29]:
# data size
raw_dataset['Session01']['wav_embeddings'][0].shape

torch.Size([1, 49, 768])

In [30]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=4, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, drop_last=True)

# NetWork 만들기

In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [32]:
class MLPNetwork_pre(nn.Module):
    def __init__(self, input_length, input_width):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_length*input_width, 768)
        self.gelu1 = nn.GELU()
        self.bn1 = nn.BatchNorm1d(768)
        self.fc2 = nn.Linear(768, 512)
        self.gelu2 = nn.GELU()
        self.bn2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 32)
        self.gelu3 = nn.GELU()
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.gelu1(x)
        x = self.bn1(x)
        x = self.fc2(x)
        x = self.gelu2(x)
        x = self.bn2(x)
        x = self.fc3(x)
        output = self.gelu3(x)
        return output

class MLPNetwork_final(nn.Module):
    def __init__(self, input_length, input_width):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_length*input_width, 256)
        self.gelu1 = nn.GELU()
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 64)
        self.gelu2 = nn.GELU()
        self.bn2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, 6)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.gelu1(x)
        x = self.bn1(x)
        x = self.fc2(x)
        x = self.gelu2(x)
        x = self.bn2(x)
        output = self.fc3(x)
        
        return output


In [33]:
class TensorFusionMixer(nn.Module):
    def __init__(self, ModelA, ModelB):
        super().__init__()
        self.ModelA = ModelA
        self.ModelB = ModelB
        self.Model_mlp_final = MLPNetwork_final(32,32).to(device)
        self.softmax = nn.Softmax(dim=1)
        
    def tensor_fusion(self, batch_arr1, batch_arr2):
        fusion_matrix_lst = []
        for i, (arr1, arr2) in enumerate(zip(batch_arr1, batch_arr2)):
            outer_matrix = torch.outer(arr1, arr2)
            l, w = outer_matrix.shape
            outer_matrix = outer_matrix.view(1, l, w)
            fusion_matrix_lst.append(outer_matrix)
        fusion_matrix = torch.concat(fusion_matrix_lst)
        # print(fusion_matrix.shape)
        return fusion_matrix
        
    def forward(self, x1, x2):
        x1 = self.ModelA(x1)
        x2 = self.ModelB(x2)
        fusion_matrix = self.tensor_fusion(x1, x2) 
        x = self.Model_mlp_final(fusion_matrix)
        output = self.softmax(x)
        return output     

txt_input_length, txt_input_width = raw_dataset['Session01']['text_embeddings'][0].shape
_, wav_input_length, wav_input_width = raw_dataset['Session01']['wav_embeddings'][0].shape

# tf_mixer에 들어갈 wav mlp, txt mlp 선언
model_mlp_txt = MLPNetwork_pre(txt_input_length,txt_input_width).to(device)
model_mlp_wav = MLPNetwork_pre(wav_input_length,wav_input_width).to(device)

# 최종 모델 선언
model_tf_mixer = TensorFusionMixer(ModelA = model_mlp_txt, ModelB = model_mlp_wav).to(device)

print(model_tf_mixer)

TensorFusionMixer(
  (ModelA): MLPNetwork_pre(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=62976, out_features=768, bias=True)
    (gelu1): GELU(approximate='none')
    (bn1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): Linear(in_features=768, out_features=512, bias=True)
    (gelu2): GELU(approximate='none')
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc3): Linear(in_features=512, out_features=32, bias=True)
    (gelu3): GELU(approximate='none')
  )
  (ModelB): MLPNetwork_pre(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=37632, out_features=768, bias=True)
    (gelu1): GELU(approximate='none')
    (bn1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): Linear(in_features=768, out_features=512, bias=True)
    (gelu2): GELU(approximate='none')
    (bn2): BatchNorm1d(512, eps=1e-0

# 학습을 위한 train, test method 만들기

In [34]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X_txt, X_wav, y, _, _) in enumerate(dataloader): # data 순서: file_names, text_embeddings, wav_embeddings, label_emotion, label_arousal, label_valence
        
        # 예측 오류 계산
        X_txt, X_wav, y = X_txt.to(device), X_wav.to(device),y.type(torch.LongTensor).to(device)
        pred = model(X_txt, X_wav)
        loss = loss_fn(pred, y)

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X_txt)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [35]:
def test(dataloader, model, loss_fn, mode = 'test'):
       
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch, (X_txt, X_wav, y, _, _) in enumerate(dataloader): # text_embeddings, wav_embeddings, label_emotion, label_arousal, label_valence
            # 예측 오류 계산
            X_txt, X_wav, y = X_txt.to(device), X_wav.to(device),y.type(torch.LongTensor).to(device)
            pred = model(X_txt, X_wav)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_loss /= num_batches
    correct /= size
    if mode == 'test':
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    elif mode == 'val':
        print(f"Validation Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# 학습시키기

In [36]:
# Set the Training Parameters
lr = 1e-3
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model_tf_mixer.parameters(), lr=lr)

epochs = 30
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader, model_tf_mixer, loss_fn, optimizer)
    test(validation_dataloader, model_tf_mixer, loss_fn, mode = 'val')
print("Done!")

Epoch 1
-------------------------------
loss: 1.805170  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.784804 

Epoch 2
-------------------------------
loss: 1.784523  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.784104 

Epoch 3
-------------------------------
loss: 1.785294  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.783390 

Epoch 4
-------------------------------
loss: 1.780635  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.782096 

Epoch 5
-------------------------------
loss: 1.788549  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.781471 

Epoch 6
-------------------------------
loss: 1.781594  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.781223 

Epoch 7
-------------------------------
loss: 1.779850  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg loss: 1.779849 

Epoch 8
-------------------------------
loss: 1.780122  [    0/  217]
Validation Error: 
 Accuracy: 0.0%, Avg l

## 검증

In [38]:
test(test_dataloader, model_tf_mixer, loss_fn, mode = 'test')

Test Error: 
 Accuracy: 83.3%, Avg loss: 1.764071 



In [39]:
 for batch, (a,b,c,d,e) in list(enumerate(test_dataloader))[:3]:
     print(a,b,c,d,e)

tensor([[[-1.4036, -0.8254,  0.3374,  ..., -0.4919,  0.1812,  0.9015],
         [-1.5124, -1.0119,  1.0906,  ..., -0.4881,  1.2984,  0.6860],
         [-2.3484, -1.5205,  0.8371,  ...,  0.4671,  0.7323,  1.5519],
         ...,
         [-1.3646, -0.5572,  0.7691,  ..., -0.4963,  0.5333,  0.3470],
         [-1.3499, -0.6444,  0.7400,  ..., -0.7245,  0.5905,  0.6070],
         [-1.2281, -0.1519,  0.4446,  ..., -0.4954,  0.4255,  0.4170]],

        [[ 0.0683, -1.3784,  2.1355,  ..., -0.5110,  0.3518,  0.3829],
         [-0.1755, -0.6976,  0.9760,  ..., -0.2128,  0.3653,  0.1003],
         [-0.7046, -0.6705,  1.0717,  ..., -0.2769,  0.1023,  0.6022],
         ...,
         [ 0.4858, -0.2909,  1.9416,  ..., -0.2216,  1.1318,  0.3733],
         [ 0.2900,  0.4255,  2.0321,  ..., -0.2868,  1.1614,  0.6492],
         [ 0.3786, -0.3927,  1.8511,  ..., -0.0986,  1.4052,  0.5346]],

        [[-0.3103, -0.6272,  0.8433,  ..., -1.4785, -0.8344,  0.9614],
         [ 0.8770, -0.1079,  1.3329,  ..., -0

In [48]:
probs = model_tf_mixer(a.to(device), b.to(device))
print(probs)
for i in torch.argmax(probs, dim=1):
    print(decode_dict[int(i)])

tensor([[0.2029, 0.1564, 0.1491, 0.1530, 0.1398, 0.1988],
        [0.2029, 0.1564, 0.1493, 0.1531, 0.1397, 0.1987],
        [0.2031, 0.1564, 0.1492, 0.1530, 0.1398, 0.1986],
        [0.2029, 0.1564, 0.1493, 0.1531, 0.1398, 0.1987]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)
neutral
neutral
neutral
neutral
