# fsgan quick model tests
This notebook contains small, self-contained tests that: load model weights from `fsgan/weights/`,
run a quick forward pass (or the provided inference pipeline) on two images `a.jpg` and `j.jpg` located
in the repository root, and write outputs to an `outputs/` folder.

Cells: 1) env & imports, 2) paths & helpers, 3) segmentation test, 4) landmarks test, 5) reenactment test, 6) full swap test.

In [1]:
# Cell 2: Check and (optionally) install minimal dependencies
import sys, subprocess, importlib
pkgs = ['torch', 'torchvision', 'numpy', 'opencv_python', 'tqdm', 'matplotlib']
missing = []
for p in pkgs:
    try:
        importlib.import_module(p)
    except Exception:
        missing.append(p)

if missing:
    print('Missing packages:', missing)
    print('Installing missing packages into the current Python environment. This may take a few minutes.')
    cmd = [sys.executable, '-m', 'pip', 'install'] + missing
    subprocess.check_call(cmd)
else:
    print('All minimal packages present')

Missing packages: ['opencv_python']
Installing missing packages into the current Python environment. This may take a few minutes.


In [2]:
# Cell 3: imports, paths and helpers
import os, sys, traceback
from pathlib import Path
import torch, cv2, numpy as np
from fsgan.utils.utils import load_model

ROOT = Path('.')
WEIGHTS_DIR = ROOT / 'fsgan' / 'weights'
OUT_DIR = ROOT / 'outputs'
OUT_DIR.mkdir(exist_ok=True)
IMG_A = ROOT / 'j.jpg'
IMG_J = ROOT / 'a.jpg'

print('Weights directory:', WEIGHTS_DIR)
print('Expecting images at:', IMG_A, IMG_J)
print('Outputs will be saved to:', OUT_DIR)

def load_image_as_tensor(p, size=256, device=None):
    im = cv2.imread(str(p))
    if im is None:
        raise FileNotFoundError(f'Image not found: {p}')
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (size, size), interpolation=cv2.INTER_AREA)
    im = im.astype('float32') / 255.0
    t = torch.from_numpy(im.transpose(2,0,1)).unsqueeze(0)
    if device is not None:
        t = t.to(device)
    return t

def save_mask(mask, outpath):
    # mask: HxW uint8 or 2D numpy
    cv2.imwrite(str(outpath), mask)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device', device)

Weights directory: fsgan\weights
Expecting images at: j.jpg a.jpg
Outputs will be saved to: outputs
Using device cuda:0


In [3]:
# Cell 4: Segmentation model quick test
seg_w = WEIGHTS_DIR / 'celeba_unet_256_1_2_segmentation_v2.pth'
out_seg_dir = OUT_DIR / 'segmentation'
out_seg_dir.mkdir(exist_ok=True)
try:
    model_seg = load_model(str(seg_w), 'segmentation', device=device)
    model_seg.eval()
    t = load_image_as_tensor(IMG_A, size=256, device=device)
    with torch.no_grad():
        pred = model_seg(t)
    # handle logits: if shape (B,C,H,W) -> argmax across channel
    pred_np = pred.detach().cpu().numpy()
    if pred_np.ndim == 4:
        mask = pred_np.argmax(1)[0].astype('uint8') * 85  # 3 classes -> scale for visualization
    else:
        # fallback: save raw single-channel output
        mask = (pred_np[0,0] * 255).astype('uint8')
    out_path = out_seg_dir / 'a_seg_mask.png'
    save_mask(mask, out_path)
    print('Segmentation saved to', out_path)
except Exception as e:
    print('Segmentation test failed:', e)
    traceback.print_exc()

=> Loading segmentation model: "celeba_unet_256_1_2_segmentation_v2.pth"...
Segmentation saved to outputs\segmentation\a_seg_mask.png
Segmentation saved to outputs\segmentation\a_seg_mask.png


