# Image Classification using PyTorch

## Overview

In this Notebook, we will review how to build image recognition application in Snowflake using Snowpark for Python, PyTorch, and Streamlit.

### What Is Snowpark?

The set of libraries and runtimes in Snowflake that securely deploy and process non-SQL code, including Python, Java and Scala.

Familiar Client Side Libraries: Snowpark brings deeply integrated, DataFrame-style programming and OSS compatible APIs to the languages data practitioners like to use. It also includes the Snowpark ML API for more efficient ML modeling (public preview) and ML operations (private preview).

Flexible Runtime Constructs: Snowpark provides flexible runtime constructs that allow users to bring in and run custom logic. Developers can seamlessly build data pipelines, ML models, and data applications with User-Defined Functions and Stored Procedures.

### What is PyTorch?

It is one of the most popular open source machine learning frameworks that also happens to be pre-installed and available for developers to use in Snowpark. This means that you can load pre-trained PyTorch models in Snowpark for Python without having to manually install the library and manage all its dependencies.

For this particular application, we will be using [PyTorch implementation of MobileNet V3](https://github.com/d-li14/mobilenetv3.pytorch). *Note: A huge thank you to the [authors](https://github.com/d-li14/mobilenetv3.pytorch?_fsi=THrZMtDg,%20THrZMtDg&_fsi=THrZMtDg,%20THrZMtDg#citation) for the research and making the pre-trained models available under [MIT License](https://github.com/d-li14/mobilenetv3.pytorch/blob/master/LICENSE).*

### Prerequisites

* Install `cachetools`, `pandas`, `streamlit` and `snowflake-snowpark-python` packages. [Learn how.](https://docs.snowflake.com/en/user-guide/ui-snowsight/notebooks-import-packages)
* Download files:
    * https://sfquickstarts.s3.us-west-1.amazonaws.com/misc/pytorch/imagenet1000_clsidx_to_labels.txt
    * https://sfquickstarts.s3.us-west-1.amazonaws.com/misc/pytorch/mobilenetv3-large-1cd25616.pth
    * https://sfquickstarts.s3.us-west-1.amazonaws.com/misc/pytorch/mobilenetv3.py


In [None]:
-- Create internal stage to host the PyTorch model files downloaded in the previous step and the User-Defined Function
CREATE STAGE DASH_FILES DIRECTORY = ( ENABLE = true );

In [None]:
-- Create Network Rule object for AWS S3 bucket where the images are store for this demo
CREATE OR REPLACE NETWORK RULE sfquickstarts_s3_network_rule
  MODE = EGRESS
  TYPE = HOST_PORT
  VALUE_LIST = ('sfquickstarts.s3.us-west-1.amazonaws.com');

In [None]:
-- Create External Access Integration object for the Network Rule created above so the User-Defined Function can access images stored on AWS S3 for this demo
CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION sfquickstarts_s3_access_integration
  ALLOWED_NETWORK_RULES = (sfquickstarts_s3_network_rule)
  ENABLED = true;

### *TODO: Before proceeding, use Snowsight to upload the downloaded files on stage `DASH_FILES`. [Learn how](https://docs.snowflake.com/en/user-guide/data-load-local-file-system-stage-ui#uploading-files-onto-a-stage).*

## Import libraries

In [None]:
# Snowpark
from snowflake.snowpark.functions import udf
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col
import streamlit as st

# Misc
import pandas as pd
import cachetools

session = get_active_session()

## Creat and register User-Defined Function

To deploy the pre-trained model for inference, we will **create and register a Snowpark Python UDFs and add the model files as dependencies**. Once registered, getting new predictions is as simple as calling the function by passing in data. For more information on Snowpark Python User-Defined Functions, refer to the [docs](https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-udfs.html).

In [None]:
session.clear_packages()
session.clear_imports()

# Add model files and test images as dependencies on the UDF
session.add_import('@dash_files/imagenet1000_clsidx_to_labels.txt')
session.add_import('@dash_files/mobilenetv3.py')
session.add_import('@dash_files/mobilenetv3-large-1cd25616.pth')

# Add Python packages from Snowflake Anaconda channel
session.add_packages('snowflake-snowpark-python','torchvision','joblib','cachetools','requests')

@cachetools.cached(cache={})
def load_class_mapping(filename):
  with open(filename, "r") as f:
   return f.read()

@cachetools.cached(cache={})
def load_model():
  import sys
  import torch
  from torchvision import models, transforms
  import ast
  from mobilenetv3 import mobilenetv3_large

  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]

  model_file = import_dir + 'mobilenetv3-large-1cd25616.pth'
  imgnet_class_mapping_file = import_dir + 'imagenet1000_clsidx_to_labels.txt'

  IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

  transform = transforms.Compose([
      transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
  ])

  # Load the Imagenet {class: label} mapping
  cls_idx = load_class_mapping(imgnet_class_mapping_file)
  cls_idx = ast.literal_eval(cls_idx)

  # Load pretrained image recognition model
  model = mobilenetv3_large()
  model.load_state_dict(torch.load(model_file))

  # Configure pretrained model for inference
  model.eval().requires_grad_(False)

  return model, transform, cls_idx

@udf(name='image_recognition_using_bytes',session=session,replace=True,is_permanent=True,stage_location='@dash_files')
def image_recognition_using_bytes(image_bytes_in_str: str) -> str:
  from io import BytesIO
  import torch
  from PIL import Image

  image_bytes = bytes.fromhex(image_bytes_in_str)

  model, transform, cls_idx = load_model()
  img = Image.open(BytesIO(image_bytes)).convert('RGB')
  img = transform(img).unsqueeze(0)

  # Get model output and human text prediction
  logits = model(img)

  outp = torch.nn.functional.softmax(logits, dim=1)
  _, idx = torch.topk(outp, 1)
  idx.squeeze_()
  predicted_label = cls_idx[idx.item()]

  return f"{predicted_label}"

@udf(name='image_recognition',
     session=session,
     is_permanent=True,
     stage_location='@dash_files',
     if_not_exists=True,
     external_access_integrations=['sfquickstarts_s3_access_integration'])
def image_recognition(image_url: str) -> str:
    import requests
    import torch
    from PIL import Image
    from io import BytesIO

    predicted_label = 'N/A'
    response = requests.get(image_url)
    
    if response.status_code == 200:
        image = Image.open(BytesIO(response.content))

        model, transform, cls_idx = load_model()

        img_byte_arr = BytesIO()
        image.save(img_byte_arr, format='JPEG')
        img_byte_arr = img_byte_arr.getvalue()
        
        img = Image.open(BytesIO(img_byte_arr)).convert('RGB')
        img = transform(img).unsqueeze(0)
        
        # # Get model output and human text prediction
        logits = model(img)
        
        outp = torch.nn.functional.softmax(logits, dim=1)
        _, idx = torch.topk(outp, 1)
        idx.squeeze_()
        predicted_label = cls_idx[idx.item()]
        
        return f"{predicted_label}"
    else:
        return("Failed to fetch the image. HTTP Status:", response.status_code)

## Streamlit Application

Let's use 5 images of dogs and cats stored on AWS S3 to see how the pre-trained PyTorch model loaded as part of the User-Defined Function classifies them.

In [None]:
base_s3_url = 'https://sfquickstarts.s3.us-west-1.amazonaws.com/misc/images'
images = ['dogs/001.jpg','dogs/002.jpg','cats/001.jpg','cats/003.jpg','dogs/003.jpg']
with st.status("Breed classification in progress...") as status:
    col1, col2, col3, col4 = st.columns(4, gap='small')
    p_container = st.container()
    col_index = 0
    i = 1
    for i in range(1,len(images)):
        with p_container:
            col = col1 if col_index == 0 else col2 \
                if col_index == 1 else col3 if col_index == 2 else col4
            img = f"{base_s3_url}/{images[i]}"
            with col:
                sql = f"""select image_recognition('{img}') as classified_breed"""
                classified_breed = session.sql(sql).to_pandas()['CLASSIFIED_BREED'].iloc[0]
                st.image(img,caption=f"{classified_breed}",use_column_width=True)
        if (i % 4) == 0:
            col1, col2, col3, col4 = st.columns(4, gap='small')
            p_container = st.container()
            col_index = 0
        else:
            col_index += 1
        i += 1                
    status.update(label="Done!", state="complete", expanded=True)