# Quickstart with DaFt

In this Quickstart tutorial, we will be using the Fashion MNIST dataset to demonstrate some of DaFt's core functionality

## Setup

Download and extract required data

In [None]:
import urllib.request
import tarfile

URL = "https://dax-cdn.cdn.appdomain.cloud/dax-fashion-mnist/1.0.2/fashion-mnist.tar.gz"
TARFILE_PATH = "fashion-mnist.tar.gz"
urllib.request.urlretrieve(URL, TARFILE_PATH)
with tarfile.open(TARFILE_PATH, "r:gz") as tar:
    tar.extractall()

In [None]:
TEST_CSV_PATH = "fashion-mnist_test.csv"
TRAIN_CSV_PATH = "fashion-mnist_train.csv"

## Create Dataframe

In [None]:
from daft import DataFrame, col, udf

images_df = DataFrame.from_csv(TRAIN_CSV_PATH)

In [None]:
images_df.show(10)

## Create Numpy arrays

In [None]:
import numpy as np

columns = [col(f"pixel{i}") for i in range(1, 785)]

@udf(return_type=np.ndarray)
def pixels_to_np_array(*pixels):
    return np.stack(pixels).T
    
images_df = images_df.select(col("label"), pixels_to_np_array(*columns).alias("img_array"))

In [None]:
images_df.show(10)

In [None]:
images_df = images_df.with_column("reshaped_array", col("img_array").as_py(np.ndarray).reshape(28, 28))

In [None]:
images_df.show(10)

## Create Images

In [None]:
from PIL import Image

@udf(return_type=Image.Image)
def arr_to_img(np_arrs):
    return [Image.fromarray(arr.astype(np.uint8)) for arr in np_arrs]

In [None]:
images_df = images_df.with_column("image", arr_to_img(col("reshaped_array")))

In [None]:
images_df.show(10)

## Filtering a Dataframe

In [None]:
images_df.where(col("label") == 8).show(10)

## Running a Model

In [None]:
import torch
from transformers import AutoModelForImageClassification


@udf(return_type=int)
class ClassifyImages:
    
    def __init__(self):
        self._model = AutoModelForImageClassification.from_pretrained("arize-ai/resnet-50-fashion-mnist-quality-drift", torch_dtype=torch.double)

    def __call__(self, images):
        converted_image_arrays = np.array([np.array(img.convert('RGB')) for img in images])
        converted_image_arrays = np.moveaxis(converted_image_arrays, 3, 1)  # (BATCH, X, Y, CHANNEL) -> (BATCH, CHANNEL, X, Y)
        converted_image_arrays = converted_image_arrays / 255
        classifications = self._model(torch.from_numpy(converted_image_arrays).double()).logits
        return classifications.detach().numpy().argmax(axis=1)
    

In [None]:
classified_images_df = images_df.with_column("model_classification", ClassifyImages(col("image")))

In [None]:
classified_images_df.show(10)

In [None]:
classified_images_df.where(col("label") == 8).show(10)

In [None]:
SAMPLE_COUNT = 1000

@udf(return_type=int)
def matched(labels, model_classifications):
    return (labels == model_classifications).astype(int)

@udf(return_type=float)
def to_float(ints):
    return ints.astype(float)

ground_truth_counts = classified_images_df \
    .limit(SAMPLE_COUNT) \
    .with_column("matched", matched(col("label"), col("model_classification"))) \
    .with_column("not_matched", abs(col("matched") - 1)) \
    .groupby(col("label")) \
    .agg([
        (col("matched").alias("true_positive"), "sum"),
        (col("not_matched").alias("false_negative"), "sum"),
    ])

prediction_counts = classified_images_df \
    .limit(SAMPLE_COUNT) \
    .with_column("matched", matched(col("label"), col("model_classification"))) \
    .with_column("not_matched", abs(col("matched") - 1)) \
    .groupby(col("model_classification")) \
    .agg([
        (col("matched").alias("true_positive"), "sum"),
        (col("not_matched").alias("false_positive"), "sum"),
    ])

precision_recall = ground_truth_counts \
    .join(
        prediction_counts,
        left_on=col("label"),
        right_on=col("model_classification"),
    ) \
    .with_column("precision", to_float(col("true_positive")) / (col("true_positive") + col("false_positive"))) \
    .with_column("recall", to_float(col("true_positive")) / (col("true_positive") + col("false_negative")))

In [None]:
precision_recall.show(10)