# FSGAN Pipeline Setup & Testing

This notebook sets up and tests the pretrained FSGAN models for face swapping.

**Contents:**
1. Environment setup and imports
2. Load pretrained models (reenactment, segmentation, landmarks)
3. Test individual components
4. Run full face swap pipeline
5. Analyze pipeline quality

## 1. Environment Setup

In [None]:
!pip install ffmpeg-python

In [None]:
import sys, subprocess, importlib
pkgs = ['torch', 'torchvision', 'numpy', 'opencv_python', 'tqdm', 'matplotlib']
missing = []
for p in pkgs:
    try:
        importlib.import_module(p)
    except Exception:
        missing.append(p)

if missing:
    print('Missing packages:', missing)
    print('Installing missing packages into the current Python environment. This may take a few minutes.')
    cmd = [sys.executable, '-m', 'pip', 'install'] + missing
    subprocess.check_call(cmd)
else:
    print('All minimal packages present')


import os
import traceback
from pathlib import Path
import torch
import torchvision
import numpy as np
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F


from fsgan.utils.utils import load_model
from fsgan.utils import img_utils, landmarks_utils
from fsgan.models.hrnet import hrnet_wlfw
from fsgan.inference import swap as swap_mod

from fsgan.notebook_helpers.reenact_preprocess import run_full_pipeline

import traceback
from pathlib import Path


ROOT = Path('.')
WEIGHTS_DIR = ROOT / 'fsgan' / 'weights'
OUT_DIR = ROOT / 'outputs'
OUT_DIR.mkdir(exist_ok=True)
IMG_A = ROOT / 'j.jpg'
IMG_J = ROOT / 'a.jpg'
print('Weights directory:', WEIGHTS_DIR)
print('Expecting images at:', IMG_A, IMG_J)
print('Outputs will be saved to:', OUT_DIR)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device', device)

def load_image_as_tensor(p, size=256, device=None):
    im = cv2.imread(str(p))
    if im is None:
        raise FileNotFoundError(f'Image not found: {p}')
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (size, size), interpolation=cv2.INTER_AREA)
    im = im.astype('float32') / 255.0
    t = torch.from_numpy(im.transpose(2,0,1)).unsqueeze(0)
    if device is not None:
        t = t.to(device)
    return t

def read_bgr_tensor(p, device=device, size=256):
    im = cv2.imread(str(p))
    if im is None:
        raise FileNotFoundError(f'Image not found: {p}')
    if size is not None:
        im = cv2.resize(im, (size, size), interpolation=cv2.INTER_AREA)
    return img_utils.bgr2tensor(im, normalize=False).to(device)

def save_mask(mask, outpath):
    cv2.imwrite(str(outpath), mask)

print('Top-level imports and helpers are ready')

## 2. Test Individual Components

### Segmentation Model

In [None]:
seg_w = WEIGHTS_DIR / 'celeba_unet_256_1_2_segmentation_v2.pth'
out_seg_dir = OUT_DIR / 'segmentation'
out_seg_dir.mkdir(exist_ok=True)
try:
    model_seg = load_model(str(seg_w), 'segmentation', device=device)
    model_seg.eval()
    t = load_image_as_tensor(IMG_A, size=256, device=device)
    with torch.no_grad():
        pred = model_seg(t)

    pred_np = pred.detach().cpu().numpy()
    if pred_np.ndim == 4:
        mask = pred_np.argmax(1)[0].astype('uint8') * 85 
    else:
        mask = (pred_np[0,0] * 255).astype('uint8')
    out_path = out_seg_dir / 'a_seg_mask.png'
    save_mask(mask, out_path)
    print('✓ Segmentation saved to', out_path)
except Exception as e:
    print('✗ Segmentation test failed:', e)
    traceback.print_exc()

### Landmarks Model

In [None]:
lms_w = WEIGHTS_DIR / 'hr18_wflw_landmarks.pth'
out_lms_dir = OUT_DIR / 'landmarks'
out_lms_dir.mkdir(exist_ok=True)
try:
    model_lms = load_model(str(lms_w), 'landmarks', device=device)
    model_lms.eval()
    t = load_image_as_tensor(IMG_A, size=256, device=device)
    with torch.no_grad():
        out = model_lms(t)
    print('Landmarks model forward output shape:', getattr(out, 'shape', None))

    try:
        out_np = out.detach().cpu().numpy()
        if out_np.ndim == 4:
            hm = out_np[0,0]
            hm = (255 * (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)).astype('uint8')
            cv2.imwrite(str(out_lms_dir / 'a_landmark_heatmap_ch0.png'), hm)
            print('✓ Saved example landmark heatmap to', out_lms_dir / 'a_landmark_heatmap_ch0.png')
    except Exception as e:
        print('Could not save landmark heatmap:', e)
