[![AWS Data Wrangler](_static/logo.png "AWS Data Wrangler")](https://github.com/awslabs/aws-data-wrangler)

# PyTorch

## Table of Contents
* [1.Defining Training Function](#1.-Defining-Training-Function)
* [2.Training From Amazon S3](#2.-Traoning-From-Amazon-S3)
	* [2.1 Writing PyTorch Dataset to S3](#2.1-Writing-PyTorch-Dataset-to-S3)
	* [2.2 Training Network](#2.2-Training-Network)
* [3. Training From SQL Query](#3.-Training-From-SQL-Query)
	* [3.1 Writing Data to SQL Database](#3.1-Writing-Data-to-SQL-Database)
	* [3.3 Training Network From SQL](#3.3-Reading-single-JSON-file)
* [4. Creating Custom S3 Dataset](#4.-Creating-Custom-S3-Dataset)
	* [4.1 Creating Custom PyTorch Dataset](#4.1-Creating-Custom-PyTorch-Dataset)
	* [4.2 Writing Data to S3](#4.2-Writing-Data-to-S3)
	* [4.3 Training Network](#4.4-Training-Network)
* [5. Delete objects](#5.-Delete-objects)

In [1]:
import io

import boto3
import torch
import torchvision
import pandas as pd
import awswrangler as wr

from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

In [2]:
import getpass
bucket = getpass.getpass()

········


# 1. Defining Training Function

In [3]:
def train(model, dataset, batch_size=64, epochs=2, device='cpu', num_workers=1):

    criterion = CrossEntropyLoss().to(device)
    opt = SGD(model.parameters(), 0.025)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

    for epoch in range(epochs):

        correct = 0    
        model.train()
        for i, (inputs, labels) in enumerate(loader):

            # Forward Pass
            outputs = model(inputs)
            
            # Backward Pass
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            opt.zero_grad()
            
            # Accuracy
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / ((i+1) * batch_size)

            print(f'batch: {i} loss: {loss.mean().item():.4f} acc: {accuracy:.2f}')   

# 2. Training From Amazon S3

## 2.1 Writing PyTorch Dataset to S3

In [4]:
client_s3 = boto3.client("s3")
folder = "tutorial_torch_dataset"

wr.s3.delete_objects(f"s3://{bucket}/{folder}")
for i in range(3):
    batch = (
        torch.randn(100, 3, 32, 32),
        torch.randint(2, size=(100,)),
    )
    buff = io.BytesIO()
    torch.save(batch, buff)
    buff.seek(0)
    client_s3.put_object(
        Body=buff.read(),
        Bucket=bucket,
        Key=f"{folder}/file{i}.pt",
    )

## 2.2 Training Network

In [6]:
train(
    torchvision.models.resnet18(),
    wr.torch.S3IterableDataset(path=f"{bucket}/{folder}")
)

batch: 0 loss: 7.0132 acc: 0.00
batch: 1 loss: 2.8764 acc: 21.09
batch: 2 loss: 0.9600 acc: 32.29
batch: 3 loss: 0.8676 acc: 36.33
batch: 4 loss: 1.1386 acc: 36.88
batch: 0 loss: 1.0754 acc: 51.56
batch: 1 loss: 1.4241 acc: 51.56
batch: 2 loss: 1.3019 acc: 51.04
batch: 3 loss: 0.8631 acc: 53.52
batch: 4 loss: 0.4252 acc: 54.38


# 2. Training Directly From SQL Query

## 2.1 Writing Data to SQL Database

In [7]:
eng = wr.catalog.get_engine("aws-data-wrangler-redshift")
df = pd.DataFrame({
    "height": [2, 1.4, 1.7, 1.8, 1.9, 2.2],
    "weight": [100.0, 50.0, 70.0, 80.0, 90.0, 160.0],
    "target": [1, 0, 0, 1, 1, 1]
})

wr.db.to_sql(
    df,
    eng,
    schema="public",
    name="torch",
    if_exists="replace",
    index=False
)

## 2.2 Training Network From SQL

In [8]:
train(
    torch.nn.Sequential(
        torch.nn.Linear(2, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 2),    
    ),
    wr.torch.SQLDataset(
        sql="SELECT * FROM public.torch",
        con=eng,
        label_col="target",
        chunksize=2
    ),
    num_workers=0,
    batch_size=2,
    epochs=5
)

batch: 0 loss: 8.8708 acc: 50.00
batch: 1 loss: 88.7789 acc: 50.00
batch: 2 loss: 0.8655 acc: 33.33
batch: 0 loss: 0.7036 acc: 50.00
batch: 1 loss: 0.7034 acc: 50.00
batch: 2 loss: 0.8447 acc: 33.33
batch: 0 loss: 0.7012 acc: 50.00
batch: 1 loss: 0.7010 acc: 50.00
batch: 2 loss: 0.8250 acc: 33.33
batch: 0 loss: 0.6992 acc: 50.00
batch: 1 loss: 0.6991 acc: 50.00
batch: 2 loss: 0.8063 acc: 33.33
batch: 0 loss: 0.6975 acc: 50.00
batch: 1 loss: 0.6974 acc: 50.00
batch: 2 loss: 0.7886 acc: 33.33


# 3. Delete Objects

In [9]:
wr.s3.delete_objects(f"s3://{bucket}/{folder}")