In [None]:
import sys
sys.path.append('..')

In [None]:
import torch
from src.pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor, AdaFace, ReDimNet, ClassifierHead
import json
from pathlib import Path
from synthweave.utils.fusion import get_fusion
from synthweave.utils.tools import read_video

class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

### PIPELINE

In [None]:
fusion_module_dir = Path("/home/woleek/SynthWeave/SynthWeave/examples/multimodal_auth/CAFF")

# config
args = json.loads((fusion_module_dir / "args.json").read_text())
args = dotdict(args)

# weights
weights_path = fusion_module_dir / "detection_module.ckpt"

In [None]:
preprocessors = {
    "video": ImagePreprocessor(
        window_len=4,
        step=1,
        estimate_quality=False,
        models_dir="/home/woleek/SynthWeave/models",
        quality_model_type="ir50"
    ),
    "audio": AudioPreprocessor(
        window_len=4,
        step=1,
        use_vad=True,
    )
}

In [None]:
models = {
    "video": AdaFace(
        path="/home/woleek/SynthWeave/models",
        model_type="ir50",
    ), 
    "audio": ReDimNet()
}

In [None]:
fusion = get_fusion(
    fusion_name=args.fusion,
    output_dim=args.emb_dim,
    modality_keys=["video", "audio"],
    input_dims={"video":512, "audio":192},
    out_proj_dim=args.proj_dim,
)

detection_head = ClassifierHead(input_dim=args.emb_dim, num_classes=1)

In [None]:
pipe = MultiModalAuthPipeline(
    processors=preprocessors,
    models=models,
    fusion=fusion,
    detection_head=detection_head,
    freeze_backbone=True,
    iil_mode=args.iil_mode,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

state_dict = torch.load(weights_path, map_location="cpu")['state_dict']
state_dict = {k.replace("pipeline.", ""): v for k, v in state_dict.items()}
pipe.load_state_dict(state_dict, strict=False)

pipe = pipe.cuda()
pipe.eval();

### SAMPLE

In [None]:
df_sample = "./samples/john_face_fake.mp4"
bf_sample = "./samples/john_real.mp4"

In [None]:
vid, aud, meta = read_video(df_sample)
df = {
    "video": [vid, meta['video_fps']],
    "audio": [aud, meta['audio_fps']],
    "metadata": meta
}

vid, aud, meta = read_video(bf_sample)
bf = {
    "video": [vid, meta['video_fps']],
    "audio": [aud, meta['audio_fps']],
    "metadata": meta
}

Infer pipeline

In [None]:
with torch.no_grad():
    df_out = pipe(df) # Sample for pipeline inference
    bf_out = pipe(bf) # Simulates database

Threshold based on quality, face/voice detection etc.

In [None]:
print("Original len: ", df_out['org_len'])
print("Processed len: ", df_out['valid_len']) # NOTE: Can threshold based on % of dropped windows

Check DeepFake module prediction

In [None]:
probs = torch.sigmoid(df_out["logits"]).cpu()
prob_per_clip = probs.mean()
preds_per_clip = (prob_per_clip >= 0.5).long() # NOTE: set threshold for DeepFake detection

print("Pred:", "Bonafide" if preds_per_clip.item() == 0 else "DeepFake")

Run verification

In [None]:
with torch.no_grad():
    sim = pipe.verify({
        "video": df_out["video"].cpu(),
        "audio": df_out["audio"].cpu(),
        "video_ref": bf_out["video"].cpu(), # NOTE: put refference embeddings here
        "audio_ref": bf_out["audio"].cpu()
    })

Face ver

In [None]:
# NOTE: Select aggregation methods across windows
# vid_sim = sim["video"].max(dim=1).values                # Max similarity to any reference
vid_sim = torch.tensor(sim["video"]).mean(dim=1)          # Mean similarity to reference

vid_th = 60.0 # NOTE: set threshold for video similarity

passed = vid_sim > vid_th
print("Pass status:", passed.tolist())
print("Face verified:", passed.all().item()) # True if all windows passed the threshold

Audio ver

In [None]:
# NOTE: Select aggregation methods across windows
# aud_sim = sim["audio"].max(dim=1).values                # Max similarity to any reference
aud_sim = torch.tensor(sim["audio"]).mean(dim=1)          # Mean similarity to reference

aud_th = 60.0 # NOTE: set threshold for video similarity

passed = aud_sim > aud_th
print("Pass status:", passed.tolist())
print("Voice verified:", passed.all().item()) # True if all windows passed the threshold