In [4]:
# Cell 5: Landmarks model quick test (shape/forward check)
lms_w = WEIGHTS_DIR / 'hr18_wflw_landmarks.pth'
out_lms_dir = OUT_DIR / 'landmarks'
out_lms_dir.mkdir(exist_ok=True)
try:
    model_lms = load_model(str(lms_w), 'landmarks', device=device)
    model_lms.eval()
    t = load_image_as_tensor(IMG_A, size=256, device=device)
    with torch.no_grad():
        out = model_lms(t)
    print('Landmarks model forward output shape:', getattr(out, 'shape', None))
    # If output looks like heatmaps (B,C,H,W), save the first channel as an example heatmap
    try:
        out_np = out.detach().cpu().numpy()
        if out_np.ndim == 4:
            hm = out_np[0,0]
            hm = (255 * (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)).astype('uint8')
            cv2.imwrite(str(out_lms_dir / 'a_landmark_heatmap_ch0.png'), hm)
            print('Saved example landmark heatmap to', out_lms_dir / 'a_landmark_heatmap_ch0.png')
    except Exception as e:
        print('Could not save landmark heatmap:', e)
except Exception as e:
    print('Landmarks test failed:', e)
    traceback.print_exc()

=> Loading landmarks model: "hr18_wflw_landmarks.pth"...
Landmarks model forward output shape: torch.Size([1, 98, 64, 64])
Saved example landmark heatmap to outputs\landmarks\a_landmark_heatmap_ch0.png
Landmarks model forward output shape: torch.Size([1, 98, 64, 64])
Saved example landmark heatmap to outputs\landmarks\a_landmark_heatmap_ch0.png


In [5]:
# Cell 6 (updated): Robust minimal reenactment forward (single frame)
import torch, cv2, torch.nn.functional as F, traceback
from fsgan.utils import utils, img_utils, landmarks_utils
from fsgan.models.hrnet import hrnet_wlfw
from pathlib import Path

# Use device already defined in cell 3 if present, else set one
try:
    device
except NameError:
    device, _ = utils.set_device(None)

# Paths (prefer the variables defined in cell 3 if available)
ROOT = Path('.')
reenact_w = ROOT / 'fsgan' / 'weights' / 'nfv_msrunet_256_1_2_reenactment_v2.1.pth'
lms_w = ROOT / 'fsgan' / 'weights' / 'hr18_wflw_landmarks.pth'
src_path = ROOT / 'a.jpg'
tgt_path = ROOT / 'j.jpg'
out_path = ROOT / 'outputs' / 'reenact_a_to_j.png'
out_path.parent.mkdir(parents=True, exist_ok=True)

print('Device:', device)
print('Reenact model:', reenact_w)
print('Landmarks model:', lms_w)
print('Source:', src_path, 'Target:', tgt_path)

# Helper: read BGR image and convert to float tensor in [0,1]
def read_bgr_tensor(p, device=device, size=256):
    im = cv2.imread(str(p))
    if im is None:
        raise FileNotFoundError(f'Image not found: {p}')
    # Resize to a consistent base resolution (models trained on 256x256)
    if size is not None:
        im = cv2.resize(im, (size, size), interpolation=cv2.INTER_AREA)
    return img_utils.bgr2tensor(im, normalize=False).to(device)

# Load reenactment generator (this checkpoint should contain 'arch' + 'state_dict')
try:
    Gr, ckpt = utils.load_model(str(reenact_w), 'reenactment', device=device, return_checkpoint=True)
    Gr.eval()
    # print architecture if available
    try:
        print('Checkpoint arch:', ckpt.get('arch', None))
    except Exception:
        pass
except Exception as e:
    print('Failed loading reenactment model:', e)
    traceback.print_exc()
    raise

# Load landmarks model: try load_model first, fallback to hrnet_wlfw + raw state_dict
try:
    L = None
    try:
        L = utils.load_model(str(lms_w), 'landmarks', device=device)
        print('Loaded landmarks using utils.load_model()')
    except AssertionError:
        # checkpoint is probably a raw state_dict -> fallback
        print('Landmarks file looks like raw state_dict, falling back to hrnet factory')
    if L is None:
        L = hrnet_wlfw().to(device)
        state_dict = torch.load(str(lms_w), map_location=device)
        L.load_state_dict(state_dict)
    L.eval()
except Exception as e:
    print('Failed loading landmarks model:', e)
    traceback.print_exc()
    raise

# Read images as tensors [0,1] and resize to model base resolution
base_res = 256
try:
    src_t = read_bgr_tensor(src_path, device=device, size=base_res)
    tgt_t = read_bgr_tensor(tgt_path, device=device, size=base_res)
except Exception as e:
    print('Image loading error:', e)
    raise

