# Gallery Filter Network Demo
This notebook implements a demo showcasing the SeqNeXt person search model and Gallery Filter Network (GFN) scoring process.

The notebook loads images from the web, and you can easily try it out on other image URLs.

All dependencies are imported below, and the model is loaded via torchscript: our package is not used to make the demo (somewhat) self-contained.

In [None]:
# Torch libs
import torch
## Disable nvfuser for now
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
import torch.nn.functional as F
from tqdm import tqdm

# Libs for data pre-processing
import cv2
import numpy as np
from albumentations.augmentations.geometric import functional as FGeometric
import torchvision.transforms.functional as TF

# Libs for loading images
import os
import urllib.request
from PIL import Image
import ssl
## Avoid SSL error
ssl._create_default_https_context = ssl._create_unverified_context

# Libs for visualization
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

## Helper Functions

In [None]:
# Download images from the web
def load_web_image(
        demo_image_dir='./demo_images',
        image_url='https://pics.filmaffinity.com/Friends_TV_Series-783758929-large.jpg',
        file_name='demo.png',
        display=True,
    ):
    # Make dir to store the images
    if not os.path.exists(demo_image_dir):
        os.makedirs(demo_image_dir)

    # Download the image URL from the web
    urllib.request.urlretrieve(
        image_url,
        f'{demo_image_dir}/{file_name}')

    # Load the image from disk
    img = Image.open(f'{demo_image_dir}/{file_name}').convert('RGB')
    
    # Plot the image and show dimensions
    if display:
        fix, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(img)
        plt.show()
    
    return img

In [None]:
# Convert PIL image to torch tensor
def to_tensor(image):
    arr = np.array(image)
    arr_wrs = window_resize(arr)
    tsr = torch.FloatTensor(arr_wrs)
    tsr_norm = normalize(tsr)
    tsr_input = tsr_norm.permute(2, 0, 1).to(device)
    return tsr_input

# Convert torch tensor to PIL image
def to_image(tensor):
    tsr_denorm = denormalize(tensor.permute(1, 2, 0).cpu()).clip(min=0, max=1)
    arr = tsr_denorm.numpy()
    arr_uint8 = (arr * 255.0).astype(np.uint8)
    image = Image.fromarray(arr_uint8)
    return image

In [None]:
# Normalize image tensor using ImageNet stats
def normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = torch.FloatTensor(mean).view(1, 1, 3)
    std = torch.FloatTensor(std).view(1, 1, 3)
    return tensor.div(255.0).sub(mean).div(std)

# Denormalize image tensor using ImageNet stats
def denormalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = torch.FloatTensor(mean).view(1, 1, 3)
    std = torch.FloatTensor(std).view(1, 1, 3)
    return tensor.mul(std).add(mean)

In [None]:
# Resize image (numpy array) to fit in fixed size window
def window_resize(img, min_size=900, max_size=1500, interpolation=cv2.INTER_LINEAR):
    height, width = img.shape[:2]
    image_min_size = min(width, height)
    image_max_size = max(width, height)
    scale_factor = min_size / image_min_size
    if image_max_size * scale_factor > max_size:
        return FGeometric.longest_max_size(img, max_size=max_size, interpolation=interpolation)
    else:
        return FGeometric.smallest_max_size(img, max_size=min_size, interpolation=interpolation)

In [None]:
# Plot detected boxes on image with matplotlib
def show_detects(image, detect, person_sim=None, show_detect_score=False, ax=None, title=None, xlabel=None):
    # Setup subplot
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    # Setup labels
    if title is not None:
        ax.set_title(title, fontsize=20, fontweight='bold')
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=20)
    # Show the image
    ax.imshow(denormalize(image.permute(1, 2, 0).cpu()))
    # Plot boxes (and optionally similarity scores)
    for i, (box, score) in enumerate(zip(detect['det_boxes'].cpu().tolist(), detect['det_scores'].cpu().tolist())):
        x, y, x2, y2 = box
        w, h = x2 - x, y2 - y
        ax.add_patch(Rectangle((x, y), w, h, edgecolor='green', lw=4, fill=False, alpha=0.8))
        ax.add_patch(Rectangle((x+2, y+2), w-4, h-4, edgecolor='whitesmoke', lw=1, fill=False, alpha=0.8))
        ## Display person similarity if this is supplied
        if person_sim is not None:
            ax.text(x, y, '{:.2f}'.format(person_sim[i].item()), ha="left", va="bottom", size=14,
                bbox=dict(boxstyle="square,pad=0.2", fc="whitesmoke", alpha=0.8, ec='black', lw=2.0)
            )
        ## Otherwise, display detected box scores
        elif show_detect_score:
            ax.text(x, y, '{:.2f}'.format(detect['det_scores'][i].item()), ha="left", va="bottom", size=14,
                bbox=dict(boxstyle="square,pad=0.2", fc="lightblue", alpha=0.8, ec='black', lw=2.0)
            ) 
    # Remove ticks and expand borders
    ax.set_xticks([])
    ax.set_yticks([])
    [x.set_linewidth(3) for x in ax.spines.values()]

