### Prepare Repo

In [1]:
# !pip install git+https://github.com/ChuakBlurk/vidaug

In [2]:
# !pip install omegaconf==2.1.1
# !pip install hydra-core==1.1.1
# !pip install -U numpy==1.23.5
# !apt-get update && apt-get install -y python3-opencv
# !pip install opencv-python
# !pip install scikit-image 
# !pip install transformers
# !pip install datasets
# !pip install transformers[torch]
# !pip install accelerate -U
# !pip install wandb
# !pip install scikit-learn

In [3]:
# !git clone https://github.com/facebookresearch/av_hubert.git

# %cd av_hubert
# !git submodule init
# !git submodule update
# !pip install scipy
# !pip install sentencepiece
# !pip install python_speech_features
# !pip install scikit-video

# %cd fairseq
# !pip install ./

In [4]:
import fairseq
from fairseq import checkpoint_utils, options, tasks, utils
import cv2
import tempfile
import torch
from transformers import Trainer, TrainingArguments
import sys
sys.path.append("/home/multi_modal_ser/finetune_encoder/audio_video/av_hubert/avhubert")
%cd /home/multi_modal_ser/finetune_encoder/audio_video/av_hubert/
import utils as avhubert_utils
from argparse import Namespace
from IPython.display import HTML
import numpy as np
import sys
print(sys.version)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
import wandb
from torch.utils.data import Dataset, Subset
import os
import datetime

/home/multi_modal_ser/finetune_encoder/audio_video/av_hubert
3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]


In [5]:
print(device)
!nvidia-smi

cuda
Sat Nov  4 15:28:51 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
| 30%   35C    P8              33W / 350W |      6MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                               

### Download Model

In [6]:
# os.makedirs("/home/check_pts/")
# # !wget https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/vsr/base_vox_433h.pt -O /home/check_pts/avhubert.pt
# !wget https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/clean-pretrain/base_vox_iter4.pt -O /home/check_pts/avhubert.pt

### Build Model Pipeline

In [7]:
user_dir = "/home/multi_modal_ser/finetune_encoder/audio_video/av_hubert/avhubert"
utils.import_user_module(Namespace(user_dir=user_dir))
ckpt_path = "/home/check_pts/avhubert.pt"
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])  
model = models[0]
if hasattr(models[0], 'decoder'):
    print(f"Checkpoint: fine-tuned")
    model = models[0].encoder.w2v_model
else:
    print(f"Checkpoint: pre-trained w/o fine-tuning")



Checkpoint: pre-trained w/o fine-tuning


### Load Dataset

In [8]:
import torch
print(torch.__version__)

2.1.0


In [9]:
from avhubert_ds import AVHUBERTDataset
mmser_ds = torch.load("/home/avhubert_ds2.pt")
mmser_ds.video_path = "/home/face_raw/"

# outputs = model.extract_finetune(mmser_ds[:2])

In [10]:
mmser_ds.cached = False
mmser_ds.__cache__()

100%|██████████| 5531/5531 [12:51<00:00,  7.17it/s]


### Define the model

In [11]:
from avhubert_classifier import AVHUBERTClassifier

In [12]:
# classifier = AVHUBERTClassifier(model, 768, 256, mmser_ds.df_["emotion_id"].nunique())
# classifier(**mmser_ds[:4])

### Build Train Test DS

In [13]:
meta_df_ = mmser_ds.df_
mmser_ds.df_["bigsess"] = mmser_ds.df_["session"].apply(lambda x: x[:-1])
sess_dict = mmser_ds.df_.groupby("bigsess").groups
all_indices = set(mmser_ds.df_.index.tolist())

# sess_ds = {}
# for i in range(1,6):
#     sess = "Ses0{}".format(i)
#     sess_val = "Ses0{}".format(i%5+1)
#     sess_ds[sess+"_test"] = Subset(mmser_ds, 
#                                     indices=sess_dict[sess])
#     # sess_ds[sess+"_val"] = Subset(mmser_ds, 
#     #                                 indices=sess_dict[sess_val])
#     sess_ds[sess+"_train"] = Subset(mmser_ds, 
#                                     indices=list(all_indices-set(sess_dict[sess])))
    

