## Inference

### Config

In [None]:
import os
import clip
import torch
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)
from transformers import Wav2Vec2Processor
from torch import nn
import numpy as np
import librosa
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

PROJECT_PATH = os.path.join('/', *os.getcwd().split(os.sep)[:-2])
# EMOTION_FEATURES_SAVE_PATH, the path to save the extracted features, e.g. EPAlign/mmefeature/tmp
EMOTION_FEATURES_SAVE_PATH = os.path.join(PROJECT_PATH, 'EPAlign', 'mmefeature', 'tmp')
os.makedirs(EMOTION_FEATURES_SAVE_PATH, exist_ok=True)
# PRETRAIN_MODEL is the pretrained model name, e.g. ViT-B/32
PRETRAIN_MODEL = "ViT-B/32"
# PRETRAIN_WAV2VEC2_PATH is the pretrained model path, e.g. EPAlign/ckpt/base/wav2vec2
PRETRAIN_WAV2VEC2_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base/wav2vec2"
# PRETRAIN_MODEL_PATH is the pretrained model path, e.g. EPAlign/ckpt/base
PRETRAIN_MODEL_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base"
# PROCESSED_WAV2VEC2_PATH is the path to the Wav2Vec2Processor
PROCESSED_WAV2VEC2_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base/wav2vec2"

### Explict point out speech emotions

In [None]:
# emotiontags, probable emotions, e.g. "angry", "happy", "neutral", "sad", "surprise" (appear in ESD dataset)
emotiontags = ["angry", "happy", "neutral", "sad", "surprise"]

DATASET = "ESD"
# FINETUNE_MODEL is the finetuned model path, e.g. EPAlign/ckpt/ESD/best_model.pt
FINETUNE_MODEL = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}/best_model.pt"

model, preprocess = clip.load(PRETRAIN_MODEL, device=device, jit=False, download_root=PRETRAIN_MODEL_PATH)
model.load_state_dict(torch.load(FINETUNE_MODEL))

test_prompts = [f'A person speaking with a feeling of {emo}' for emo in emotiontags]
emo_prompt = clip.tokenize(test_prompts).to(device)

with torch.no_grad():
    emo_prompt_features = model.encode_text(emo_prompt)
    emo_prompt_features /= emo_prompt_features.norm(dim=-1, keepdim=True)

emo_prompt_features = emo_prompt_features.split(1, dim=0)

os.makedirs(f'{EMOTION_FEATURES_SAVE_PATH}/explict', exist_ok=True)
for i in range(len(emo_prompt_features)):
    torch.save(emo_prompt_features[i], f"{EMOTION_FEATURES_SAVE_PATH}/explict/{emotiontags[i]}.pt")
    print(f"Save {emotiontags[i]} feature in {EMOTION_FEATURES_SAVE_PATH}/explict/{emotiontags[i]}.pt")

### Implict point out speech emotion (with Audio)

In [None]:
# emotiontags, probable emotions, e.g. "angry", "happy", "neutral", "sad", "surprise" (appear in ESD dataset)
emotiontags = ["angry", "happy", "neutral", "sad", "surprise"]
# test_wav_path is the path to the wav file, e.g. test/wav/test_audio.wav
test_wav_path = f"{PROJECT_PATH}/EPAlign/test/wav/test_audio.wav"

DATASET = "ESD"
# FINETUNE_MODEL is the finetuned model path, e.g. EPAlign/ckpt/ESD/best_model.pt
FINETUNE_MODEL = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}/best_model_proj_logit.pt"
test_prompts = [f'A person speaking with a feeling of {emo}' for emo in emotiontags]

class CLAP(Wav2Vec2PreTrainedModel):
    def __init__(self, config, prompt_pretrain_model, prompt_pretrain_model_path):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.init_weights()
        width = 1024
        scale = width ** -0.5
        self.proj = nn.Parameter(scale * torch.randn(width, 512))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.prompt_model, self.prompt_processor = clip.load(prompt_pretrain_model, jit=False, download_root=prompt_pretrain_model_path)
        self.prompt_model.to(device)
    def forward(self, wavs, prompts):
        audio_features = torch.tensor([]).to(device)
        for wav in wavs:
            audio_feature = self.wav2vec2(wav)
            audio_feature = audio_feature[0]
            audio_feature = torch.mean(audio_feature, dim=1)
            audio_features = torch.cat((audio_features, audio_feature), dim=0)
        audio_features = audio_features @ self.proj

        prompt_features = clip.tokenize(prompts).to(device)
        prompt_features = self.prompt_model.encode_text(prompt_features)
        # normalized features
        audio_features = audio_features / audio_features.norm(dim=1, keepdim=True)
        prompt_features = prompt_features / prompt_features.norm(dim=1, keepdim=True)
        audio_features = audio_features.float()
        prompt_features = prompt_features.float()

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp().float()
        logits_per_audio = logit_scale * audio_features @ prompt_features.t()
        logits_per_text = logits_per_audio.t()
        return logits_per_audio, logits_per_text

