# Introduction
Examining shoeprint evidence found at a crime scene assists investigators in identifying suspects of the crime. We introduce a tool to retrieve the closest matching shoe models to query crime-scene prints. The tool uses CriSp in its backend to match the crime-scene print to a large-scale database of tread depth maps. The details of this method can be found in its [project page](https://github.com/Samia067/Crisp) and paper: [CriSp: Leveraging Tread Depth Maps for Enhanced Crime-Scene Shoeprint Matching](https://arxiv.org/abs/2404.16972), currently under review at [ECCV 2024](https://eccv.ecva.net/).

**How the retrieval works**:
We precompute and store features for the depth maps in the reference database.
When a query crime-scene sheoprint and a corresponding mask is uploaded to the system, we compute its features using CriSp. We mask out irrelevant portions of both the query and database features with the provided mask. Similarity between query and database images are computed using cosine similarity and a ranked list is generated. 

**Reference Database statistics**:
The reference database used contains depth maps from 56,847 shoe tread instances and 24,766 shoe models. 

# How to use this tool
steps to make a query
1. Run the code. On the first run, wait for upto 5 minutes while the model and dataset are downloaded. 
2. After the download is complete, you will see a box where you can specify the number of results you want to see for each query, a file uploader for the query print and another for the mask. If  you do not have a preference for the number of results to view, you can leave it at the default value of 10. 
3. Select a query print using the button labeled 'Upload Print'. The print image should be of size 384x192. If a different sized image is uploaded, it will be resized to this shape. The query print should be modified such that the print lies on the correct position on the yellow shoe outline visualized. 
4. Next, upload a mask for the print where only the visible portions have a value of 1 and the rest is 0.
5. Once you select the query print and mask, you can hit 'Search' and you will see a ranked list of results. 

Example query prints

# Contact

Please feel free to email me at (sshafiqu [at] uci [dot] edu) if you have any questions. Author list:

<table>
<tr>
<td> <img src="git/figures/samia.jpg" alt="Drawing" style="width: 145px;"/> </td>
<td>
    <p>Samia Shafique</p>
    <p>PhD Student</p>
    <p>University of California, Irvine</p>
</td>

</tr>
<tr>
<td> <img src="git/figures/charless.jpg" alt="Drawing" style="width: 145px;"/> </td>
<td>
    <p>Charless Fowlkes</p>
    <p>Professor</p>
    <p>University of California, Irvine</p>
</td>
</tr>
</table>



In [4]:
#@title ##← Click on the circled arrow and wait for up to 5 minutes

import sys
# !{sys.executable} -m pip install torch==1.11.0+cu102  --extra-index-url https://download.pytorch.org/whl/cu102



import torch
import torch.nn as nn
import torch.nn.functional as F

import requests
import numpy as np
import os
import urllib.request
import numpy as np

from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets
import io
from PIL import Image
import zipfile
import csv

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=2, zero_init_residual=False, final_pool=False):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.final_pool = final_pool

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        if self.final_pool:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride, maxpool=False, maxpool2=False):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
            if maxpool2:
                layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        if maxpool:
            layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        return nn.Sequential(*layers)

    def forward(self, x, mask=None):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out1 = self.layer4(out)
        if self.final_pool:
            if mask is not None:
                # mask penultimate features out1
                # in other words, average over only the relavant portions of the image
                out1[~mask.expand_as(out1)] = 0

            out = self.avgpool(out1)
            out = torch.flatten(out, 1)
            return out
        return out1

    
