# Finetuning our model

In the previous notebooks we: [built a simple fashion search engine using Docarray](https://colab.research.google.com/github/alexcg1/neural-search-notebooks/blob/main/docarray/fashion-search/fashion_image_search.ipynb)

Now we'll finetune our model to deliver better results!

## Important note

This code won't run well in a notebook, since we need to run Finetuner on our local machine. Please:

- Download this notebook as a Python file to your local machine (*File* > *Download* > *Download .py*)
- Install finetuner in a virtual environment (`pip install finetuner`)
- Run this script from that directory

If you don't follow the above instructions the script will fail since it can't run the Finetuner GUI from within a notebook.

## Configuration

We'll set up some basic variables. Feel free to adapt these for your own project!

In [1]:
DATA_DIR = "./data"
DATA_PATH = f"{DATA_DIR}/images/*.jpg"
MAX_DOCS = 1000
QUERY_IMAGE = "./query.jpg" # image we'll use to search with
PLOT_EMBEDDINGS = False # Really useful but have to manually stop it to progress to next cell

# Toy data - If data dir doesn't exist, we'll get data of ~800 fashion images from here
TOY_DATA_URL = "https://github.com/alexcg1/neural-search-notebooks/blob/main/docarray/fashion-search/data.zip?raw=true"

## Setup

In [5]:
# We use "[full]" because we want to deal with more complex data like images (as opposed to text)
!pip install finetuner torchvision

Collecting finetuner
  Using cached finetuner-0.3.0-py3-none-any.whl
Collecting jina>=2.4.9
  Using cached jina-2.6.4-py3-none-any.whl
Collecting filelock
  Using cached filelock-3.4.2-py3-none-any.whl (9.9 kB)
Collecting pathspec
  Using cached pathspec-0.9.0-py2.py3-none-any.whl (31 kB)
Collecting aiostream
  Using cached aiostream-0.4.4-py3-none-any.whl
Collecting cryptography
  Using cached cryptography-36.0.1-cp36-abi3-manylinux_2_24_x86_64.whl (3.6 MB)
Collecting kubernetes>=18.20.0
  Using cached kubernetes-21.7.0-py2.py3-none-any.whl (1.8 MB)
Collecting python-multipart
  Using cached python_multipart-0.0.5-py3-none-any.whl
Collecting websockets
  Using cached websockets-10.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (111 kB)
Collecting uvloop
  Using cached uvloop-0.16.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.8 MB)
Collecting lz4<3.1.2
  Using cached lz4-3.1.1-cp37-cp37m-manylinux2010_x86_64.whl (1.8 

In [6]:
from docarray import Document, DocumentArray

## Load images

In [7]:
# Download images if they don't exist
import os

if not os.path.isdir(DATA_DIR) and not os.path.islink(DATA_DIR):
    print(f"Can't find {DATA_DIR}. Downloading toy dataset")
    !wget "$TOY_DATA_URL" -O data.zip
    !unzip -q data.zip # Don't print out every darn filename
    !rm -f data.zip
else:
    print(f"Nothing to download. Using {DATA_DIR} for data")

Nothing to download. Using ./data for data


In [8]:
# Use `.from_files` to quickly load them into a `DocumentArray`
docs = DocumentArray.from_files(DATA_PATH, size=MAX_DOCS, to_datauri=True)
print(f"{len(docs)} Documents in DocumentArray")

TypeError: from_files() got an unexpected keyword argument 'to_datauri'

In [None]:
for doc in docs:
    doc.load_uri_to_image_blob(
        height=80, width=60
    ).set_image_blob_normalization().set_image_blob_channel_axis(-1, 0)

In [None]:
docs.plot_image_sprites() # Preview the images

## Load model

In [None]:
import torchvision

model = torchvision.models.resnet50(pretrained=True)

## Finetune model

⚠️ **Warning**: As stated previously, **this part won't run in a notebook**. Please check the introduction for instructions.

In [None]:
finetuner.fit(
    model,
    train_data=docs,
    interactive=True,
    to_embedding_model=True,
    freeze=False,
    input_size=(3, 80, 60),
)