In [14]:
# def build_ds(sess_id):
#     train_size = int(len(sess_ds[sess_id+"_train"])*0.8)
#     val_size = len(sess_ds[sess_id+"_train"])-train_size
#     train_set, val_set = torch.utils.data.random_split(sess_ds[sess_id+"_train"], [train_size, val_size])
#     test_set = sess_ds[sess_id+"_test"]
#     # train_set = sess_ds[sess_id+"_train"]
#     # val_set = sess_ds[sess_id+"_val"]

#     print("Train Samples:", len(train_set))
#     print("Val Samples:", len(val_set))
#     print("Test Samples:", len(test_set))
    
#     return train_set, val_set, test_set

### Data Augmentation

In [30]:
from torch.utils.data import Dataset, Subset
import os
import pandas as pd
import numpy as np
import torch 
import torch.nn.functional as F
import sys
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import vidaug.augmentors as va
import random

class VidaugDataset(Dataset):
    
    def collate(self, audio, video, max_size=500):
        padded_audio = pad_sequence([torch.tensor(a.squeeze()) for a in audio]+[torch.empty(500, 104)], batch_first=True)[:-1]
        padded_video = pad_sequence([torch.tensor(v.squeeze()) for v in video]+[torch.empty(500,88,88)], batch_first=True)[:-1, np.newaxis, : ,:,:]
        mask = torch.zeros_like(padded_audio)
        mask[padded_audio != 0] = 1
        return padded_audio, padded_video, mask
    
    
    def __init__(self, audio_feats_list, 
                 video_feats_list, 
                 text_list, 
                 labels_list):
        self.audio_feats_list = audio_feats_list
        self.video_feats_list = video_feats_list
        self.text_list = text_list
        self.labels_list = labels_list
    
        print(len(self.audio_feats_list))
        print(len(self.video_feats_list))
        print(len(self.text_list))
        print(len(self.labels_list))
        
        self.origin_len = len(self.labels_list)
        self.aug_len = 0

    
    
    def __aug__(self, niters=2, aug_prob=0.3, k=None):
        sometimes = lambda aug: va.Sometimes(aug_prob, aug) # Used to apply augmentor with 50% probability
        
        transform_list = [
            sometimes(va.InvertColor()),
            sometimes(va.Salt()),
            sometimes(va.Pepper()),
            sometimes(va.RandomShear(0.2, 0.2)),
            sometimes(va.HorizontalFlip()),
            sometimes(va.VerticalFlip()),
            sometimes(va.RandomRotate(30)),
            sometimes(va.GaussianBlur(0.8)),
            sometimes(va.ElasticTransformation(0.2,0.2)),
            sometimes(va.PiecewiseAffineTransform(20,10,0.5)),
        ]
        
        if k is None:
            seq = va.Sequential(random.choices(transform_list, 
                                           k=random.choice(range(len(transform_list)))))
        
        else:
            seq = va.Sequential(random.choices(transform_list, 
                                           k=k))
        
        self.aug_audio_feats_list = []
        self.aug_video_feats_list = []
        self.aug_text_list = []
        self.aug_labels_list = []
        
        for smp_id in tqdm(range(len(self.video_feats_list))):
            for i in range(niters):
                vid = self.video_feats_list[smp_id].squeeze()
                # change to color 
                vid = [cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB).astype(np.uint8) for frame in vid]
                vid = np.stack(vid)
                video_aug = seq(vid)
                video_aug = [cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.uint8) for frame in video_aug]
                video_aug = np.stack(video_aug)
                
                # transform = avhubert_utils.Compose([
                #   avhubert_utils.Normalize(0.0, 255.0),
                #   avhubert_utils.CenterCrop((88, 88)),
                #   avhubert_utils.Normalize(0.421, 0.165)])
                # video_aug = transform(video_aug)[np.newaxis, np.newaxis]

                video_aug = (video_aug/255)[np.newaxis, np.newaxis]

                
                video_aug = torch.tensor(video_aug)
                self.aug_audio_feats_list.append(self.audio_feats_list[smp_id])
                self.aug_video_feats_list.append(video_aug)
                self.aug_text_list.append(self.text_list[smp_id])
                self.aug_labels_list.append(self.labels_list[smp_id])
                
        self.aug_len = len(self.aug_labels_list)
                
        
    def __len__(self):
        return self.origin_len + self.aug_len
    
    def __getitem__(self, idx):
        if idx < self.origin_len:
            audio_feats = self.audio_feats_list[idx]
            video_feats = self.video_feats_list[idx]
            padded_audio, padded_video, padding_mask = self.collate([audio_feats], [video_feats])
            return {
                "padding_mask": padding_mask[0][:500, :].float(),
                "audio": padded_audio[0][:500, :].T.float(),
                "video": padded_video[0][:, :500, :, :].float(),
                "text": self.text_list[idx].float(),
                "labels": self.labels_list[idx].float()
            }
        else:
            idx = idx - self.origin_len
            audio_feats = self.aug_audio_feats_list[idx]
            video_feats = self.aug_video_feats_list[idx]
            padded_audio, padded_video, padding_mask = self.collate([audio_feats], [video_feats])
            return {
                "padding_mask": padding_mask[0][:500, :].float(),
                "audio": padded_audio[0][:500, :].T.float(),
                "video": padded_video[0][:, :500, :, :].float(),
                "text": self.aug_text_list[idx].float(),
                "labels": self.aug_labels_list[idx].float()
            }