class MatchingModel(nn.Module):
    """backbone + projection head"""
    # this is different from SupConResNet because
    def __init__(self, feat_dim=128, in_channel=2, feature_dim=(1,1)):
        super(MatchingModel, self).__init__()
        # model_fun, dim_in = model_dict[name]
        # self.encoder = model_fun(in_channel=in_channel, final_pool=False)
        self.encoder = ResNet(Bottleneck, [3, 4, 6, 3], in_channel=in_channel, final_pool=False)
        dim_in = 2048
        self.head = nn.Sequential(
                nn.Conv2d(dim_in, dim_in, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(dim_in, feat_dim, 1)
            )
        self.feature_dim=feature_dim
        self.avgpool = nn.AdaptiveAvgPool2d(self.feature_dim)

    def forward(self, x, spatial_feat=False, mask=None):
        feat = self.encoder(x, mask=mask)
        embedded = self.head(feat)
        if mask is not None:
            embedded[~mask.expand_as(embedded)] = 0

        return embedded if spatial_feat else self.vectorize(embedded) 

    def vectorize(self, x, spatial=False):
        if not spatial or self.feature_dim!=(1,1):
            x = self.avgpool(x)
        return torch.flatten(F.normalize(x, dim=1), 1)

    

def download(id, destination, download_name=None):
    download_name = download_name if download_name else destination
    if not os.path.exists(download_name):
        URL = "https://docs.google.com/uc?export=download"

        session = requests.Session()

        response = session.get(URL, params = { 'id' : id }, stream = True)
        token = get_confirm_token(response)
        
        if token:
            # params = { 'id' : id, 'confirm' : token }
            URL= 'https://drive.usercontent.google.com/download?export=download&confirm=t'
            params = { 'id' : id}
            response = session.get(URL, params = params, stream = True)

        save_response_content(response, destination)  

def get_confirm_token(response):
    if 'too large for Google to scan for viruses. Would you still like to download this file?' in response.text:
        return True

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)


def load_model(weights, file_id):
    net = MatchingModel().to(device)
    download(file_id, weights)
    file = torch.load(weights, map_location=device)
    net.load_state_dict(file)
    return net


def display_results(query_image, match_names, match_values):
    clear_output()
    display(search_widget)
    display(query_image)
    for i, match_name in enumerate(match_names):
        
        wi_rank = widgets.HTML(value=str(i+1) + "          ")

        path = os.path.join(database_name, 'image', match_name)
        img1 = open(path, 'rb').read()
        wi1 = widgets.Image(value=img1)
        
        path = os.path.join(database_name, 'print', match_name)
        img2 = open(path, 'rb').read()
        wi2 = widgets.Image(value=img2)
        
        # metadata information
        shoeid = int(match_name[:6])
        metadata_textboxes = []
        for header in ['brand', 'product', 'gender']:
            if shoeid in metadata:
                metadata_textboxes.append(widgets.Text(
                    value=header + ': ' + metadata[shoeid][header],
                    disabled=True   
                ))
        
        button = widgets.Button(description="More like this")
        button.filename = match_name
        button.on_click(on_show_more_button_clicked)
        
        box_layout = widgets.Layout(display='flex',
            flex_flow='row',
            align_items='center')
        metadata_textboxes.append(button)
        
        metadata_widget = widgets.VBox(metadata_textboxes, layout=widgets.Layout(justify_content='center'))
    
        display(widgets.HBox([wi_rank, wi1, wi2, metadata_widget], layout=box_layout))
    

def outline_image(image, outline):
    image = np.array(image)
    if len(image.shape) == 2:
        image = image[:, :, np.newaxis].repeat(3, axis=2)    
    if np.max(image) == 1:
        image = image *255
    outline_mask = np.array(outline)[:,:, 3] != 0
    outline_color = np.array(outline)[0,0,0:3]
#     image[outline_mask] = (image[outline_mask] + outline_color)/2
    image[outline_mask] = outline_color
    return Image.fromarray(image)
    

def on_button_clicked(b):
    # print(print_uploader.value.items())
    # print(mask_uploader.value.items())
    for (print_name, print_file_info), (mask_name, mask_file_info) in zip(print_uploader.value.items(), mask_uploader.value.items()):
        img = Image.open(io.BytesIO(print_file_info['content']))
        query_image = img.resize((384, 192))
        mask = Image.open(io.BytesIO(mask_file_info['content']))
        mask = mask.resize((384, 192))
        n_results = result_count.value
        match_names, match_values = image_search(query_image, mask, n_results)
        print(match_values)
        outlined_query = outline_image(query_image, outline)
        display_results(outlined_query, match_names, match_values)
        

