In [None]:
import io
import os
import torch
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

# Install an older version of bing_image_downloader
# The latest version is broken and return non related images
!pip install bing_image_downloader==1.0.4
from bing_image_downloader import downloader

### Configuration

In [None]:
# Determine how many images of each class that you want to download
n_images_per_class = 4
# Give a name for the direcotry where images are to be stored
data_path= 'image_data'
print('Saving images to: {:s}'.format(data_path))

# Give search queries for the classes you want to get pictures of
quaries = ['trains', 'airplanes', 'cars road', 'ships']

### Download images

In [None]:
# ONLY RUN THIS CELL IF YOU WANT TO DOWNLOAD NEW IMAGES!!!

# Create the data direcotry if it does not exist
if not os.path.exists(data_path):
  os.makedirs(data_path)

# Download images for each query
for quary in quaries:
  downloader.download(quary,
                      limit=n_images_per_class,
                      output_dir=data_path,
                      adult_filter_off=True,
                      force_replace=False,
                      timeout=5)

### Visualize the dowloaded dataset

In [None]:
# Define the number of classes based on the number of quaries
n_classes = len(quaries)

# Create a matplotlib figure window
fig = plt.figure(figsize=[15, 10])
# Loop over all classes
for row in range(n_classes):
  # List all the files in the class directory
  class_dir_path = os.path.join(data_path, quaries[row])
  files = os.listdir(class_dir_path)
  n_examples = len(files)
  for col in range(min(n_examples, 4)):
    # Create a subplot axes
    plt.subplot(n_classes, n_examples, row*n_examples+col+1)
    # Open the image
    img = Image.open(os.path.join(class_dir_path, files[col]))
    # Resize the image
    img.thumbnail([200, 200])
    # Plot the image
    plt.imshow(img)
    # Remove ticks
    fig.gca().set(xticks=[], yticks=[])

### Download the DINOv2 model

In [None]:
# Download the the DINOv2 embedding model
embedding_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
# Check if we have a GPU available otherwise use the CPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
# Move the model to the GPU if we have one
embedding_model.to(device);

### Define some useful functions for computing embeddings

In [None]:
# Define image transforms
transform_image = T.Compose([
    T.ToTensor(),
    T.Resize(224, antialias=True),
    T.CenterCrop(224),
    T.Normalize([0.5], [0.5])]
)

def get_img_paths(data_path):
  """
  Get image paths to all images that were downloaded
  """
  img_paths = []
  for folder in os.listdir(data_path):
      for file in os.listdir(os.path.join(data_path, folder)):
          if file[-4:] in [".jpg", "jpeg", ".JPG", ".png"]:
              img_paths.append(os.path.join(data_path, folder, file))

  return img_paths

def load_image(img: str) -> torch.Tensor:
  """
  Load an image and return a tensor that can be used as an input to DINOv2.
  """
  if isinstance(img, str):
    img = Image.open(img).convert('RGB')
  else:
    img = Image.open(io.BytesIO(img)).convert('RGB')
  transformed_img = transform_image(img)[:3].unsqueeze(0)
  return transformed_img

def compute_embedding(img):
  """
  Compute the embedding for an image
  """
  with torch.no_grad():
    embedding = embedding_model(load_image(img).to(device))
    embedding = embedding.cpu().numpy()
  return embedding

def create_dataset(img_paths: list) -> dict:
  """
  Get embeddings and labels for a list of image paths
  """
  all_embeddings = []
  all_labels = []

  for i, img_path in enumerate(tqdm(img_paths)):
    all_labels.append(img_path.split(os.path.sep)[-2].split(' ')[0])
    all_embeddings.append(compute_embedding(img_path))

  all_embeddings = np.vstack(all_embeddings)
  return all_embeddings, all_labels


### Compute the embeddings for our dpwnloaded image dataset

In [None]:
img_paths = get_img_paths(data_path)
embeddings, labels = create_dataset(img_paths)

### Test how well a simple nearest neighbor classifier works

In [None]:
import ipywidgets as widgets
# Use a widget to upload a test image from your computer
uploader = widgets.FileUpload()
uploader

In [None]:
# Compute the embedding for the test image
query_embedding = compute_embedding(uploader.data[0])
# Check distances to the embeddings for all images in the dataset
distances = np.sqrt(np.sum((embeddings - query_embedding)**2, 1))
# Assign the label based on the label of most similar image in the dataset
predicted_class = labels[np.argmin(distances)]

# Print the results
print('The predicted label is: {:s}'.format(predicted_class))
image = Image.open(io.BytesIO(uploader.data[0]))
image