# Determine the correct pyramid levels to match the generator's n_local_enhancers
# The model expects len(pyd) == n_local_enhancers + 1
try:
    n_local = getattr(Gr, 'n_local_enhancers', None)
    if n_local is None and hasattr(Gr, 'module'):
        n_local = getattr(Gr.module, 'n_local_enhancers', None)
    if n_local is None:
        # fallback to 1 (common default)
        n_local = 1
    required_levels = int(n_local) + 1
    print('Generator n_local_enhancers =', n_local, '; building pyramid with', required_levels, 'levels')
except Exception as e:
    print('Could not determine generator n_local_enhancers:', e)
    required_levels = 3

# Create pyramid with the required number of levels
src_pyd = img_utils.create_pyramid(src_t, n=required_levels)
print('Pyramid levels:', len(src_pyd), 'shapes:', [p.shape for p in src_pyd])

# Normalization tensors (training uses these values)
img_mean = torch.as_tensor([0.5,0.5,0.5], device=device).view(1,3,1,1)
img_std  = torch.as_tensor([0.5,0.5,0.5], device=device).view(1,3,1,1)
context_mean = torch.as_tensor([0.485,0.456,0.406], device=device).view(1,3,1,1)
context_std  = torch.as_tensor([0.229,0.224,0.225], device=device).view(1,3,1,1)

# Compute landmarks context from the target (hrnet expects ImageNet-normalized input)
with torch.no_grad():
    tgt_for_lms = (tgt_t - context_mean) / context_std
    context = L(tgt_for_lms)
    context = landmarks_utils.filter_landmarks(context)

# Normalize source pyramid to [-1,1] as training did
for i in range(len(src_pyd)):
    src_pyd[i] = (src_pyd[i] - img_mean) / img_std

# Build generator input: for each pyramid level concatenate image + resized context (training order preserved)
inp = []
for p in range(len(src_pyd)-1, -1, -1):
    c = F.interpolate(context, size=src_pyd[p].shape[2:], mode='bicubic', align_corners=False)
    inp.insert(0, torch.cat((src_pyd[p], c), dim=1))

print('Input pyramid shapes for generator:', [x.shape for x in inp])

# Run generator and save output (handle single-tensor and list outputs)
with torch.no_grad():
    try:
        reenact_img = Gr(inp)
    except Exception as e:
        print('Generator forward failed:', e)
        traceback.print_exc()
        raise

# If model returns a list/pyramid, take the last (highest res) or the tensor itself
if isinstance(reenact_img, (list, tuple)):
    out_tensor = reenact_img[-1]
else:
    out_tensor = reenact_img

# Move to CPU and convert to BGR uint8 for saving
out_bgr = img_utils.tensor2bgr(out_tensor[0].cpu())
cv2.imwrite(str(out_path), out_bgr)
print('Saved reenactment to:', out_path)