In [31]:
audio_feats_list = mmser_ds.audio_feats_list
video_feats_list = mmser_ds.video_feats_list
text_list = list(meta_df_["transcript"])
labels_list = list(meta_df_["emotion_id"])

del mmser_ds

NameError: name 'mmser_ds' is not defined

In [32]:
val_indices = sess_dict['Ses01']
test_indices = sess_dict['Ses02']
train_indices = list(all_indices-set(val_indices)-set(test_indices))

In [33]:
train_ds = VidaugDataset(
    [item.detach().numpy() for i, item in enumerate(audio_feats_list) if i in train_indices],
    [item.detach().numpy() for i, item in enumerate(video_feats_list) if i in train_indices],
    [item for i, item in enumerate(text_list) if i in train_indices],
    [item for i, item in enumerate(labels_list) if i in train_indices]
)

val_ds = VidaugDataset(
    [item.detach().numpy() for i, item in enumerate(audio_feats_list) if i in val_indices],
    [item.detach().numpy() for i, item in enumerate(video_feats_list) if i in val_indices],
    [item for i, item in enumerate(text_list) if i in val_indices],
    [item for i, item in enumerate(labels_list) if i in val_indices]
)

test_ds = VidaugDataset(
    [item.detach().numpy() for i, item in enumerate(audio_feats_list) if i in test_indices],
    [item.detach().numpy() for i, item in enumerate(video_feats_list) if i in test_indices],
    [item for i, item in enumerate(text_list) if i in test_indices],
    [item for i, item in enumerate(labels_list) if i in test_indices]
)

3423
3423
3423
3423
1085
1085
1085
1085
1023
1023
1023
1023


In [None]:
train_ds.__aug__(2, k=1)

 79%|███████▉  | 2696/3423 [01:24<00:31, 23.15it/s]

In [None]:
train_ds[0]["video"].shape

### Run Pipeline

API: 2999b8f99f0f62b4f64c48a1c8be9a16945183e9

In [None]:
user_dir = "/home/multi_modal_ser/finetune_encoder/audio_video/av_hubert/avhubert"
utils.import_user_module(Namespace(user_dir=user_dir))
ckpt_path = "/home/check_pts/avhubert.pt"
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])  
model = models[0]
if hasattr(models[0], 'decoder'):
    print(f"Checkpoint: fine-tuned")
    model = models[0].encoder.w2v_model
else:
    print(f"Checkpoint: pre-trained w/o fine-tuning")

In [None]:
import json
sess_id = list(sess_dict.keys())[0]
print("="*10, sess_id, "="*10)

avhubert_classifier = AVHUBERTClassifier(model, 768, 256, 5)
for param in avhubert_classifier.parameters():
    param.requires_grad = True

wandb.init()
print(sess_id)
# train_set, val_set, test_set = build_ds(sess_id)
train_set = train_ds
val_set = val_ds
test_set = test_ds