except Exception as e:
    print('✗ Landmarks test failed:', e)
    traceback.print_exc()

## 3. Full Reenactment Pipeline Test

Test the complete face reenactment pipeline with visualization.

In [None]:
%matplotlib inline

OUT_DIR = Path('outputs')
OUT_DIR.mkdir(exist_ok=True)
src = str(Path('input/tim.jpg'))
tgt = str(Path('input/tom.jpg'))
out_file = str(OUT_DIR / 'reenact_full_composited.png')

print(f"Running full pipeline: {src} → {tgt}")
result_bgr, intermediates, src_crop, tgt_crop = run_full_pipeline(
    src, tgt, 
    out_path=out_file, 
    reenact=True, 
    use_detector=True, 
    device=device,
    crop_scale=1.5, 
    resolution=256
)

print('✓ Saved full result to', out_file)

try:
    img = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(6,6))
    plt.axis('off')
    plt.imshow(img)
    plt.title('Face Swap Result')
    plt.show()
except Exception as e:
    print('Could not display image inline:', e)

## 4. Visualize Pipeline Intermediates

Show all intermediate steps: segmentation, inpainting, blending

In [None]:
import math
from fsgan.utils.img_utils import tensor2bgr

def tensor_to_bgr_uint8(x):
    import numpy as _np
    import torch as _torch
    if isinstance(x, _torch.Tensor):
        t = x.detach().cpu()
        if t.ndim == 4:
            t = t[0]
        if t.min() >= -1.1 and t.max() <= 1.1:
            return tensor2bgr(t).astype('uint8')
        else:
            arr = (t.numpy().transpose(1,2,0) * 255.0).clip(0,255).astype('uint8')
            return arr[:, :, ::-1]
    else:
        arr = x.copy()
        if arr.dtype != _np.uint8:
            arr = arr.astype('uint8')
        return arr

def seg_to_color(seg_arr):
    import numpy as _np
    h,w = seg_arr.shape
    cmap = _np.array([[0,0,0],[0,255,0],[0,0,255],[255,0,0],[255,255,0]], dtype='uint8')
    out = _np.zeros((h,w,3), dtype='uint8')
    labels = _np.clip(seg_arr, 0, cmap.shape[0]-1)
    for i in range(cmap.shape[0]):
        out[labels==i] = cmap[i]
    return out

# Gather images to display
imgs = []
titles = []

# Preprocessing crops
try:
    imgs.append(src_crop)
    titles.append('Source crop')
except Exception:
    pass
    
try:
    imgs.append(tgt_crop)
    titles.append('Target crop')
except Exception:
    pass
    
# Reenact result
try:
    imgs.append(result_bgr)
    titles.append('Final composited')
except Exception:
    pass
    
# Reenactment-only (generator output) if available
if 'reenact_tensor' in intermediates:
    try:
        imgs.append(tensor_to_bgr_uint8(intermediates['reenact_tensor']))
        titles.append('Reenact')
    except Exception:
        pass
        
# Segmentation maps
if 'reenact_seg' in intermediates:
    s = intermediates['reenact_seg']
    try:
        s_np = s.detach().cpu().numpy() if hasattr(s, 'detach') else s
        if s_np.ndim == 4:
            lab = s_np.argmax(1)[0].astype('int')
        else:
            lab = s_np.astype('int')
        imgs.append(seg_to_color(lab)[:,:,::-1])
        titles.append('Reenact seg')
    except Exception:
        pass
        
if 'tgt_seg' in intermediates:
    s = intermediates['tgt_seg']
    try:
        s_np = s.detach().cpu().numpy() if hasattr(s, 'detach') else s
        if s_np.ndim == 4:
            lab = s_np.argmax(1)[0].astype('int')
        else:
            lab = s_np.astype('int')
        imgs.append(seg_to_color(lab)[:,:,::-1])
        titles.append('Target seg')
    except Exception:
        pass

# Plot grid
n = len(imgs)
if n == 0:
    print('No intermediates available to display')
else:
    cols = min(4, n)
    rows = math.ceil(n / cols)
    plt.figure(figsize=(4*cols, 3*rows))
    for i, im in enumerate(imgs):
        plt.subplot(rows, cols, i+1)
        plt.axis('off')
        if im.ndim == 2:
            plt.imshow(im, cmap='gray')
        else:
            plt.imshow(im[:,:,::-1])
        plt.title(titles[i])
    plt.tight_layout()
    plt.show()

## Summary

This notebook tested the pretrained FSGAN pipeline:
- ✓ Segmentation model working
- ✓ Landmarks model working  
- ✓ Full reenactment pipeline working

**Next steps:**
- For multi-subject finetuning: Open `02_multisubject_finetuning.ipynb`
- For per-subject finetuning: Open `03_per_subject_finetuning.ipynb`