Device: cuda:0
Reenact model: fsgan\weights\nfv_msrunet_256_1_2_reenactment_v2.1.pth
Landmarks model: fsgan\weights\hr18_wflw_landmarks.pth
Source: a.jpg Target: j.jpg
=> Loading reenactment model: "nfv_msrunet_256_1_2_reenactment_v2.1.pth"...
Checkpoint arch: res_unet.MultiScaleResUNet(in_nc=101,out_nc=3,flat_layers=(2,2,2,2),ngf=128)
=> Loading landmarks model: "hr18_wflw_landmarks.pth"...
Checkpoint arch: res_unet.MultiScaleResUNet(in_nc=101,out_nc=3,flat_layers=(2,2,2,2),ngf=128)
=> Loading landmarks model: "hr18_wflw_landmarks.pth"...
Loaded landmarks using utils.load_model()
Generator n_local_enhancers = 1 ; building pyramid with 2 levels
Pyramid levels: 2 shapes: [torch.Size([1, 3, 256, 256]), torch.Size([1, 3, 128, 128])]
Input pyramid shapes for generator: [torch.Size([1, 101, 256, 256]), torch.Size([1, 101, 128, 128])]
Loaded landmarks using utils.load_model()
Generator n_local_enhancers = 1 ; building pyramid with 2 levels
Pyramid levels: 2 shapes: [torch.Size([1, 3, 256, 25

In [6]:
# Cell 6b: Full pipeline inline (preprocessing + segmentation + inpainting + blending)
from pathlib import Path
from fsgan.notebook_helpers.reenact_preprocess import run_full_pipeline

OUT_DIR = Path('outputs')
OUT_DIR.mkdir(exist_ok=True)
src = str(Path('a.jpg'))
tgt = str(Path('j.jpg'))
out_file = str(OUT_DIR / 'reenact_full_composited.png')
print('Running full pipeline:', src, '->', tgt)
# This may take some time while models are loaded onto the device
result_bgr, intermediates, src_crop, tgt_crop = run_full_pipeline(src, tgt, out_path=out_file)
print('Saved full composited result to', out_file)
# If running inside notebook show via matplotlib (optional)
try:
    import matplotlib.pyplot as plt, cv2
    img = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(6,6)); plt.axis('off'); plt.imshow(img); plt.show()
except Exception as e:
    print('Could not display image inline:', e)

Running full pipeline: a.jpg -> j.jpg
=> Loading reenactment model: "nfv_msrunet_256_1_2_reenactment_v2.1.pth"...
=> Loading landmarks model: "hr18_wflw_landmarks.pth"...
=> Loading landmarks model: "hr18_wflw_landmarks.pth"...




=> Loading segmentation model: "celeba_unet_256_1_2_segmentation_v2.pth"...
=> Loading completion model: "ijbc_msrunet_256_1_2_inpainting_v2.pth"...
=> Loading completion model: "ijbc_msrunet_256_1_2_inpainting_v2.pth"...
=> Loading blending model: "ijbc_msrunet_256_1_2_blending_v2.pth"...
=> Loading blending model: "ijbc_msrunet_256_1_2_blending_v2.pth"...


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Saved full composited result to outputs\reenact_full_composited.png


  plt.figure(figsize=(6,6)); plt.axis('off'); plt.imshow(img); plt.show()


In [7]:
# Cell 7: Full swap pipeline quick test (this will run reenact+segmentation+inpainting+blending)
from fsgan.inference import swap as swap_mod
out_swap = OUT_DIR / 'swap'
out_swap.mkdir(exist_ok=True)
try:
    swap_mod.main([str(IMG_A)], [str(IMG_J)], output=str(out_swap))
    print('Swap pipeline completed. Check', out_swap)
except Exception as e:
    print('Swap pipeline failed:', e)
    traceback.print_exc()



=> using GPU devices: 0
=> Loading face pose model: "hopenet_robust_alpha1.pth"...
=> Loading face landmarks model: "hr18_wflw_landmarks.pth"...
=> Loading face landmarks model: "hr18_wflw_landmarks.pth"...
=> Loading face segmentation model: "celeba_unet_256_1_2_segmentation_v2.pth"...
=> Loading face segmentation model: "celeba_unet_256_1_2_segmentation_v2.pth"...
=> Loading face reenactment model: "nfv_msrunet_256_1_2_reenactment_v2.1.pth"...
=> Loading face reenactment model: "nfv_msrunet_256_1_2_reenactment_v2.1.pth"...
=> Loading face completion model: "ijbc_msrunet_256_1_2_inpainting_v2.pth"...
=> Loading face completion model: "ijbc_msrunet_256_1_2_inpainting_v2.pth"...
=> Loading face blending model: "ijbc_msrunet_256_1_2_blending_v2.pth"...
=> Loading face blending model: "ijbc_msrunet_256_1_2_blending_v2.pth"...


  _C._set_default_tensor_type(t)


=> Detecting faces in video: "j.jpg..."


100%|██████████| 1/1 [00:00<00:00,  4.18frames/s]



=> Extracting sequences from detections in video: "j.jpg"...


100%|██████████| 2/2 [00:00<00:00, 13189.64it/s]

=> Cropping image sequences from image: "j.jpg"...
=> Computing face poses for video: "j_seq00.jpg"...
Swap pipeline failed: [WinError 2] Le fichier spécifié est introuvable



Traceback (most recent call last):
  File "C:\Users\Arthur\AppData\Local\Temp\ipykernel_16692\2496922654.py", line 6, in <module>
    swap_mod.main([str(IMG_A)], [str(IMG_J)], output=str(out_swap))
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Arthur\Documents\Github\fsgan\fsgan\inference\swap.py", line 498, in main
    face_swapping(source[0], target[0], output, select_source, select_target)
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Arthur\Documents\Github\fsgan\fsgan\inference\swap.py", line 239, in __call__
    source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
                                                ~~~~~~~~~~^^^^^^^^^^^^^
  File "c:\Users\Arthur\Documents\Github\fsgan\fsgan\preprocess\preprocess_video.py", line 469, in cache
    self.process_pose(input_path, output_dir, seq_file_path)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Art