# Return list of detected (PIL) image crops
def get_crops(tensor, detect, ax=None):
    # Convert tensor back to image
    image = to_image(tensor)
    # Extract crops using detected boxes
    crop_list = []
    for i, box in enumerate(detect['det_boxes'].cpu().tolist()):
        x1, y1, x2, y2 = box
        crop = image.crop((x1, y1, x2, y2))
        crop_list.append(crop)
    return crop_list

# Show ranked re-id crops
def show_reid(crop_list, score_list, plot_width=0.75):
    # Sort crops by decreasing score
    sorted_score_idx = np.argsort(score_list)[::-1]
    sorted_score_list = [score_list[i] for i in sorted_score_idx]
    sorted_crop_list = [crop_list[1:][i] for i in sorted_score_idx]
    # Plotting helper function
    def _plot_subplot(_ax, title='', fw=None):
        _ax.set_title(title, fontweight=fw, fontsize=11)
        _ax.set_xticks([])
        _ax.set_yticks([])
        [x.set_linewidth(2) for x in _ax.spines.values()]
     # Plot query crop
    fig, ax = plt.subplots(nrows=1, ncols=len(crop_list), figsize=(plot_width*len(crop_list), plot_width*3))
    ax[0].imshow(crop_list[0].resize((100, 300)))
    _plot_subplot(ax[0], title='Query', fw='bold')
    # Plot gallery crops
    for i, (crop, score) in enumerate(zip(sorted_crop_list, sorted_score_list), 1):
        ax[i].imshow(crop.resize((100, 300)))
        _plot_subplot(ax[i], title='s={:.2f}'.format(score))
        ax[i].set_xticks([])
        ax[i].set_yticks([])
        [x.set_linewidth(2) for x in ax[i].spines.values()]
    # Return fig
    return fig

## User Inputs

In [None]:
# Parameters
torchscript_path = '../torchscript/cuhk_final_convnext-base_e30.torchscript.pt'

# Device
device = torch.device('cuda')

In [None]:
# Query data
query_list = [
    {
        # Chandler 
        'url': 'https://i.redd.it/uqrvlpf667j51.jpg',
        'file': 'chandler_query.jpg',
    },
    {
        # Ross 
        'url': 'https://metro.co.uk/wp-content/uploads/2019/04/dr3-f247.jpg?quality=90&strip=all&zoom=1&resize=644%2C367',
        'file': 'ross_query.jpg',
    }
]

# Gallery data
gallery_list = [
    {
        'url': 'https://deadline.com/wp-content/uploads/2022/10/Screenshot-2022-10-22-at-09.45.12.png?w=681&h=383&crop=1',
        'file': 'friends_gallery3.jpg',
    },
    {
        'url': 'https://static.wikia.nocookie.net/friends/images/0/03/TOWChandler%27sWorkLaugh.png/revision/latest/scale-to-width-down/1000?cb=20180307145208',
        'file': 'friends_gallery4.jpg',
    }
]

## Image Loading

In [None]:
# Load queries
print('Loading query images:')
query_image_list = []
for query_dict in query_list:
    query_image = load_web_image(image_url=query_dict['url'], file_name=query_dict['file'])
    ## Take center crop of query image to make display look better
    query_image = TF.center_crop(query_image, min(query_image.size))
    query_image_list.append(query_image)

# Load gallery
print('Loading gallery images:')
gallery_image_list = []
for gallery_dict in gallery_list:
    gallery_image = load_web_image(image_url=gallery_dict['url'], file_name=gallery_dict['file'])
    gallery_image_list.append(gallery_image)

## Model Loading

Loading the torchscript model is just a single function call: no libs or configs needed.

In [None]:
# Load torchscript version of model
model = torch.jit.load(torchscript_path)

## Run Model on Query Scenes: Get Query Person and Scene Embeddings

There are two ways get query embeddings:
1. Let the model try to detect the query bounding box, and use the detected embedding.
2. Input the query bounding box explicitly. If the query scene just has one person, you can also supply the full scene extent as the bounding box.

Both methods 1. and 2. are done below as an example, but only the detected embedding is used later in the demo.

In [None]:
# Put query sample through model
query_output_list = []
with torch.no_grad():
    query_tensor_list = [to_tensor(query_image) for query_image in query_image_list]
    for query_tensor in query_tensor_list:
        ## Use the full scene extent as a bounding box
        query_box = torch.FloatTensor([0, 0, *query_tensor.shape[1:]]).unsqueeze(0).to(device)
        query_targets = [{'boxes': query_box}]
        ## Run query scene through the model
        detections = model([query_tensor], query_targets, inference_mode='both')
        ## Reorganize results
        for query_image, query_detect in zip([query_tensor], detections):
            ## Show detected boxes with their scores
            show_detects(query_tensor, query_detect, show_detect_score=True)        
            query_output_list.append(query_detect)

