# Quickstart Guide for PyTorch Training Using Snowflake's ML Container Runtime
## Introduction
This notebook provides a quickstart for training a PyTorch model using Snowflake's ML Container Runtime APIs. We will use Snowflake's `ShardedDataConnector` and `PyTorchDistributor` for distributed training on a dataset using multiple GPUs.

### Steps Covered:
- Load data from a Snowflake table.
- Set up Snowflake's `ShardedDataConnector` for data ingestion.
- Train a PyTorch model leveraging multiple GPUs (or fallback to CPU if GPUs are not available).
- Make predictions or evaluate the model.

### Step 1: Set Up Snowflake Session
Initialize a Snowflake session to perform operations within the environment.

In [None]:
# Initialize Snowflake session
from snowflake.snowpark.context import get_active_session
session = get_active_session()

### Step 2: Load Data from Snowflake Table
We load data from the `CR_QUICKSTART.PUBLIC.VEHICLE` table.

In [None]:
# Load data from the Snowflake table
table_name = 'CR_QUICKSTART.PUBLIC.VEHICLE'
snowpark_df = session.table(table_name)

# Convert Snowpark DataFrame to Pandas DataFrame using DataConnector
from snowflake.ml.data.data_connector import DataConnector
pandas_df = DataConnector.from_dataframe(snowpark_df).to_pandas()

# Drop the 'C2' column (datetime column) from the dataset
pandas_df = pandas_df.drop(columns=['C2'])

# Split the data into features (X) and target (y). Assume C6 is the target column.
X = pandas_df.drop('C6', axis=1)
y = pandas_df['C6']

# Define input columns and label column
input_cols = X.columns.tolist()
label_col = 'C6'

### Step 3: Set Up Snowflake ShardedDataConnector
Use the `ShardedDataConnector` to ingest the dataset into the Snowflake environment for model training.

In [None]:
# Create ShardedDataConnector for data ingestion
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector

# Drop the 'C2' column (datetime column) from the dataset
snowpark_df = snowpark_df.drop(['C2'])
data_connector = ShardedDataConnector.from_dataframe(snowpark_df)

### Step 4: Define and Train the PyTorch Model
We define a PyTorch model and configure it for distributed training using `PyTorchDistributor`. This configuration leverages multi-GPU support to optimize training performance.

In [None]:
# Import necessary PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Define the training function
def train_func():
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from snowflake.ml.modeling.distributors.pytorch import get_context
    context = get_context()
    rank = context.get_rank()
    dist.init_process_group(backend='gloo')

    # Initialize model, loss function, and optimizer
    model = SimpleNet(input_size=len(input_cols), hidden_size=32, output_size=1).to(rank)
    model = DDP(model)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Retrieve training data
    dataset_map = context.get_dataset_map()
    torch_dataset = dataset_map['train'].get_shard().to_torch_dataset(batch_size=1024)
    dataloader = DataLoader(torch_dataset)

    # Training loop
    for epoch in range(10):
        for batch_dict in dataloader:
            features = torch.cat([batch_dict[col].T for col in input_cols], dim=1).float().to(rank)
            labels = batch_dict[label_col].T.squeeze(0).float().to(rank)
            output = model(features)
            loss = criterion(output, labels.unsqueeze(1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

    print('Training finished')

### Step 5: Run Distributed Training
Configure and start the distributed training process using `PyTorchDistributor`.

In [None]:
# Set up PyTorchDistributor for distributed training
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig

pytorch_trainer = PyTorchDistributor(
    train_func=train_func,
    scaling_config=PyTorchScalingConfig(
        num_nodes=1,
        num_workers_per_node=2,
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=0, num_gpus=1)
    )
)

# Run the training process
pytorch_trainer.run(dataset_map={'train': data_connector})

## Conclusion
In this notebook, we demonstrated how to:
- Set up a Snowflake session
- Load and prepare data using a Snowflake `ShardedDataConnector`
- Train a PyTorch model using the `PyTorchDistributor` API with multi-GPU support
- Evaluate the model after training

You can now apply these steps to your own datasets and models!