In [1]:
import torch 
import os
import numpy as np 
import pandas as pd
from transformers import AutoFeatureExtractor, ASTForAudioClassification
from datasets import load_dataset

from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torchaudio
from transformers import AutoModelForAudioClassification
from torchaudio.transforms import Resample
from transformers import ASTFeatureExtractor

import pickle

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model path for video data
mod_path = 'C:\\Users\\zhuld\\Dropbox\\Projects\\EAV_codes\\facial_emotions_image_detection'

# class for downloading audio
class DataLoadAudio:
    def __init__(self, subject='all', parent_directory=r'C:\\Users\\zhuld\\Dropbox\\DATASETS\\EAV', target_sampling_rate=16000):
        self.parent_directory = parent_directory
        self.original_sampling_rate = int()
        self.target_sampling_rate = target_sampling_rate
        self.subject = subject
        self.file_path = list()
        self.file_emotion = list()

        self.seg_length = 5  # 5s
        self.feature = None
        self.label = None
        self.label_indexes = None
        self.test_prediction = list()

    def data_files(self):
        subject = f'subject{self.subject:02d}'
        file_emotion = []
        subjects = []
        path = os.path.join(self.parent_directory, subject, 'Audio')
        for i in os.listdir(path):
            emotion = i.split('_')[4]
            self.file_emotion.append(emotion)
            self.file_path.append(os.path.join(path, i))

    def feature_extraction(self):
        x = []
        y = []
        feature_extractor = ASTFeatureExtractor()
        for idx, path in enumerate(self.file_path):
            waveform, sampling_rate = torchaudio.load(path)
            self.original_sampling_rate = sampling_rate
            if self.original_sampling_rate is not self.target_sampling_rate:
                resampler = Resample(orig_freq=sampling_rate, new_freq=self.target_sampling_rate)
                resampled_waveform = resampler(waveform)
                resampled_waveform = resampled_waveform.squeeze().numpy()
            else:
                resampled_waveform = waveform

            segment_length = self.target_sampling_rate * self.seg_length
            num_sections = int(np.floor(len(resampled_waveform) / segment_length))

            for i in range(num_sections):
                t = resampled_waveform[i * segment_length: (i + 1) * segment_length]
                x.append(t)
                y.append(self.file_emotion[idx])
        print(f"Original sf: {self.original_sampling_rate}, resampled into {self.target_sampling_rate}")

        emotion_to_index = {
            'Neutral': 0,
            'Happiness': 3,
            'Sadness': 1,
            'Anger': 2,
            'Calmness': 4
        }
        y_idx = [emotion_to_index[emotion] for emotion in y]
        #print('shape of x: ', np.array(x).shape)
        #print(feature_extractor(np.array(x), sampling_rate=16000, padding="max_length", return_tensors="pt").input_values.shape)
        self.feature = np.squeeze(np.array(x))
        self.label_indexes = np.array(y_idx)
        self.label = np.array(y)

    def process(self):
        self.data_files()
        feature_extractor = ASTFeatureExtractor()
        self.feature_extraction()
        return self.feature, self.label_indexes

    def label_emotion(self):
        self.data_files()
        self.feature_extraction()
        return self.label

class EAVDataSplit:
    def __init__(self, x, y, batch_size=32):
        self.x = np.array(x)
        self.y = np.array(y)
        self.batch_size = batch_size


    def _split_features_labels(self):
        # Splitting features and labels based on class, select each 80 samples per class in order
        features = []
        labels = []
        for class_idx in range(5):  # Assuming there are 5 classes
            class_mask = np.where(self.y == class_idx)
            class_features = self.x[class_mask]
            class_labels = self.y[class_mask]

            features.append(class_features)
            labels.append(class_labels)

        return features, labels

    def get_split(self, h_idx=40): # update it if you want to use different ratio here we have 50/50
        [features, labels] = self._split_features_labels()
        # Splitting into training and testing
        train_features = np.concatenate([cls_features[:h_idx] for cls_features in features], axis=0)
        test_features = np.concatenate([cls_features[h_idx:] for cls_features in features], axis=0)
        train_labels = np.concatenate([cls_labels[:h_idx] for cls_labels in labels], axis=0)
        test_labels = np.concatenate([cls_labels[h_idx:] for cls_labels in labels], axis=0)

        #
        train_features = np.squeeze(train_features)
        test_features = np.squeeze(test_features)
        train_labels = train_labels
        test_labels = test_labels

        return train_features, train_labels, test_features, test_labels

    def get_loaders(self):
        self._split_features_labels()
        train_features, train_labels, test_features, test_labels = self.get_split()

        train_features = torch.Tensor(np.squeeze(train_features))
        test_features = torch.Tensor(np.squeeze(test_features))
        train_labels = torch.Tensor(train_labels).long()  # Using .long() for labels
        test_labels = torch.Tensor(test_labels).long()

        # Creating TensorDatasets and DataLoaders
        train_dataset = TensorDataset(train_features, train_labels)
        test_dataset = TensorDataset(test_features, test_labels)

        loader_train = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        loader_test = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        return loader_train, loader_test

class VideoAudioDataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        rgb = self.inputs['rgb'][idx]
        spec = self.inputs['spectrogram'][idx]
        label = self.labels[idx]
        return {'rgb': rgb, 'spectrogram': spec}, label

def prepare_dataloader(x, y, batch_size, shuffle=False):
    dataset = VideoAudioDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modalities = ['rgb', 'spectrogram']

In [3]:
class TransformerEncoder(nn.Module):
    def __init__(self, emb_dim, num_layers, n_bottleneck_tokens, fusion_layer):
        super().__init__()
        
        self.emb_dim = emb_dim 
        self.num_layers = num_layers
        self.fusion_layer = fusion_layer 
        self.n_bottleneck_tokens = n_bottleneck_tokens
        self.encoders = nn.ModuleDict()
        
        for modality in modalities: 
            if modality == 'rgb':
                self.encoders[modality] = AutoModelForImageClassification.from_pretrained(mod_path).vit.encoder  
            elif modality == 'spectrogram': 
                self.encoders[modality] = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").audio_spectrogram_transformer.encoder
                
        self.norm_rgb = nn.LayerNorm(emb_dim)
        self.norm_spec = nn.LayerNorm(emb_dim)

    def forward(self, inputs, bottleneck):
        x = inputs 
        # assume rgb inputs of shape (batch x frames) x seq_len x emb_dim
        # audio inputs of shape batch x seq_len x emb_dim 
        encoders = {}
        x_combined = []
        fusion_layer = self.fusion_layer  

        for layer in range(self.num_layers):
            if layer < fusion_layer: 
                for modality in modalities: 
                    x[modality] = self.encoders[modality].layer[layer](x[modality])[0]
            else: 
                bottle = []
                for modality in modalities:  
                    if modality == 'rgb':
                        bottleneck_expanded = bottleneck.unsqueeze(1).expand(-1, 25, -1, -1).reshape(-1, self.n_bottleneck_tokens, self.emb_dim)
                        
                    else:
                        bottleneck_expanded = bottleneck 
                    t_mod = x[modality].shape[1]
                    in_mod = torch.cat([x[modality], bottleneck_expanded], dim=1)
                    out_mod = self.encoders[modality].layer[layer](in_mod)[0]
                    x[modality] = out_mod[:, :t_mod]
                    if modality == 'rgb':
                        bottle.append(torch.mean(out_mod[:, t_mod:].view(-1, 25, self.n_bottleneck_tokens, self.emb_dim), dim=1)) # average accross frames 
                    else: 
                        bottle.append(out_mod[:, t_mod:])
                bottleneck = torch.mean(torch.stack(bottle, dim=-1), dim=-1)
          
        encoded_rgb = self.norm_rgb(x['rgb'])
        encoded_spec = self.norm_spec(x['spectrogram'])
        encoded = {
            'rgb': encoded_rgb,
            'spectrogram': encoded_spec
        }
        return encoded