def on_show_more_button_clicked(b):
    path = os.path.join(database_name, 'print', b.filename)
    query = Image.open(path)
    n_results = result_count.value
    match_names, match_values = image_search(query, n_results)
    outlined_query = outline_image(query, outline)
    display_results(outlined_query, match_names, match_values)
    
def image_search(query_image, mask, n_results=24):
    img = np.array(query_image)
    if len(img.shape) == 3:
        img = np.mean(img, axis=2)
    img = torch.tensor(img/255.0).to(device).unsqueeze(0).unsqueeze(0)
    zeros = torch.zeros(img.shape).to(device)
    query = torch.cat((img, zeros), dim=1).float()

    mask = np.array(mask)
    if len(mask.shape) == 3:
        mask = np.mean(mask, axis=2) 
    mask = torch.tensor(mask/255.0).to(device).unsqueeze(0).unsqueeze(0) > 0.5
    
    net.eval()
    with torch.no_grad():
        embedded_mask = (F.interpolate(mask.float(), (6, 12), mode='bilinear') > 0)
        query_features = net(query, spatial_feat=True)
        query_features = net.vectorize((query_features*embedded_mask.to(query_features.device)).to(query.device), spatial=True)
        query_dot_dataset = torch.matmul(query_features, database_features.T)
        values, indices = torch.sort(query_dot_dataset, dim=1, descending=True)
        indices = indices.squeeze().detach().cpu().numpy()
        match_names = database_names[indices][:n_results]
        values = values.squeeze().detach().cpu().numpy()
        match_values = values[:n_results]
    return match_names, match_values

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')  
print(device)
    
print('==> Loading model')
file_id = '1X8b75CWiFl11VSkhOG0XelQlZrXGhC_Y'
# https://drive.google.com/file/d/1FFSh0uSUr5c4ebaCkzIELv98wRizSvrM/view?usp=sharing
file_id = '1FFSh0uSUr5c4ebaCkzIELv98wRizSvrM'
net = load_model('model.pth', file_id)

print('==> Loading database')
database_name = 'database'
file_id = '1v4DYJ1hF4UadsMOa3yachsORPxwpHQpz'
download(file_id, database_name + '.zip', download_name='database.zip')
if not os.path.exists(database_name):
    with zipfile.ZipFile(database_name + '.zip', 'r') as zip_ref:
        zip_ref.extractall()
database_feat_name = os.path.join(database_name, 'features.pth')
database_features = torch.load(database_feat_name, map_location=device).to(device)
database_names = os.path.join(database_name, 'names.npy')
database_names = np.load(database_names)

# download outline
outline_name = 'outline.png'
file_id = '1_34Waq_gd7o2KzOjVgUMtVS_3B-KNmSb'
download(file_id, outline_name)
outline = Image.open(outline_name)

clear_output()
search_button = widgets.Button(description="Search")
# output = widgets.Output()
print_uploader = widgets.FileUpload(multiple=False, description='Upload Print')
mask_uploader = widgets.FileUpload(multiple=False, description='Upload Mask')
result_count = widgets.IntText(
    value=10,
    description='Count',
    disabled=False
)
search_widget = widgets.HBox([result_count, print_uploader, mask_uploader, search_button], layout=widgets.Layout(justify_content='center'))
display(search_widget)
search_button.on_click(on_button_clicked)

# get metadata
with open(os.path.join(database_name, 'metadata.csv'), newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    metadata = {}
    metadata_headers = ['shoeid', 'gender', 'brand', 'product', 'category', 'prints']
    for row in reader:
        shoeid, gender, brand, product, category, prints = [row[header] for header in metadata_headers]
        metadata[int(shoeid)] = {metadata_headers[i+1]: md for i, md in enumerate([gender, brand, product, category, prints])}



HBox(children=(IntText(value=10, description='Count'), FileUpload(value={}, description='Upload Print'), FileU…