In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

print("----------------------------------------")
snowflake_environment = session.sql('select current_warehouse(), current_database(), current_schema(), current_version()').collect()
print('Warehouse                   : {}'.format(snowflake_environment[0][0]))
print('Database                    : {}'.format(snowflake_environment[0][1]))
print('Schema                      : {}'.format(snowflake_environment[0][2]))
print('Snowflake version           : {}'.format(snowflake_environment[0][3]))
print("----------------------------------------")

DATABASE = snowflake_environment[0][1]
SCHEMA = snowflake_environment[0][2]
DATA_TABLE = 'TINY_IMAGENET'

# Dataset setup
* Dataset Description
    * The example uses the Tiny ImageNet dataset, a popular benchmark for image classification tasks.
    * The dataset contains 100,000 color images, each sized 64×64 pixels, and contains 200 distinct classes.
    * Each class contains 500 training images.
* If the Snowflake table is not found, dataset is automatically downloaded fron HuggingFace and save to a Snowflake table.
    * The images are encoded using Base64 saved in Snowflake tables as Varchar column.

In [None]:
def table_exists(session):
    query = f"""
        SELECT COUNT(*)
        FROM INFORMATION_SCHEMA.TABLES
        WHERE TABLE_SCHEMA = '{SCHEMA}'
        AND TABLE_NAME = '{DATA_TABLE}'
    """
    result = session.sql(query).collect()
    return result[0][0] > 0

def create_dataset_if_not_present(session):
    from datasets import load_dataset
    import pandas as pd
    import base64
    from io import BytesIO

    if table_exists(session):
        return

    # Load Mini-ImageNet dataset from Hugging Face
    dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
    
    # Function to encode an image to base64
    def encode_image_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")  # Save the image to a buffer in PNG format
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    # Create a Pandas DataFrame
    data = {
        "label": [entry["label"] for entry in dataset],
        "image_base64": [encode_image_base64(entry["image"]) for entry in dataset],
    }
    
    df = pd.DataFrame(data)

    session.write_pandas(
        df=df,
        table_name=DATA_TABLE,
        database=DATABASE,
        schema=SCHEMA,
        auto_create_table=True,
        overwrite=True
    )
create_dataset_if_not_present(session)

# Scaling ContainerRuntime Cluster

In [None]:
from snowflake.ml.runtime_cluster import scale_cluster, get_nodes
scale_cluster("DISTRIBUTED_PYTORCH_QUICKSTART_2", 3)
get_nodes()

In [None]:
get_nodes()

# Training
## Model
* Finetune a resent18 model by replacing fully connected layer with a new multi layer DNN.
* More details about resnet18 model: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html#torchvision.models.resnet18 
* All the base layers of resnet18 model are froze and only newly added fully connected layer is trained in this example to keep it simple.

## Decoding Data
* The images are stored as Base64-encoded strings in Snowflake tables. To train a CNN model, these encoded strings must be decoded and converted into tensors of shape 3 × 64 × 64 (3 color channels for RGB and 64×64 resolution).
* DecodedDataset Class: The DecodedDataset class acts as a wrapper around a source [dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) and transforms each row of the dataset, enabling seamless decoding and preprocessing.

* Similar pattern can be followed to chain together any number of row-level or batch-level transforms to preprocess the data for training efficiently.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights

def get_resent_model():
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    # Replace the last layer to fit 200 classes
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.LeakyReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, 200)
    )

    def initialize_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)
    model.fc.apply(initialize_weights)
    

    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True

    return model

In [None]:
import io
import base64
from torchvision import transforms
from torch.utils.data import IterableDataset
from PIL import Image

class DecodedDataset(IterableDataset):
    def __init__(self, source_dataset):  
        self.source_dataset = source_dataset
        self.transforms = transforms.Compose([
            transforms.ToTensor()
        ])
        self.normlize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __iter__(self):
        for row in self.source_dataset:
            base64_image = row['"image_base64"']
            image = Image.open(io.BytesIO(base64.b64decode(base64_image)))
            # Convert the image to a tensor
            image = self.transforms(image)  # Converts PIL image to tensor

            # Skip images with improper dimentions
            if image.size() != torch.Size([3, 64, 64]):
                # Check if it is a grey scale image
                if image.size() == torch.Size([1, 64, 64]):
                    # Conver to 3 channel image by replacing single channel to 3 channels
                    image = image.repeat(3, 1, 1)
                else:
                    raise RuntimeError(f"Unsupported image of dimentions {image.size()}")

            image = self.normlize(image)
            labels = row['"label"']
            yield image, int(labels)

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    batch_idx = 1
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [Processed {} images]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), loss.item()))
        batch_idx += 1

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total_images = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()  # sum up batch loss
            _, preds = torch.max(output, 1)
            correct += (preds == target).sum().item()
            total_images += len(data)

    test_loss /= total_images

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total_images,
        100. * correct / total_images))
    return (100. * correct / total_images)