class MBT(nn.Module):
    def __init__(self, mlp_dim, num_classes, num_layers, 
                 hidden_size, fusion_layer, 
                 representation_size=None, 
                 return_prelogits=False, return_preclassifier=False,
                 ):
        super(MBT, self).__init__()

        self.mlp_dim = mlp_dim
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.num_layers = num_layers 
        self.representation_size = representation_size
        self.return_prelogits = return_prelogits
        self.return_preclassifier = return_preclassifier

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))

        self.temporal_encoder_audio = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").audio_spectrogram_transformer.embeddings
        self.temporal_encoder_rgb = AutoModelForImageClassification.from_pretrained(mod_path).vit.embeddings

        self.n_bottlenecks = 4  
        self.bottleneck = nn.Parameter(torch.randn(1, self.n_bottlenecks, self.hidden_size) * 0.02)

        self.fusion_layer = fusion_layer 
        
        self.encoder = TransformerEncoder(
            self.hidden_size, self.num_layers,  
            self.n_bottlenecks, self.fusion_layer 
        )

        if self.representation_size is not None:
            self.pre_logits = nn.Linear(hidden_size, self.mlp_dim)
            self.activation = nn.Tanh()
            #self.activation = nn.Relu()
            self.dropout = nn.Dropout(0.1)
        else:
            self.pre_logits = None

        self.output_projection = nn.Linear(self.mlp_dim, num_classes)

    def forward(self, x):
        # x should be a dict with rbg and spec as modalities (keys)
        # rgb input of shape (batch x channels x frames x image_size x image_size)
        # spectrogram input of shape (batch x channels x image_size x image_size)
        for modality in modalities: 
            if modality == 'spectrogram':
                x[modality] = self.temporal_encoder_audio(x[modality])
            if modality == 'rgb':
                x[modality] = self.temporal_encoder_rgb(x[modality])

        batch, _, _ = x['spectrogram'].shape
        bottleneck_expanded = self.bottleneck.expand(batch, -1, -1)
        
        # now add cls token 
        temporal_dims = {}
        
        encoded = self.encoder(x, bottleneck=bottleneck_expanded)
        
        if self.return_preclassifier:
            return encoded

        x_out = {}
        counter = 0 

        for modality in modalities: 
            if modality == 'rgb':
                batch_size, _, _ = encoded['spectrogram'].size()
                cls_tok = encoded[modality][:, 0]
                cls_tok = cls_tok.view(batch_size, 25, -1)
                features = cls_tok.mean(dim=1)
                x_out[modality] = features
            if modality == 'spectrogram':
                x_out[modality] = encoded[modality][:, 0]
    
        
        if self.representation_size is not None:
            for modality in x_out: 
                x_out[modality] = self.pre_logits(x_out[modality])
                x_out[modality] = self.activation(x_out[modality])
                x_out[modality] = self.dropout(x_out[modality])

        if self.return_prelogits:
            return x

        x_pool = 0 
        for modality in x_out: 
            x_out[modality] = self.output_projection(x_out[modality]) 
            x_pool += x_out[modality]
        x_pool /= len(x_out)
        if not self.training: 
            return x_pool 
        return x_out 

In [4]:
device

device(type='cuda')

In [5]:
accuracies = []
subjects = []