model = CLAP.from_pretrained(PRETRAIN_WAV2VEC2_PATH, prompt_pretrain_model=PRETRAIN_MODEL, prompt_pretrain_model_path=PRETRAIN_MODEL_PATH).to(device)
model.load_state_dict(torch.load(FINETUNE_MODEL))


test_wav, _ = librosa.load(test_wav_path, sr=16000)
# audio = torch.from_numpy(wav).float()
audio_processor = Wav2Vec2Processor.from_pretrained(PROCESSED_WAV2VEC2_PATH)
test_audio = audio_processor(test_wav, sampling_rate=16000)
test_audio = test_audio["input_values"][0]
test_audio = test_audio.reshape(1, -1)
test_audio = torch.from_numpy(test_audio).to(device).float()
with torch.no_grad():
    logits_per_audio, _ = model(test_audio.unsqueeze(0), test_prompts)
    probs = logits_per_audio.softmax(dim=-1).cpu().numpy()
    print(f"Predicted emotion: {emotiontags[np.argmax(probs)]}")
    os.makedirs(f'{EMOTION_FEATURES_SAVE_PATH}/implict_audio', exist_ok=True)
    torch.save(emo_prompt_features[np.argmax(probs)].squeeze(), f"{EMOTION_FEATURES_SAVE_PATH}/implict_audio/{test_wav_path.split("/")[-1][:-4]}.pt")
    print(f"Save {emotiontags[np.argmax(probs)]} feature in {EMOTION_FEATURES_SAVE_PATH}/implict_audio/{test_wav_path.split("/")[-1][:-4]}.pt")

### Implict point out speech emotion (with Image)

In [None]:
# emotiontags, probable emotions, e.g. "Surprise", "Fear", "Disgust", "Happiness", "Sadness", "Anger", "Neutral"  (appear in RAF-DB dataset)
emotiontags = ["Surprise", "Fear", "Disgust", "Happiness", "Sadness", "Anger", "Neutral"]
test_prompts = [f'A person speaking with a feeling of {emo}' for emo in emotiontags]
# test_img_path is the path to the img file, e.g. test/img/test_img.jpg
test_img_path = f"{PROJECT_PATH}/EPAlign/test/img/test_img.jpg"

DATASET = "RAF"
# FINETUNE_MODEL is the finetuned model path, e.g. EPAlign/ckpt/RAF/RAF_ft.pt
FINETUNE_MODEL = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}/RAF_ft.pt"
model, preprocess_img = clip.load(PRETRAIN_MODEL, device=device, jit=False, download_root=PRETRAIN_MODEL_PATH)
model.load_state_dict(torch.load(FINETUNE_MODEL))
with torch.no_grad():
    img = Image.open(test_img_path)
    logits_per_image, _ = model(preprocess_img(img).unsqueeze(0).to(device), clip.tokenize(test_prompts).to(device))
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    print(f"Predicted emotion: {emotiontags[np.argmax(probs)]}")
    os.makedirs(f'{EMOTION_FEATURES_SAVE_PATH}/implict_img', exist_ok=True)
    torch.save(emo_prompt_features[np.argmax(probs)].squeeze(), f"{EMOTION_FEATURES_SAVE_PATH}/implict_img/{test_img_path.split("/")[-1][:-4]}.pt")
    print(f"Save {emotiontags[np.argmax(probs)]} feature in {EMOTION_FEATURES_SAVE_PATH}/implict_img/{test_img_path.split("/")[-1][:-4]}.pt")

### Implict point out speech emotion (with Text)

In [None]:
# emotiontags, probable emotions, e.g. "neutral", "joy", "sad", "angry", "surprise", "fearful", "disgust" (appear in MELD dataset)
emotiontags = ["neutral", "joy", "sad", "angry", "surprise", "fearful", "disgust"]
test_prompts = [f'A person speaking with a feeling of {emo}' for emo in emotiontags]
# test_text_f_path is the path to the text feature file, e.g. test/img/test_text_f.pt
test_text_f_path = f"{PROJECT_PATH}/EPAlign/test/text_f/test_text_f.pt"

DATASET = "MELD"
# FINETUNE_MODEL is the finetuned model path, e.g. EPAlign/ckpt/MELD/MELD_text_ft.pt
FINETUNE_MODEL = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}/MELD_text_ft.pt"