# Training loop definition
* The provided example implements a straightforward training loop where the model is trained for a fixed number of epochs (num_epochs).
    * Users can leverage the HuggingFace Trainer or PyTorch Lightning Trainer to take advantage of fully featured training loops with advanced features like early stopping, checkpointing, and more.
* This example demonstrates how datasets and hyperparameters can be passed dynamically as arguments:
    * The dataset_map parameter allows a map of datasets to be passed to the train_func, enabling flexibility to use different datasets without modifying the core logic.
    * The hyper_params parameter allows variables such as num_epochs, learning_rate, and batch_size to be defined externally and passed during trainer invocation, making it easy to adjust or tune these parameters without hardcoding them in the training function.
* The training loop also highlights the use of context APIs to retrieve metadata like rank and world_size, which can then be used to customize the training loop.

In [None]:
def train_func():
    import os
    import time
    import torch
    import torch.optim as optim
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.optim.lr_scheduler import StepLR
    from snowflake.ml.modeling.distributors.pytorch import get_context

    start_time = time.time()
    context = get_context()
    local_rank = context.get_local_rank()
    device = f"cuda:{local_rank}"
    is_distributed = context.get_world_size() > 1
    if is_distributed:
        dist.init_process_group(backend="nccl")
    print(
        f"rank : {context.get_rank()}, "
        f"world_size: {context.get_world_size()}, "
        f"local_rank : {context.get_local_rank()}, "
        f"local_world_size: {context.get_local_world_size()}, "
        f"node_rank: {context.get_node_rank()}"
    )

    dataset_map = context.get_dataset_map()
    train_dataset = DecodedDataset(dataset_map["train"].get_shard().to_torch_dataset())
    test_dataset = DecodedDataset(dataset_map["test"].to_torch_dataset())

    batch_size = 64
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        pin_memory=True,
        pin_memory_device=f"cuda:{local_rank}"
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        pin_memory=True,
        pin_memory_device=f"cuda:{local_rank}"
    )

    model = get_resent_model()
    model = model.to(device)

    if is_distributed:
        model = DDP(model)

    base_lr = 0.001
    lr = base_lr * context.get_world_size()
    optimizer = optim.AdamW(model.parameters(), lr = lr)
    scheduler = StepLR(optimizer, step_size=10)

    hyper_parms = context.get_hyper_params()
    num_epochs = int(hyper_parms['num_epochs'])

    accuracy = []
    for epoch in range(0, num_epochs):
        train(model, device, train_loader, optimizer, epoch+1)
        res = test(model, device, test_loader)
        accuracy.append(res)
        scheduler.step()

    now = time.time()
    context.get_metrics_reporter().log_metrics({
        "train_func_train_time": int(now-start_time),
        "test_accuracy": accuracy
    })

    if local_rank == 0:
        torch.save(
            model.module.state_dict(), os.path.join(context.get_model_dir(), "model.pt")
        )
    

# Training
* This example shows how to run the training function on 3 nodes, 4 workers on each node (because each node has 4 GPUs).

In [None]:
# Set up PyTorchDistributor
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig  
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector
from snowflake.ml.data.data_connector import DataConnector
from snowflake.snowpark.functions import random

df = session.table(f"{DATABASE}.{SCHEMA}.{DATA_TABLE}")

shuffled_df = df.with_column("random_order", random()).sort("random_order").drop("random_order")

train_df, test_df = shuffled_df.random_split(weights = [0.95, 0.05], seed = 99)
# Create sharded data connector.
train_data = ShardedDataConnector.from_dataframe(train_df)
test_data = DataConnector.from_dataframe(test_df)

pytorch_trainer = PyTorchDistributor(  
    train_func=train_func,
    scaling_config=PyTorchScalingConfig(  
        num_nodes=3,  
        num_workers_per_node=4,  
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=0, num_gpus=1),  
    )  
)  

# Run the trainer.
results = pytorch_trainer.run(
    dataset_map={"train": train_data, "test": test_data},
    hyper_params={"num_epochs": "10", "warm_up_num_epochs": "3"}
)
results.get_metrics()

In [None]:
results.get_metrics()