for sub in range(12, 43):
    # load vision data 
    file_path = r'C:\Users\zhuld\Dropbox\DATASETS\EAV\Input_images\Vision'
    file_name = f"subject_{sub:02d}_vis.pkl"
    file_ = os.path.join(file_path, file_name)
    if os.path.exists(file_):
        with open(file_, 'rb') as f: 
            vis_list = pickle.load(f)
        tr_x_vis, tr_y_vis, te_x_vis, te_y_vis = vis_list 
        data_video = [tr_x_vis, tr_y_vis, te_x_vis, te_y_vis] 
    else:
        print('Does not exist')

    processor = AutoImageProcessor.from_pretrained(mod_path)
    pixel_values_list = []
    for img_set in tr_x_vis:
        for img in img_set:
            processed = processor(images=img, return_tensors="pt")
            pixel_values = processed.pixel_values.squeeze()
            pixel_values_list.append(pixel_values)
    
    vals = torch.stack(pixel_values_list)
    
    pixel_values_list = []
    for img_set in te_x_vis:
        for img in img_set:
            processed = processor(images=img, return_tensors="pt")
            pixel_values = processed.pixel_values.squeeze()
            pixel_values_list.append(pixel_values)
    
    vals_test = torch.stack(pixel_values_list)

    # load audio data
    file_path = r'C:\Users\zhuld\Dropbox\DATASETS\EAV\Input_images\Audio'
    file_name = f"subject_{sub:02d}_aud.pkl"
    file_ = os.path.join(file_path, file_name)
    if os.path.exists(file_):
        with open(file_, 'rb') as f: 
            aud_list = pickle.load(f)
        tr_x_aud, tr_y_aud, te_x_aud, te_y_aud = aud_list 
        data_audio = [tr_x_aud, tr_y_aud, te_x_aud, te_y_aud] 
    else:
        print('Does not exist')

    feature_extractor = ASTFeatureExtractor()
    resampler = Resample(orig_freq=44100, new_freq=16000)
    
    tr_x_aud = feature_extractor(tr_x_aud, sampling_rate=16000, padding="max_length", return_tensors="pt")
    tr_x_aud = tr_x_aud.input_values.permute(0, 2, 1)
    
    te_x_aud = feature_extractor(te_x_aud, sampling_rate=16000, padding="max_length", return_tensors="pt")
    te_x_aud = te_x_aud.input_values.permute(0, 2, 1)

    data = [
    {'rgb': vals.view(-1, 25, 3, 224, 224), 
    'spectrogram': tr_x_aud},
    torch.from_numpy(tr_y_vis).long(), 
    {'rgb': vals_test.view(-1, 25, 3, 224, 224), 
    'spectrogram': te_x_aud},
    torch.from_numpy(te_y_vis).long(), 
]

    tr_x, tr_y, te_x, te_y = data
    
    train_dataloader = prepare_dataloader(tr_x, tr_y, batch_size=4, shuffle=True)
    test_dataloader = prepare_dataloader(te_x, te_y, batch_size=4, shuffle=False)

    # initialize the model
    mbt = MBT(
        mlp_dim=3072, num_classes=5, num_layers=12, 
        hidden_size=768, fusion_layer=8, representation_size=256,
        return_prelogits=False, return_preclassifier=False
    )

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(mbt.parameters(), lr=0.00001)
    mbt = mbt.to(device)

    epochs = 10
    best_accuracy, best_epoch = 0.0, 0 
    
    for epoch in range(1, epochs+1): 
        mbt.train()
        train_correct, train_total = 0, 0
        running_loss = 0.0 
        
        for i, (inputs, labels) in enumerate(train_dataloader):
            optimizer.zero_grad()
            inputs['rgb'] = inputs['rgb'].view(-1, 3, 224, 224)
            inputs = {k: v.float().to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            outputs = mbt(inputs)
            if isinstance(outputs, dict):
                ce_loss = []
                for mod in outputs:
                    ce_loss.append(criterion(outputs[mod], labels))
            else:
                break         
                
            loss = torch.mean(torch.stack(ce_loss))
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
    
            if i % 1 == 0: 
                print(f'Epoch {epoch}, loss: {running_loss / 10:.4f}')
                running_loss = 0 
    
        mbt.eval()
        correct, total = 0, 0
        outputs_batch = []
        with torch.no_grad():
            for inputs, labels in test_dataloader:
                inputs['rgb'] = inputs['rgb'].view(-1, 3, 224, 224)
                inputs = {k: v.float().to(device) for k, v in inputs.items()}
                labels = labels.to(device)
                outputs = mbt(inputs)
                correct += (outputs.argmax(dim=-1) == labels).sum().item()
                total += labels.size(0)
            
    
        test_accuracy = correct / total
        print(f'Test accuracy after epoch {epoch}: {test_accuracy:.4f}')
    
        if test_accuracy > best_accuracy: 
            best_accuracy = test_accuracy 
            best_epoch = epoch 
            #torch.save(mbt, f'sub{sub}_mbt{epoch}epoch')
    
        
    print(f'finished training. Best test set accuracy for subject {sub} is: {best_accuracy:.4f} at epoch {best_epoch}')
    subjects.append(sub)
    accuracies.append(best_accuracy)

    del tr_x_vis, tr_y_vis, te_x_vis, te_y_vis, data_video
    del tr_x_aud, tr_y_aud, te_x_aud, te_y_aud, data_audio
    del tr_x, tr_y, te_x, te_y, data
    del train_dataloader, test_dataloader, mbt, criterion, optimizer
    torch.cuda.empty_cache()  

performance = {'Subject': subjects, 'Accuracy': accuracies}
perf = pd.DataFrame(performance)

print(perf)
perf.to_csv('bottleneck_subject_accuracies10epochs.csv', index=False)
print('done')

Epoch 1, loss: 0.1699
Epoch 1, loss: 0.1574
Epoch 1, loss: 0.1625
Epoch 1, loss: 0.1818
Epoch 1, loss: 0.1701
Epoch 1, loss: 0.1606
Epoch 1, loss: 0.1803
Epoch 1, loss: 0.1650
Epoch 1, loss: 0.1609
Epoch 1, loss: 0.1497
Epoch 1, loss: 0.1811
Epoch 1, loss: 0.1631
Epoch 1, loss: 0.1456
Epoch 1, loss: 0.1566
Epoch 1, loss: 0.1460
Epoch 1, loss: 0.1496
Epoch 1, loss: 0.1376
Epoch 1, loss: 0.1486
Epoch 1, loss: 0.1476
Epoch 1, loss: 0.1527
Epoch 1, loss: 0.1472
Epoch 1, loss: 0.1565
Epoch 1, loss: 0.1372
Epoch 1, loss: 0.1468
Epoch 1, loss: 0.1462
Epoch 1, loss: 0.1515
Epoch 1, loss: 0.1341
Epoch 1, loss: 0.1506
Epoch 1, loss: 0.1449
Epoch 1, loss: 0.1653
Epoch 1, loss: 0.1359
Epoch 1, loss: 0.1322
Epoch 1, loss: 0.1072
Epoch 1, loss: 0.1315
Epoch 1, loss: 0.1614
Epoch 1, loss: 0.1260
Epoch 1, loss: 0.1495
Epoch 1, loss: 0.1251
Epoch 1, loss: 0.1146
Epoch 1, loss: 0.1299
Epoch 1, loss: 0.1363
Epoch 1, loss: 0.1165
Epoch 1, loss: 0.1383
Epoch 1, loss: 0.1110
Epoch 1, loss: 0.1156
Epoch 1, l