class Concat_text_Model(nn.Module):
    def __init__(self, 
                 input_label_feature_dim=512,
                 input_text_feature_dim=4096,
                 fused_dim=512,
                 num_heads=8,
                 ):
        super(Concat_text_Model, self).__init__()
        self.input_label_feature_dim = input_label_feature_dim
        self.input_text_feature_dim = input_text_feature_dim

        self.fused_dim = fused_dim

        self.num_heads = num_heads
        # self.text_linear = nn.Linear(self.input_text_feature_dim, self.fused_dim)
        # self.atten = nn.MultiheadAttention(3 * self.fused_dim, self.num_heads)
        scale = self.fused_dim ** -0.5
        self.fuse_proj = nn.Parameter(scale * torch.randn(self.input_text_feature_dim, self.fused_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.label_linear = nn.Linear(self.input_label_feature_dim, self.fused_dim)

    def forward(self, text_features, label_features):
        # text_features: (batch_size, seq_len, input_text_feature_dim) seq_len = 1 e.g.
        label_features = self.label_linear(label_features)
        text_features = text_features @ self.fuse_proj
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        label_features = label_features / label_features.norm(dim=1, keepdim=True)
        text_features = text_features.float()
        label_features = label_features.float()
        logit_scale = self.logit_scale.exp().float()
        logits_per_fused = logit_scale * text_features @ label_features.t()
        logits_per_label = logits_per_fused.t()

        return logits_per_fused, logits_per_label
    
    def extract_text_feature(self, text_features):
        # text_features = self.text_linear(text_features)
        text_features = text_features @ self.fuse_proj
        return text_features

model = Concat_text_Model().to(device)
model.load_state_dict(torch.load(FINETUNE_MODEL))

with torch.no_grad():
    text_feature = torch.load(test_text_f_path)
    text_feature = text_feature.unsqueeze(0)
    logits_per_fused, _ = model(text_feature.to(device), clip.tokenize(test_prompts).to(device))
    probs = logits_per_fused.softmax(dim=-1).cpu().numpy()
    print(f"Predicted emotion: {emotiontags[np.argmax(probs)]}")
    os.makedirs(f'{EMOTION_FEATURES_SAVE_PATH}/implict_text', exist_ok=True)
    torch.save(emo_prompt_features[np.argmax(probs)].squeeze(), f"{EMOTION_FEATURES_SAVE_PATH}/implict_text/{test_text_f_path.split("/")[-1]}")
    print(f"Save {emotiontags[np.argmax(probs)]} feature in {EMOTION_FEATURES_SAVE_PATH}/implict_text/{test_text_f_path.split("/")[-1]}")

### Implict point out speech emotion (with Text & Video & Audio)

In [None]:
# emotiontags, probable emotions, e.g. "neutral", "joy", "sad", "angry", "surprise", "fearful", "disgust" (appear in MELD dataset)
emotiontags = ["neutral", "joy", "sad", "angry", "surprise", "fearful", "disgust"]
test_prompts = [f'A person speaking with a feeling of {emo}' for emo in emotiontags]
# test_text_f_path is the path to the text feature file, e.g. test/img/test_text_f.pt
test_text_f_path = f"{PROJECT_PATH}/EPAlign/test/text_f/test_text_f.pt"
test_visual_f_path = f"{PROJECT_PATH}/EPAlign/test/visual_f/test_visual_f.pt"
test_audio_f_path = f"{PROJECT_PATH}/EPAlign/test/audio_f/test_audio_f.pt"

DATASET = "MELD"
# FINETUNE_MODEL is the finetuned model path, e.g. EPAlign/ckpt/MELD/MELD_fuse_ft.pt
FINETUNE_MODEL = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}/MELD_fuse_ft.pt"

class Concat_SelfAttention_Model(nn.Module):
    def __init__(self, 
                 input_text_feature_dim=4096,
                 input_visual_feature_dim=512,
                 input_audio_feature_dim=512,
                 fused_dim=512,
                 is_prompt_linear=False,
                 is_text_linear=True,
                 is_visual_linear=False,
                 is_audio_linear=False,
                 num_heads=8,
                 prompt_pretrain_model="",
                 prompt_pretrain_model_path=""
                 ):
        super(Concat_SelfAttention_Model, self).__init__()
        self.input_text_feature_dim = input_text_feature_dim
        self.input_visual_feature_dim = input_visual_feature_dim
        self.input_audio_feature_dim = input_audio_feature_dim
        self.fused_dim = fused_dim
        self.is_prompt_linear = is_prompt_linear
        self.is_text_linear = is_text_linear
        self.is_visual_linear = is_visual_linear
        self.is_audio_linear = is_audio_linear
        self.num_heads = num_heads

        if self.is_text_linear:
            self.text_linear = nn.Linear(self.input_text_feature_dim, self.fused_dim)
        if self.is_visual_linear:
            self.visual_linear = nn.Linear(self.input_visual_feature_dim, self.fused_dim)
        if self.is_audio_linear:
            self.audio_linear = nn.Linear(self.input_audio_feature_dim, self.fused_dim)
        
        self.atten = nn.MultiheadAttention(3 * self.fused_dim, self.num_heads)

        scale = (3 * self.fused_dim) ** -0.5
        self.fuse_proj = nn.Parameter(scale * torch.randn(3 * self.fused_dim, self.fused_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        self.prompt_model, self.prompt_processor = clip.load(prompt_pretrain_model, jit=False, download_root=prompt_pretrain_model_path)
        self.prompt_model.to(device)
        if self.is_prompt_linear:
            self.prompt_linear = nn.Linear(512, self.fused_dim)

    def forward(self, text_features, visual_features, audio_features, prompts):
        # text_features: (batch_size, seq_len, input_text_feature_dim) seq_len = 1 e.g.
        if self.is_text_linear:
            text_features = self.text_linear(text_features)
        if self.is_visual_linear:
            visual_features = self.visual_linear(visual_features)
        if self.is_audio_linear:
            audio_features = self.audio_linear(audio_features)
        prompt_features = clip.tokenize(prompts).to(device)
        prompt_features = self.prompt_model.encode_text(prompt_features)
        if self.is_prompt_linear:
            prompt_features = self.prompt_linear(prompt_features)

        x = torch.cat([text_features, visual_features, audio_features], dim=-1)
        x = x.unsqueeze(1)

        x = x.permute(1, 0, 2)
        x, _ = self.atten(x, x, x)
        x = x.permute(1, 0, 2)

        x = x.squeeze(1)
        fused_features = x @ self.fuse_proj

        fused_features = fused_features / fused_features.norm(dim=1, keepdim=True)
        prompt_features = prompt_features / prompt_features.norm(dim=1, keepdim=True)
        fused_features = fused_features.float()
        prompt_features = prompt_features.float()

        logit_scale = self.logit_scale.exp().float()
        logits_per_fused = logit_scale * fused_features @ prompt_features.t()
        logits_per_label = logits_per_fused.t()

        return logits_per_fused, logits_per_label
    
    def extract_fused_feature(self, text_features, visual_features, audio_features):
        if self.is_text_linear:
            text_features = self.text_linear(text_features)
        if self.is_visual_linear:
            visual_features = self.visual_linear(visual_features)
        if self.is_audio_linear:
            audio_features = self.audio_linear(audio_features)

        x = torch.cat([text_features, visual_features, audio_features], dim=-1)
        x = x.unsqueeze(1)

        x = x.permute(1, 0, 2)
        x, _ = self.atten(x, x, x)
        x = x.permute(1, 0, 2)

        x = x.squeeze(1)
        fused_features = x @ self.fuse_proj

        return fused_features

model = Concat_SelfAttention_Model(prompt_pretrain_model=PRETRAIN_MODEL, prompt_pretrain_model_path=PRETRAIN_MODEL_PATH).to(device)
model.load_state_dict(torch.load(FINETUNE_MODEL))

with torch.no_grad():
    text_feature = torch.load(test_text_f_path)
    visual_feature = torch.load(test_visual_f_path)
    audio_feature = torch.load(test_audio_f_path)
    text_feature = text_feature.unsqueeze(0)
    visual_feature = visual_feature.unsqueeze(0)
    audio_feature = audio_feature.unsqueeze(0)
    logits_per_fused, _ = model(text_feature.to(device), visual_feature.to(device), audio_feature.to(device), clip.tokenize(test_prompts).to(device))
    probs = logits_per_fused.softmax(dim=-1).cpu().numpy()
    print(f"Predicted emotion: {emotiontags[np.argmax(probs)]}")
    os.makedirs(f'{EMOTION_FEATURES_SAVE_PATH}/implict_fused', exist_ok=True)
    torch.save(emo_prompt_features[np.argmax(probs)].squeeze(), f"{EMOTION_FEATURES_SAVE_PATH}/implict_fused/{test_text_f_path.split("/")[-1]}")
    print(f"Save {emotiontags[np.argmax(probs)]} feature in {EMOTION_FEATURES_SAVE_PATH}/implict_fused/{test_text_f_path.split("/")[-1]}")