In [None]:
import sys
sys.path.append('..')

In [None]:
import torch
from src.pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor, AdaFace, ReDimNet
from synthweave.utils.datasets import get_datamodule, SWAN_DF_Dataset
from synthweave.utils.fusion import get_fusion
from pathlib import Path
import json

class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

### DATASET

In [None]:
ds_kwargs = {
    "root_real": "/home/woleek/SynthWeave/data/SWAN-Idiap",
    "root_df": "/home/woleek/SynthWeave/data/SWAN-DF",
    "resolutions": ["320x320"]
}

dm = get_datamodule(
    "SWAN_DF",
    dataset_cls=SWAN_DF_Dataset,
    batch_size=1, # NOTE: currently single window fusions don't ignore padding
    dataset_kwargs=ds_kwargs,
    sample_mode="sequence",  # single, sequence
    clip_mode = None,
    pad_mode = 'zeros'
)

dm.setup()

In [None]:
test_loader = dm.test_dataloader()
next(iter(test_loader))['video'].squeeze(0).shape

### PIPELINE

In [None]:
fusion_module_dir = Path("/home/woleek/SynthWeave/SynthWeave/examples/multimodal_auth/CAFF")

# config
args = json.loads((fusion_module_dir / "args.json").read_text())
args = dotdict(args)

# weights
weights_path = fusion_module_dir / "detection_module.ckpt"

In [None]:
preprocessors = {
    "video": ImagePreprocessor(
        window_len=4,
        step=1,
        estimate_quality=True,
        models_dir="/home/woleek/SynthWeave/models",
    ),
    "audio": AudioPreprocessor(
        window_len=4,
        step=1,
        use_vad=True,
    )
}

In [None]:
models = {
    "audio": AdaFace(
        path="/home/woleek/SynthWeave/models"
    ), 
    "video": ReDimNet()
}

In [None]:
fusion = get_fusion(
    fusion_name=args.fusion,
    output_dim=args.emb_dim,
    modality_keys=["video", "audio"],
    input_dims={"video":512, "audio":192},
    out_proj_dim=args.proj_dim,
)

detection_head = torch.nn.ModuleDict([
    ("classifier", torch.nn.Sequential(
        torch.nn.Linear(args.emb_dim, 1)
    ))
])

In [None]:
pipe = MultiModalAuthPipeline(
    processors=preprocessors,
    models=models,
    fusion=fusion,
    detection_head=detection_head,
    freeze_backbone=True,
    iil_mode=args.iil_mode,
)

state_dict = torch.load(weights_path, map_location="cpu")['state_dict']
state_dict = {k.replace("pipeline.", ""): v for k, v in state_dict.items()}
pipe.load_state_dict(state_dict, strict=False)

pipe = pipe.cuda()
pipe.eval();

EXAMPLE RUN

In [None]:
sample = next(iter(test_loader))

# Place on GPU
sample["video"] = sample["video"].squeeze(0).cuda()
sample["audio"] = sample["audio"].squeeze(0).cuda()

with torch.no_grad():
    out = pipe(sample)
    
    preds = (out["logits"] > 0.5).type(torch.int64).cpu()
    final_pred = torch.mode(preds, dim=0).values # NOTE: Majority vote
    
    print("GT:", sample['metadata']["label"].cpu().item())
    print("Pred:", final_pred.item())

In [None]:
out["video_ref"] = torch.rand_like(out["video"])
out["audio_ref"] = torch.rand_like(out["audio"])

# out['video_ref'] = out['video'].clone()
# out['audio_ref'] = out['audio'].clone()

with torch.no_grad():
    sim = pipe.verify(out)

In [None]:
sim["video"][0].item(), sim["audio"][0].item()