In [3]:
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))

In [5]:
import torch
from src.pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor, AdaFace, ReDimNet
from synthweave.utils.datasets import get_datamodule
from synthweave.utils.fusion import get_fusion
from pathlib import Path
import json
from tqdm.auto import tqdm
from torchmetrics.classification import F1Score, Accuracy

class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

### DATASET

In [7]:
ds_kwargs = {
    "data_dir": "../encoded_data/DeepSpeak_v1_1",
    "preprocessed": True,
    "sample_mode": "sequence",
}

dm = get_datamodule(
    "DeepSpeak_v1_1",
    batch_size=1, # NOTE: currently single window fusions don't ignore padding
    dataset_kwargs=ds_kwargs,
    sample_mode="sequence",  # single, sequence
    clip_mode = None,
    pad_mode = 'zeros'
)

dm.setup()

In [8]:
test_loader = dm.test_dataloader()
next(iter(test_loader))['video'].squeeze(0).shape

torch.Size([22, 512])

### PIPELINE

In [9]:
FUSION = "MMD"
TASK = "binary"

path = Path("logs") / TASK / FUSION
path = sorted(path.glob("version_*"))[-1]

# config
args = json.loads((path / "args.json").read_text())
args = dotdict(args)

# best checkpoint
ckpt = path / "checkpoints"
ckpt = sorted(ckpt.glob("epoch=*.ckpt"))[-1]

In [16]:
fusion = get_fusion(
    fusion_name=FUSION,
    output_dim=512,
    modality_keys=["video", "audio"],
    out_proj_dim=1024,
    num_att_heads=4,  # only for attention-based fusions
    n_layers=3,
    dropout=0.1,
)

fusion.eval();

In [18]:
fusion({"audio": torch.randn(1, 128), "video": torch.randn(1, 512)}).shape

KeyError: 0

In [143]:
models = {"audio": torch.nn.Identity(), "video": torch.nn.Identity()}

EMB_DIM = args.emb_dim

fusion = get_fusion(
    fusion_name=FUSION,
    output_dim=EMB_DIM,
    modality_keys=["video", "audio"],
    out_proj_dim=args.proj_dim,
    num_att_heads=4,  # only for attention-based fusions
    dropout=args.dropout,
)

if args.task == "binary":
    detection_head = torch.nn.Sequential(
        torch.nn.Linear(EMB_DIM, 1), torch.nn.Sigmoid()
    )
elif args.task == "fine-grained":
    detection_head = torch.nn.Sequential(
        torch.nn.Linear(EMB_DIM, 4), torch.nn.Softmax(dim=1)
    )

pipe = MultiModalAuthPipeline(
    models=models,
    fusion=fusion,
    detection_head=detection_head,
    freeze_backbone=True,
)

state_dict = torch.load(ckpt, map_location="cpu")['state_dict']
state_dict = {k.replace("pipeline.", ""): v for k, v in state_dict.items()}
pipe.load_state_dict(state_dict, strict=False)

pipe = pipe.cuda()
pipe.eval();

[INFO] This fusion expects embeddings of shape (batch_size, embed_dim).


EXAMPLE EVAL RUN

In [188]:
sample = next(iter(test_loader))

# Place on GPU
sample["video"] = sample["video"].squeeze(0).cuda()
sample["audio"] = sample["audio"].squeeze(0).cuda()

with torch.no_grad():
    out = pipe(sample)
    
    preds = (out["logits"] > 0.5).type(torch.int64).cpu()
    final_pred = torch.mode(preds, dim=0).values # NOTE: Majority vote
    
    print("GT:", sample['metadata']["label"].cpu().item())
    print("Pred:", final_pred.item())

GT: 1
Pred: 1


EVALUATION

In [189]:
f1 = F1Score(task="binary")
acc = Accuracy(task="binary")

In [190]:
f1.reset()
acc.reset()

train_loader = dm.train_dataloader()
for sample in tqdm(train_loader):
    sample["video"] = sample["video"].squeeze(0).cuda()
    sample["audio"] = sample["audio"].squeeze(0).cuda()

    with torch.no_grad():
        out = pipe(sample)
        
        preds = (out["logits"] > 0.5).type(torch.int64).cpu()
        final_pred = torch.mode(preds, dim=0).values
        gt = sample['metadata']["label"].cpu()
        
        f1.update(final_pred, gt)
        acc.update(final_pred, gt)
        
print(f"Train F1:  {f1.compute().item(): .3f}")
print(f"Train Acc: {acc.compute().item(): .3f}")

  0%|          | 0/9463 [00:00<?, ?it/s]

Train F1:   0.995
Train Acc:  0.995


In [191]:
f1.reset()
acc.reset()

dev_loader = dm.val_dataloader()
for sample in tqdm(dev_loader):
    sample["video"] = sample["video"].squeeze(0).cuda()
    sample["audio"] = sample["audio"].squeeze(0).cuda()

    with torch.no_grad():
        out = pipe(sample)
        
        preds = (out["logits"] > 0.5).type(torch.int64).cpu()
        final_pred = torch.mode(preds, dim=0).values
        gt = sample['metadata']["label"].cpu()
        
        f1.update(final_pred, gt)
        acc.update(final_pred, gt)

print(f"Dev F1:  {f1.compute().item(): .3f}")
print(f"Dev Acc: {acc.compute().item(): .3f}")

  0%|          | 0/1047 [00:00<?, ?it/s]

Dev F1:   0.981
Dev Acc:  0.981


In [192]:
f1.reset()
acc.reset()

test_loader = dm.test_dataloader()
for sample in tqdm(test_loader):
    sample["video"] = sample["video"].squeeze(0).cuda()
    sample["audio"] = sample["audio"].squeeze(0).cuda()

    with torch.no_grad():
        out = pipe(sample)
        
        preds = (out["logits"] > 0.5).type(torch.int64).cpu()
        final_pred = torch.mode(preds, dim=0).values
        gt = sample['metadata']["label"].cpu()
        
        f1.update(final_pred, gt)
        acc.update(final_pred, gt)
        
print(f"Test F1:  {f1.compute().item(): .3f}")
print(f"Test Acc: {acc.compute().item(): .3f}")

  0%|          | 0/2911 [00:00<?, ?it/s]

Test F1:   0.732
Test Acc:  0.665


In [22]:
# out["video_ref"] = torch.rand_like(out["video"])
# out["audio_ref"] = torch.rand_like(out["audio"])

# # out['video_ref'] = out['video'].clone()
# # out['audio_ref'] = out['audio'].clone()

# with torch.no_grad():
#     sim = pipe.verify(out)

In [23]:
# sim["video"][0].item(), sim["audio"][0].item()

(-4.457475662231445, 1.972063422203064)