In [1]:
import pickle
from collections import defaultdict
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from einops import rearrange
from torch import Tensor
from torch.nn.functional import normalize, softmax, cosine_similarity
from torchvision.transforms.functional import resize, rgb_to_grayscale
from torchvision.utils import save_image
from tqdm.auto import tqdm

from src.dataset import VoiceToFaceDataset
from src.model.eigenface import Eigenface
from src.model.mlp import MLP
from src.model.voice_embedder import forge_voice_embedder_with_parameters, DEFAULT_OUTPUT_FEATURE_NUM, VoiceEmbedNet
from src.model.generator import forge_generator_with_parameters, Generator
from src.config import TrainingConfig
from src.utils import get_tensor_device
from src.model.resnet import resnet50

# Testing

In [2]:
device = torch.device('cuda')

In [3]:
test_dataset = VoiceToFaceDataset(
    Path('datasets/voices/'),
    Path('datasets/images/'),
    Path('datasets/metadata-test.csv'),
)

Creating dataset from datasets/metadata-test.csv.


Loading dataset...:   0%|          | 0/124 [00:00<?, ?it/s]

In [4]:
# TARGET_PATH = Path('checkpoints/training/2022-12-9-2-15-21')
# TARGET_EPOCH = 25
TARGET_PATH = Path('checkpoints/training/2022-12-9-15-31-52')
TARGET_EPOCH = 12

training_config = TrainingConfig.from_json(TARGET_PATH / 'config.json')

In [5]:
voice_embedder = forge_voice_embedder_with_parameters(
    Path('checkpoints/voice_embedding.pth')
).to(device).eval()
generator = forge_generator_with_parameters(
    Path('checkpoints/generator.pth')
).to(device).eval()
eigenface_converter = Eigenface(Path('checkpoints/input-15k-pc-5k.npy')).to(device)
mlp = MLP(
    DEFAULT_OUTPUT_FEATURE_NUM,
    eigenface_converter.eigenface_components,
    training_config.mlp_hidden_size,
    training_config.mlp_hidder_layer_num,
    training_config.mlp_dropout_probability,
).to(device).eval()
mlp.load_state_dict(
    torch.load(TARGET_PATH / f'mlp-{TARGET_EPOCH}.pth', map_location=device)
)

Initializing Eigenface converter with eigenface_components = 5000.


<All keys matched successfully>

In [6]:
resnet = resnet50(num_classes=8631).to(device).eval()
with open('checkpoints/resnet50_scratch_weight.pkl', 'rb') as f:
    ckpt = pickle.load(f)
    ckpt = {k: torch.from_numpy(v).to(device) for k, v in ckpt.items()}
resnet.load_state_dict(ckpt)

<All keys matched successfully>

## Image Reconstruction

In [7]:
@torch.inference_mode()
def reconstruct_ours(feature: Tensor, v: VoiceEmbedNet, m: MLP, e: Eigenface):
    if feature.dim() == 2:
        feature = feature.unsqueeze(0)
    voice_embedding = v(feature)
    voice_embedding = rearrange(voice_embedding, 'N C 1 1 -> N C')
    eigenface = m(voice_embedding)
    reconstructed_face = eigenface_converter.eigenface_to_face(eigenface).reshape(128, 128)
    reconstructed_face = resize(reconstructed_face.unsqueeze(0), [64, 64])
    return reconstructed_face

@torch.inference_mode()
def reconstruct_yans(feature: Tensor, v: VoiceEmbedNet, g: Generator):
    if feature.dim() == 2:
        feature = feature.unsqueeze(0)
    voice_embedding = v(feature)
    reconstructed_face: Tensor = generator(voice_embedding)
    reconstructed_face = rgb_to_grayscale(reconstructed_face)
    reconstructed_face = rearrange(reconstructed_face, '1 1 W H -> 1 W H')
    return reconstructed_face

## Testing

In [8]:
ours_cosine_similarities = defaultdict(list)
yans_cosine_similarities = defaultdict(list)

with torch.inference_mode():
    for i in tqdm(range(len(test_dataset))):
        if i % 2000 == 0:
            print(f'Current {i = }.')
        voice_feature, voice_id, ground_truth_images, name = test_dataset[i]
        voice_feature = voice_feature.to(device).unsqueeze(0)
        ground_truth_images = ground_truth_images.to(device)
        ground_truth_images = rearrange(ground_truth_images, 'N (H W) -> N 1 H W', W=128)
        ground_truth_images = resize(ground_truth_images, [64, 64])
        ground_truth_images = torch.cat([ground_truth_images for _ in range(3)], dim=1)
        
        our_reconstruction = reconstruct_ours(voice_feature, voice_embedder, mlp, eigenface_converter)
        our_reconstruction = rearrange(our_reconstruction, '1 H W -> 1 1 H W')
        our_reconstruction = torch.cat([our_reconstruction for _ in range(3)], dim=1)

        yan_reconstruction = reconstruct_yans(voice_feature, voice_embedder, generator)
        yan_reconstruction = rearrange(yan_reconstruction, '1 H W -> 1 1 H W')
        yan_reconstruction = torch.cat([yan_reconstruction for _ in range(3)], dim=1)
        
        ours_embedding = resnet(our_reconstruction)
        yans_embedding = resnet(yan_reconstruction)
        ground_truth_embeddings = resnet(ground_truth_images)
        
        ours_cosine_similarity = cosine_similarity(ours_embedding, ground_truth_embeddings).mean().item()
        yans_cosine_similarity = cosine_similarity(yans_embedding, ground_truth_embeddings).mean().item()
        ours_cosine_similarities[name].append(ours_cosine_similarity)
        yans_cosine_similarities[name].append(yans_cosine_similarity)

  0%|          | 0/14345 [00:00<?, ?it/s]

In [9]:
ours_cosine_similarities_mean = torch.tensor([
    torch.tensor(v).mean().item() for v in ours_cosine_similarities.values()
]).mean().item()
yans_cosine_similarities_mean = torch.tensor([
    torch.tensor(v).mean().item() for v in yans_cosine_similarities.values()
]).mean().item()

In [10]:
ours_cosine_similarities_mean, yans_cosine_similarities_mean

(0.5907018184661865, 0.41483274102211)