## EPAlign Prompt and Fused-feature (text, audio, vision) Finetune

### config

In [None]:
import os
import numpy as np
import clip
import torch
from torch import nn, optim
import logging

# DATASET is the dataset name model trained on, e.g. MELD
DATASET = "MELD"

# BATCH_SIZE should smaller/equal to the category of the emotion, e.g. for MELD, the category is 6
BATCH_SIZE = 6
EPOCH = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

PROJECT_PATH = os.path.join('/', *os.getcwd().split(os.sep)[:-2])
# MELD processed path
PROCESSED_MELD_PATH = f"{PROJECT_PATH}/data/meld/processed"
# PRETRAIN_CLIP_MODEL is the pretrained CLIP model, e.g. ViT-B-32
PRETRAIN_CLIP_MODEL = "ViT-B/32"
# PRETRAIN_CLIP_MODEL_PATH is the pretrained model path, e.g. EPAlign/ckpt/base
PRETRAIN_CLIP_MODEL_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base"
# LOG_PATH is the log path, e.g. EPAlign/log
LOG_PATH = f"{PROJECT_PATH}/EPAlign/log"
# CKPT_PATH is the path to save checkpoint, e.g. EPAlign/ckpt/ESD
CKPT_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}_fused"

### Fuse Model

In [None]:
class Concat_SelfAttention_Model(nn.Module):
    def __init__(self, 
                 input_text_feature_dim=4096,
                 input_visual_feature_dim=512,
                 input_audio_feature_dim=512,
                 fused_dim=512,
                 is_prompt_linear=False,
                 is_text_linear=True,
                 is_visual_linear=False,
                 is_audio_linear=False,
                 num_heads=8,
                 prompt_pretrain_model="",
                 prompt_pretrain_model_path=""
                 ):
        super(Concat_SelfAttention_Model, self).__init__()
        self.input_text_feature_dim = input_text_feature_dim
        self.input_visual_feature_dim = input_visual_feature_dim
        self.input_audio_feature_dim = input_audio_feature_dim
        self.fused_dim = fused_dim
        self.is_prompt_linear = is_prompt_linear
        self.is_text_linear = is_text_linear
        self.is_visual_linear = is_visual_linear
        self.is_audio_linear = is_audio_linear
        self.num_heads = num_heads

        if self.is_text_linear:
            self.text_linear = nn.Linear(self.input_text_feature_dim, self.fused_dim)
        if self.is_visual_linear:
            self.visual_linear = nn.Linear(self.input_visual_feature_dim, self.fused_dim)
        if self.is_audio_linear:
            self.audio_linear = nn.Linear(self.input_audio_feature_dim, self.fused_dim)
        
        self.atten = nn.MultiheadAttention(3 * self.fused_dim, self.num_heads)

        scale = (3 * self.fused_dim) ** -0.5
        self.fuse_proj = nn.Parameter(scale * torch.randn(3 * self.fused_dim, self.fused_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        self.prompt_model, self.prompt_processor = clip.load(prompt_pretrain_model, jit=False, download_root=prompt_pretrain_model_path)
        self.prompt_model.to(device)
        if self.is_prompt_linear:
            self.prompt_linear = nn.Linear(512, self.fused_dim)

    def forward(self, text_features, visual_features, audio_features, prompts):
        # text_features: (batch_size, seq_len, input_text_feature_dim) seq_len = 1 e.g.
        if self.is_text_linear:
            text_features = self.text_linear(text_features)
        if self.is_visual_linear:
            visual_features = self.visual_linear(visual_features)
        if self.is_audio_linear:
            audio_features = self.audio_linear(audio_features)
        prompt_features = clip.tokenize(prompts).to(device)
        prompt_features = self.prompt_model.encode_text(prompt_features)
        if self.is_prompt_linear:
            prompt_features = self.prompt_linear(prompt_features)

        x = torch.cat([text_features, visual_features, audio_features], dim=-1)
        x = x.unsqueeze(1)

        x = x.permute(1, 0, 2)
        x, _ = self.atten(x, x, x)
        x = x.permute(1, 0, 2)

        x = x.squeeze(1)
        fused_features = x @ self.fuse_proj

        fused_features = fused_features / fused_features.norm(dim=1, keepdim=True)
        prompt_features = prompt_features / prompt_features.norm(dim=1, keepdim=True)
        fused_features = fused_features.float()
        prompt_features = prompt_features.float()

        logit_scale = self.logit_scale.exp().float()
        logits_per_fused = logit_scale * fused_features @ prompt_features.t()
        logits_per_label = logits_per_fused.t()

        return logits_per_fused, logits_per_label
    
    def extract_fused_feature(self, text_features, visual_features, audio_features):
        if self.is_text_linear:
            text_features = self.text_linear(text_features)
        if self.is_visual_linear:
            visual_features = self.visual_linear(visual_features)
        if self.is_audio_linear:
            audio_features = self.audio_linear(audio_features)

        x = torch.cat([text_features, visual_features, audio_features], dim=-1)
        x = x.unsqueeze(1)

        x = x.permute(1, 0, 2)
        x, _ = self.atten(x, x, x)
        x = x.permute(1, 0, 2)

        x = x.squeeze(1)
        fused_features = x @ self.fuse_proj

        return fused_features
model = Concat_SelfAttention_Model(prompt_pretrain_model=PRETRAIN_CLIP_MODEL, prompt_pretrain_model_path=PRETRAIN_CLIP_MODEL_PATH).to(device)

### MELD Multimodal Dataset

In [None]:
from torch.utils.data import Dataset
import torch

class MELD_Multimodal_Dataset(Dataset):
    def __init__(self, 
                 mode="train",
                 filelist_path="EMITTS/filelist",
                 feature_path = "data/meld/processed",
                 is_with_unnomal_sampels=False,
                 is_return_sample_name=False):
        self.mode = mode
        self.datalist_path = f'{filelist_path}/meld_name_emotion_{self.mode}_filelist.txt'
        self.text_feature_path = f'{feature_path}/text_feature/{self.mode}_Dialogue_ID_Utterance_ID'
        self.visual_feature_path = f'{feature_path}/visual_feature/{self.mode}_Dialogue_ID_Utterance_ID'
        self.audio_feature_path = f'{feature_path}/wav_feature/{self.mode}_Dialogue_ID_Utterance_ID'
        self.is_with_unnomal_sampels = is_with_unnomal_sampels
        self.is_return_sample_name = is_return_sample_name
        self.data = self.load_data()
        self.label2text = {
            '1': "neutral",
            '2': "joy",
            '3': "sad",
            '4': "angry",
            '5': "surprise",
            '6': "fearful",
            '7': "disgust",
        }

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data = self.data[index]
        label = data[1]
        emotiontag = self.label2text[label]
        prompt = f"A person speaking with a feeling of {emotiontag}"
        text_feature = torch.load(f'{self.text_feature_path}/{data[0]}.pt', map_location='cpu')
        visual_feature = torch.load(f'{self.visual_feature_path}/{data[0]}.pt', map_location='cpu')
        audio_feature = torch.load(f'{self.audio_feature_path}/{data[0]}.pt', map_location='cpu')
        if self.is_return_sample_name:
            return text_feature, visual_feature, audio_feature, prompt, int(label), data[0]
        return text_feature, visual_feature, audio_feature, prompt, int(label)

    def load_data(self):
        with open(self.datalist_path, encoding='utf-8') as f:
            data = [line.strip().split("|") for line in f]
        if not self.is_with_unnomal_sampels:
            if self.mode == 'train':
                # del 'dia125_utt3' in train
                data = [d for d in data if d[0] != 'dia125_utt3']
            elif self.mode == 'val':
                # del 'dia110_utt7' in val
                data = [d for d in data if d[0] != 'dia110_utt7']
            elif self.mode == 'test':
                # del 'dia110_utt7' in test
                data = [d for d in data if d[0] != 'dia38_utt4']
        return data
    
train_dataset = MELD_Multimodal_Dataset(mode="train", feature_path=PROCESSED_MELD_PATH)
val_dataset = MELD_Multimodal_Dataset(mode="val", feature_path=PROCESSED_MELD_PATH)
test_dataset = MELD_Multimodal_Dataset(mode="test", feature_path=PROCESSED_MELD_PATH)
# (9988, 1109, 2609)
assert len(train_dataset) == 9988
assert len(val_dataset) == 1109
assert len(test_dataset) == 2609

### Define Batch Sample (ensures no same class per batch)

In [None]:
from torch.utils.data import BatchSampler, DataLoader
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MELD-Multimodal-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

def collate_fn(batch):
    text_features = [sample[0] for sample in batch]
    text_features = torch.stack(text_features).to(device)
    visual_features = [sample[1] for sample in batch]
    visual_features = torch.stack(visual_features).to(device)
    audio_features = [sample[2] for sample in batch]
    audio_features = torch.stack(audio_features).to(device)
    prompts = [sample[3] for sample in batch]
    # prompts = torch.stack(prompts).to(device)
    labels = [sample[4] for sample in batch]
    labels = torch.tensor(labels).to(device)
    return text_features, visual_features, audio_features, prompts, labels

train_labels = torch.tensor([item[4] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn)

dev_labels = torch.tensor([item[4] for item in val_dataset])
dev_sampler = BalancedBatchSampler(dev_labels, BATCH_SIZE, 1)
dev_dataloader = DataLoader(val_dataset, batch_sampler=dev_sampler, collate_fn=collate_fn)

### Train Config

In [None]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

loss_fused = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

import itertools
parameters = itertools.chain(model.text_linear.parameters(), [model.fuse_proj, model.logit_scale])
optimizer = optim.Adam(parameters, lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

### Train Log

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

file_handler = logging.FileHandler(f"{LOG_PATH}/log_prompt_audio_{DATASET}.txt")

file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

log = logging.getLogger('')
log.addHandler(file_handler)
log.info('finetune start...')

### Train

In [None]:
from tqdm import tqdm

best_dev_loss = 1e5
best_epoch = -1
for epoch in range(EPOCH):
    logging.info((f"running epoch {epoch}, best test loss {best_dev_loss} after epoch {best_epoch}"))
    step = 0
    train_loss = 0
    model.train()
    train_pbar = tqdm(train_dataloader, leave=False)
    for batch in train_pbar:
        step += 1
        optimizer.zero_grad()

        text_features, visual_features, audio_features, prompts, labels = batch
        # change features type
        text_features = text_features.float()
        visual_features = visual_features.float()
        audio_features = audio_features.float()

        logits_per_fused, logits_per_label = model(text_features, visual_features, audio_features, prompts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_fused(logits_per_fused, ground_truth) + loss_fused(logits_per_label, ground_truth)) / 2
        total_loss.backward()
        train_loss += total_loss.item()
        optimizer.step()
        scheduler.step()
        train_pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    train_loss /= step

    step = 0
    dev_loss = 0
    with torch.no_grad():
        model.eval()
        dev_pbar = tqdm(dev_dataloader, leave=False)
        for batch in dev_pbar:
            step += 1
            text_features, visual_features, audio_features, prompts, labels = batch
            # change features type
            text_features = text_features.float()
            visual_features = visual_features.float()
            audio_features = audio_features.float()
            logits_per_fused, logits_per_label = model(text_features, visual_features, audio_features, prompts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_fused(logits_per_fused, ground_truth) + loss_fused(logits_per_label, ground_truth)) / 2
            dev_loss += total_loss.item()
            dev_pbar.set_description(f"dev batchCE: {total_loss.item()}", refresh=True)
        if dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            best_epoch = epoch
        dev_loss /= step
        torch.save(model, f"{CKPT_PATH}/best_model.pt")
    logging.info(f"epoch {epoch}, train loss {train_loss}, dev loss {dev_loss}")
        