In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


### Create a Data Source to read unstructured data

In [None]:
from snowflake.ml.ray.datasource import SFStageImageDataSource, SFStageTextDataSource

image_source = SFStageImageDataSource(
    stage_location = "@DATA_STAGE_RAY/images/",
    database = "ST_DB",
    schema = "ST_SCHEMA",
    image_size=(256, 256),
)

In [None]:
label_source = SFStageTextDataSource(
    stage_location = "@DATA_STAGE_RAY/labels/",
    database = "ST_DB",
    schema = "ST_SCHEMA",
)

### Load into a ray dataset

In [None]:
import ray

image_ds = ray.data.read_datasource(image_source)

In [None]:
print(f'Total load {image_ds.count()} images')
image_ds.show(2)

In [None]:
label_ds = ray.data.read_datasource(label_source, concurrency=8)

In [None]:
label_ds.show(5)

### Process both dataset to include addition columns
**Image Dataset**: add a join key, encode the images, standardize image\n

**Label Dataset**: add a join key, interrpet the labels

In [None]:
import numpy as np
from typing import Dict
import base64
import os
import torch

def process_image(row):
    # If grayscale (2D), convert to 3D
    img = row['image']
    if len(img.shape) == 2:
        row['image'] = np.stack([img] * 3, axis=-1)  # Duplicate grayscale channel 3 times

    encoded_image = base64.b64encode(row['image'])
    row['encoded_image'] = encoded_image

    fn = row['file_name']
    join_id = os.path.splitext(fn)[0].split('/')[-1]
    row['join_id'] = join_id
    return row

# processed_image_ds = image_ds.map_batches(convert_to_torch, concurrency=4)
processed_image_ds = image_ds.map(process_image)

In [None]:
processed_image_ds.show(1)

In [None]:
import os

def expand_label_column(batch: pd.DataFrame) -> pd.DataFrame:
    xmin_list = []
    ymin_list = []
    xmax_list = []
    ymax_list = []
    class_list = []
    file_names = []
    ids = []
    
    # Process each row
    for _, row in batch.iterrows():
        # Split the text and convert to list
        values = row['text'].strip().split()
        
        # Ensure we have exactly 5 values
        if len(values) != 5:
            raise ValueError(f"Expected 5 values in text, but got {len(values)} values")
            
        # Add values to respective lists
        xmin_list.append(float(values[0]))
        ymin_list.append(float(values[1]))
        xmax_list.append(float(values[2]))
        ymax_list.append(float(values[3]))
        class_list.append(int(values[4]))
        file_name = row['file_name']
        file_names.append(file_name)
        ids.append(os.path.splitext(file_name)[0].split('/')[-1] + '_test')
    
    # Create new dataframe
    new_df = pd.DataFrame({
        'join_id': ids,
        'file_name': file_names,
        'xmin': xmin_list,
        'ymin': ymin_list,
        'xmax': xmax_list,
        'ymax': ymax_list,
        'class': class_list,
    })
    return new_df 

processed_label_ds = label_ds.map_batches(expand_label_column, concurrency=20, batch_format='pandas')

In [None]:
processed_label_ds.show(1)

### Merge image source and label source into a single dataset
We have two ways of achieving this: 1) if customer is more famaliar with `pandas.Dataframe` and if the data fit into memory, then we can convert all data into pandas (or write into snowflake) and do the rest of the ops. 2) If the data does not fit into memory, we can directly leverage ray dataset to do the processing. 

**Note**: Ray dataset is not naturally architeched to support join ops, so it's better for to use other method (in memory / snowflake) to perform joins

#### Convert both dataset into pandas and perform joins

In [None]:
image_df = processed_image_ds.drop_columns(cols=['image']).to_pandas()

In [None]:
image_df.head()

In [None]:
label_df = processed_label_ds.to_pandas()

In [None]:
label_df.head()

In [None]:
# perform merge 
merged_train_df = pd.merge(image_df, label_df, how='inner', on='join_id')


In [None]:
merged_train_df.head()

## Save the Transformed Dataset to a snowflake table
Customer may also save the processed image dataset and label dataset into snowflake easily

In [None]:

from snowflake.ml.ray.datasink import SnowflakeTableDatasink

table_to_save = "RAY_DEMO_JAN21_IMAGE_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)

In [None]:
processed_image_ds.drop_columns(cols=['image']).write_datasink(datasink, concurrency=4)

In [None]:
# sql cell

# SELECT * FROM RAY_DEMO_JAN21_IMAGE_DS;

In [None]:
table_to_save = "RAY_DEMO_JAN21_LABEL_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)
processed_label_ds.write_datasink(datasink, concurrency=4)

In [None]:
# sql cell

#SELECT * FROM RAY_DEMO_JAN21_LABEL_DS;

In [None]:
table_to_save = "RAY_DEMO_JAN21_COMINED_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)
processed_label_ds.write_datasink(datasink, concurrency=4)