## Run Model on Gallery Scenes: Get Gallery Person and Scene Embeddings

In [None]:
# Put gallery samples through model
gallery_output_list = []
with torch.no_grad():
    gallery_tensor_list = [to_tensor(image) for image in gallery_image_list]
    ## Run query scene through the model
    detections = model(gallery_tensor_list, inference_mode='det')
    ## Reorganize results
    for tensor, detect in zip(gallery_tensor_list, detections):
        ## Show detected boxes with their scores
        show_detects(tensor, detect, show_detect_score=True)
        gallery_output_list.append(detect)

## Get Re-ID Scores: Compare Person Embeddings 

In [None]:
# For each query image
for query_detect in query_output_list:
    # Get query person embeddings
    query_person_emb = query_detect['det_emb']

    # For each gallery image
    for gallery_output_dict in gallery_output_list:
        ## Get gallery person embeddings
        gallery_person_emb = gallery_output_dict['det_emb']

        ## Compute person similarity: cosine similarity of person embeddings
        person_sim = torch.mm(
            F.normalize(query_person_emb, dim=1),
            F.normalize(gallery_person_emb, dim=1).T
        ).flatten()
        
        ## Store person similarity
        if 'person_sim' not in gallery_output_dict:
            gallery_output_dict['person_sim'] = []
        gallery_output_dict['person_sim'].append(person_sim)

## Get GFN Scores: Compare Query-Gated Scene Embeddings

In [None]:
# For each query image
for query_detect in query_output_list:
    # Get query person and scene embeddings
    query_person_emb = query_detect['det_emb']
    query_scene_emb = query_detect['scene_emb']

    # For each gallery image
    for gallery_output_dict in gallery_output_list:
        ## Get gallery scene embeddings
        gallery_scene_emb = gallery_output_dict['scene_emb']

        ## Compute query-scene similarity: cosine similarity of query-gated scene embeddings
        with torch.no_grad():
            qg_scene_sim = model.gfn.get_scores(query_person_emb, query_scene_emb, gallery_scene_emb).flatten().item()
        
        ## Store query-scene similarity
        if 'gfn_sim' not in gallery_output_dict:
            gallery_output_dict['gfn_sim'] = []
        gallery_output_dict['gfn_sim'].append(qg_scene_sim)

## Display Search and Re-ID Results

In [None]:
# Compute subplot width ratios
width_list = [query_tensor_list[0].shape[2]] + [g.shape[2] for g in gallery_tensor_list]
width_ratios = [w/sum(width_list) for w in width_list]
kw = dict(width_ratios=width_ratios)

# For each query image
for query_idx, (query_tensor, query_detect) in enumerate(zip(query_tensor_list, query_output_list)):
    # Initialize subplots
    detect_fig, ax = plt.subplots(nrows=1, ncols=len(width_list), figsize=(20, 12), gridspec_kw=kw)
  
    # Show query detects
    show_detects(query_tensor, query_detect, ax=ax[0], title='Query')

    # Get query embeddings
    query_crop_list = get_crops(query_tensor, query_detect)

    # For each gallery image
    full_crop_list, full_score_list = query_crop_list, []
    for i, gallery_output_dict in enumerate(gallery_output_list):
        ## Get person sim
        person_sim = gallery_output_dict['person_sim'][query_idx]
        
        ## Get query-scene sim
        qg_scene_sim = gallery_output_dict['gfn_sim'][query_idx]
        
        ## Show gallery detects
        show_detects(gallery_tensor_list[i], gallery_output_dict, ax=ax[i+1], person_sim=person_sim,
            title='Gallery {}'.format(i+1), xlabel='GFN Score: {:.2f}'.format(qg_scene_sim))
        
        ## Get image crops
        crop_list = get_crops(gallery_tensor_list[i], gallery_output_dict)
        score_list = person_sim.tolist()
        full_crop_list.extend(crop_list)
        full_score_list.extend(score_list)
        
    # Save detect fig
    detect_fig.tight_layout()
    detect_fig.savefig(f'query_detect{query_idx}.png', bbox_inches='tight', dpi=100)
        
    # Show ranked re-id crops
    reid_fig = show_reid(full_crop_list, full_score_list)
    
    # Save re-id fig
    reid_fig.tight_layout()
    reid_fig.subplots_adjust(wspace=0.0)
    reid_fig.savefig(f'query_reid{query_idx}.png', bbox_inches='tight', dpi=100)
        
# Adjust subplot spacing and show plots
plt.show()