In [None]:
output_dir=os.path.join("check_pts", "AVHUBERT", sess_id, datetime.datetime.now().date().strftime(format="%Y-%m-%d"))

training_args = TrainingArguments(output_dir,report_to="wandb")
training_args.remove_unused_columns=False
training_args.per_device_train_batch_size=6
training_args.per_device_eval_batch_size=6
training_args.logging_steps = int(1000/training_args.per_device_train_batch_size)
training_args.eval_steps = int(1000/training_args.per_device_train_batch_size)
training_args.evaluation_strategy="steps" 
training_args.logging_strategy="steps"
training_args.load_best_model_at_end=True,
training_args.save_strategy = "no"
training_args.learning_rate=5e-3
training_args.num_train_epochs=15
training_args.metric_for_best_model = 'loss'

In [None]:
from avhubert_trainer import CustomTrainer , compute_metrics
from transformers import EarlyStoppingCallback, TrainerCallback, TrainerState

avhubert_classifier = avhubert_classifier.to(device)
trainer = CustomTrainer(
    model=avhubert_classifier,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)


##### Gradual Freezing

In [None]:
class FreezingCallback(TrainerCallback):
    
    def __init__(self, freeze_encoder_epochs: int):
        self.freeze_encoder_epochs = freeze_encoder_epochs

    def on_epoch_begin(self, args, state, control, **kwargs):
        print(state.epoch, self.freeze_encoder_epochs)
        model = kwargs["model"]
        if state.epoch >= self.freeze_encoder_epochs:
            print("="*10, "Freezing", "="*10)
            for param in model.encoder.feature_extractor_video.parameters():
                param.requires_grad = False

    def on_save(self, args, state, control, **kwargs):
        model = kwargs["model"]
        for name, param in model.named_parameters():
            param.requires_grad = True

In [None]:
freezing_callback = FreezingCallback(5)
trainer.add_callback(freezing_callback)

In [None]:
trainer.train()

In [None]:
del trainer

In [None]:
val_preds = trainer.predict(val_set)

In [None]:
import pandas as pd
pred_labels = val_preds.predictions.argmax(axis=1)
true_labels = val_preds.label_ids

In [None]:
print(pred_labels[10:15])
print(true_labels[10:15])

In [None]:
from sklearn.metrics import f1_score

In [None]:
f1_score(true_labels, pred_labels, average=None)

In [None]:
train_ids = [fn["fn"] for fn in train_set]    
val_ids = [fn["fn"] for fn in val_set]    

In [None]:
len(set(train_ids).intersection(set(val_ids)))

In [None]:
eval_result = trainer.evaluate()
test_result = trainer.predict(test_set).metrics

In [None]:
test_result

In [None]:
FREEZE_PROJ_PATH = "/home/freeze/{}/projector".format(sess_id)
FREEZE_CLAS_PATH = "/home/freeze/{}/classifier".format(sess_id)
os.makedirs(FREEZE_PROJ_PATH, exist_ok=True)
os.makedirs(FREEZE_CLAS_PATH, exist_ok=True)

FREEZE_PROJ = os.path.join(FREEZE_PROJ_PATH, datetime.datetime.now().date().strftime(format="%Y-%m-%d")+".pt")
FREEZE_CLAS = os.path.join(FREEZE_CLAS_PATH, datetime.datetime.now().date().strftime(format="%Y-%m-%d")+".pt")

torch.save(avhubert_classifier.projector.state_dict(), FREEZE_PROJ)
torch.save(avhubert_classifier.classifier.state_dict(), FREEZE_CLAS)

avhubert_classifier.projector.load_state_dict(torch.load(FREEZE_PROJ))
avhubert_classifier.classifier.load_state_dict(torch.load(FREEZE_CLAS))

print(eval_result)
print(test_result)

json_test = json.dumps(test_result, indent=4)
json_eval = json.dumps(eval_result, indent=4)

# Writing to sample.json
with open("{}_eval.json".format(sess_id), "w") as outfile:
    outfile.write(json_eval)
with open("{}_test.json".format(sess_id), "w") as outfile:
    outfile.write(json_test)
