Skip to content

Commit

Permalink
merge emotion task for celebvhq
Browse files Browse the repository at this point in the history
  • Loading branch information
ControlNet committed Feb 9, 2024
1 parent 3f84ba0 commit eef609e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: "celebvhq_marlin_appearance_ft"
model_name: "celebvhq_marlin_emotion_ft"
backbone: "marlin_vit_base_ytf"
dataset: "celebvhq"
task: "emotion"
Expand Down
27 changes: 19 additions & 8 deletions dataset/celebv_hq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from itertools import islice
from typing import Optional
from typing import Optional, List

import ffmpeg
import numpy as np
Expand All @@ -15,6 +15,7 @@


class CelebvHqBase(LightningDataModule, ABC):
emotions = ["neutral", "happy", "sadness", "anger", "fear", "surprise", "contempt", "disgust"]

def __init__(self, data_root: str, split: str, task: str, data_ratio: float = 1.0, take_num: int = None):
super().__init__()
Expand All @@ -26,7 +27,7 @@ def __init__(self, data_root: str, split: str, task: str, data_ratio: float = 1.

self.name_list = list(
filter(lambda x: x != "", read_text(os.path.join(data_root, f"{self.split}.txt")).split("\n")))
# self.metadata = read_json(os.path.join(data_root, "celebvhq_info.json"))
self.metadata = read_json(os.path.join(data_root, "celebvhq_info.json"))

if data_ratio < 1.0:
self.name_list = self.name_list[:int(len(self.name_list) * data_ratio)]
Expand All @@ -42,6 +43,16 @@ def __getitem__(self, index: int):
def __len__(self):
return len(self.name_list)

@classmethod
def parse_emotion_label(cls, emotion_annotation: dict) -> List[int]:
labels = [0] * 8
if emotion_annotation["sep_flag"]:
for emo in emotion_annotation["labels"]:
labels[cls.emotions.index(emo["emotion"])] = 1
return labels
else:
labels[cls.emotions.index(emotion_annotation["labels"])] = 1
return labels

# for fine-tuning
class CelebvHq(CelebvHqBase):
Expand All @@ -60,9 +71,9 @@ def __init__(self,
self.temporal_sample_rate = temporal_sample_rate

def __getitem__(self, index: int):
#y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task]
class_idx = int(self.name_list[index].split("-")[2])
y = torch.eye(8)[class_idx-1] # one-hot encoding for the emotion class
y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task]
if self.task == "emotion":
y = self.parse_emotion_label(y)
video_path = os.path.join(self.data_root, "cropped", self.name_list[index] + ".mp4")

probe = ffmpeg.probe(video_path)["streams"][0]
Expand Down Expand Up @@ -125,9 +136,9 @@ def __getitem__(self, index: int):
else:
raise ValueError(self.temporal_reduction)

#y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task]
class_idx = int(self.name_list[index].split("-")[2])
y = torch.eye(8)[class_idx-1] # one-hot emotion for the emotion class
y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task]
if self.task == "emotion":
y = CelebvHq.parse_emotion_label(y)

return x, torch.tensor(y, dtype=torch.long).bool()

Expand Down
11 changes: 5 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def train_celebvhq(args, config):
num_classes = 40
elif task == "action":
num_classes = 35
elif task == "emotion": # [asroman]
elif task == "emotion":
num_classes = 8
else:
raise ValueError(f"Unknown task {task}")
Expand Down Expand Up @@ -68,7 +68,7 @@ def train_celebvhq(args, config):
dm.setup()
return resume_ckpt, dm

strategy = "ddp" #None if n_gpus <= 1 else "ddp"
strategy = None if n_gpus <= 1 else "ddp"
accelerator = "cpu" if n_gpus == 0 else "gpu"

ckpt_filename = config["model_name"] + "-{epoch}-{val_auc:.3f}"
Expand All @@ -84,10 +84,9 @@ def train_celebvhq(args, config):
monitor=ckpt_monitor,
mode="max")

print("resume_ckpt", resume_ckpt)
trainer = Trainer(log_every_n_steps=1, devices=n_gpus, accelerator=accelerator, benchmark=True,
logger=True, precision=precision, max_epochs=max_epochs,
strategy=strategy, #resume_from_checkpoint=resume_ckpt,
strategy=strategy, resume_from_checkpoint=resume_ckpt,
callbacks=[ckpt_callback, LrLogger(), EarlyStoppingLR(1e-6), SystemStatsLogger()])

trainer.fit(model, dm)
Expand Down Expand Up @@ -140,11 +139,11 @@ def evaluate(args):
parser.add_argument("--data_path", type=str, help="Path to CelebV-HQ dataset.")
parser.add_argument("--marlin_ckpt", type=str, default=None,
help="Path to MARLIN checkpoint. Default: None, load from online.")
parser.add_argument("--n_gpus", type=int, default=2)
parser.add_argument("--n_gpus", type=int, default=1)
parser.add_argument("--precision", type=str, default="32")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=200, help="Max epochs to train.")
parser.add_argument("--epochs", type=int, default=2000, help="Max epochs to train.")
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training.")
parser.add_argument("--skip_train", action="store_true", default=False,
help="Skip training and evaluate only.")
Expand Down
3 changes: 1 addition & 2 deletions model/classifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Union, Sequence, Dict, Literal, Any

#from memory_profiler import profile [asroman]
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import CrossEntropyLoss, Linear, Identity, BCEWithLogitsLoss
from torch.nn import CrossEntropyLoss, Linear, BCEWithLogitsLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, AUROC
Expand Down
2 changes: 1 addition & 1 deletion preprocess/celebvhq_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parser.add_argument("--data_dir", type=str)
args = parser.parse_args()

model = Marlin.from_file("marlin_vit_base_ytf", "pretrained/marlin_vit_base_ytf.encoder.pt") #.from_online(args.backbone)
model = Marlin.from_online(args.backbone)
config = resolve_config(args.backbone)
feat_dir = args.backbone

Expand Down

0 comments on commit eef609e

Please sign in to comment.