# Loading our Finetuned Model

In our previous notebooks we:

- [Built a simple image search engine for fashion products using Jina's Docarray library](https://colab.research.google.com/github/alexcg1/neural-search-notebooks/blob/main/fashion-search/1_build_basic_search/basic_search.ipynb)
- [Finetuned our model using Jina Finetuner](https://colab.research.google.com/github/alexcg1/neural-search-notebooks/blob/main/fashion-search/2_finetune_model/finetune_model.ipynb)

Now we'll integrate our fine-tuned model into our original search engine and compare results

Next time we'll build our fashion search engine into something production-ready using [Jina's neural search framework](https://github.com/jina-ai/jina)

You can download this notebook from [GitHub](https://github.com/alexcg1/neural-search-notebooks). PRs and issues are always welcome!

## Before starting

1. Ensure you've completed the previous two notebooks.
2. Copy the `tuned-model` file from the finetuner tutorial and place it in the same folder as this notebook.

## 📺 Watch the video

Get a guided tour of the notebook and search results with Jack from Jina AI

In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo("Amo19S1SrhE", width=800, height=450)

## Configuration

In [None]:
# Check if we're running in Google Colab
try:
    import google.colab
    in_colab = True
except:
    in_colab = False

DATA_DIR = "./data"
DATA_PATH = f"{DATA_DIR}/*.jpg"
MAX_DOCS = 1000
QUERY_IMAGE = "./query.jpg" # image we'll use to search with
PLOT_EMBEDDINGS = False # Really useful but have to manually stop it to progress to next cell

# Toy data - If data dir doesn't exist, we'll get data of ~800 fashion images from here
TOY_DATA_URL = "https://github.com/alexcg1/neural-search-notebooks/raw/main/fashion-search/data.zip?raw=true"

## ⚙️ Setup

In [None]:
!pip install "docarray[full]==0.4.4"

In [None]:
from docarray import Document, DocumentArray

## 🖼️ Load images

In [None]:
# Download images if they don't exist
import os

if not os.path.isdir(DATA_DIR) and not os.path.islink(DATA_DIR):
    print(f"Can't find {DATA_DIR}. Downloading toy dataset")
    !wget "$TOY_DATA_URL" -O data.zip
    !unzip -q data.zip # Don't print out every darn filename
    !rm -f data.zip
else:
    print(f"Nothing to download. Using {DATA_DIR} for data")

docs = DocumentArray.from_files(DATA_PATH, size=MAX_DOCS)

## 🏭 Apply preprocessing

In [None]:
from docarray import Document

def preproc(d: Document):
    return (d.load_uri_to_image_tensor()  # load
             .set_image_tensor_shape((80, 60))  # ensure all images right size (dataset image size _should_ be (80, 60))
             .set_image_tensor_normalization()  # normalize color 
             .set_image_tensor_channel_axis(-1, 0))  # switch color axis for the PyTorch model later

docs.apply(preproc)

## 🧠 Embed images using original model

In [None]:
!pip install torchvision~=0.11

In [None]:
# Use GPU if available
import torch
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

import torchvision
model = torchvision.models.resnet50(pretrained=True)  # load ResNet50

docs.embed(model, device=DEVICE)

## Create query Document

Let's just use the first image from our dataset:

In [None]:
# Download query doc
!wget https://github.com/alexcg1/neural-search-notebooks/raw/main/fashion-search/1_build_basic_search/query.jpg -O query.jpg

query_docs = DocumentArray([Document(uri=QUERY_IMAGE)]) # Wrap in a DocumentArray
query_docs[0].display()

query_docs.apply(preproc)

query_docs.embed(model, device=DEVICE)

## Get matches and see results

In [None]:
query_docs.match(docs, limit=9)

(DocumentArray(query_docs[0].matches, copy=True)
    .apply(lambda d: d.set_image_tensor_channel_axis(0, -1)
                      .set_image_tensor_inv_normalization())).plot_image_sprites()

for match in query_docs[0].matches:
    print(match.scores["cosine"].value) # print score to see how confident the model is

## 🧠 Load new model and embed images

In [None]:
MODEL_FILENAME = "tuned-model"
model = torch.load(MODEL_FILENAME)

docs = DocumentArray.from_files(DATA_PATH, size=MAX_DOCS)
docs.apply(preproc)
docs.embed(model, device=DEVICE)

## Embed query Document

In [None]:
query_docs = DocumentArray([Document(uri=QUERY_IMAGE)])
query_docs.apply(preproc)
query_docs.embed(model, device=DEVICE)

## Get matches and see results

In [None]:
query_docs.match(docs, limit=9)

(DocumentArray(query_docs[0].matches, copy=True)
    .apply(lambda d: d.set_image_tensor_channel_axis(0, -1)
                      .set_image_tensor_inv_normalization())).plot_image_sprites()

for match in query_docs[0].matches:
    print(match.scores["cosine"].value) # print score to see how confident the model is