[![AWS Data Wrangler](_static/logo.png "AWS Data Wrangler")](https://github.com/awslabs/aws-data-wrangler)

# PyTorch

## Table of Contents
* [1.Defining Training Function](#1.-Defininf-Training-Function)
* [2.Traning From Amazon S3](#1.-Traning-From-Amazon-S3)
	* [2.1 Writing PyTorch Dataset to S3](#1.1-Writing-PyTorch-Dataset-to-S3)
	* [2.2 Training Network](#1.2-Training-Network)
* [3. Training From SQL Query](#2.-Training-From-SQL-Query)
	* [3.1 Writing Data to SQL Database](#2.1-Writing-Data-to-SQL-Database)
	* [3.3 Training Network From SQL](#2.2-Reading-single-JSON-file)
* [4. Creating Custom S3 Dataset](#1.-Creating-Custom-S3-Dataset)
	* [4.1 Creating Custom PyTorch Dataset](#1.1-Creating-Custom-PyTorch-Dataset)
	* [4.2 Writing Data to S3](#1.1-Writing-Data-to-S3)
	* [4.3 Training Network](#1.2-Training-Network)
* [5. Delete objects](#6.-Delete-objects)

In [None]:
import io
import boto3
import torch
import torchvision
import awswrangler as wr

accuracy = lambda o, l: 100/o.size(0) * (torch.max(o.data, 1)[1] == l).sum().item()

In [None]:
import getpass
bucket = getpass.getpass()

# 1. Defining Training Function

In [None]:
def train(model, dataset):
    criterion = torch.nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(), 0.025)

    for epoch in range(2):

        model.train()
        for inputs, labels in torch.utils.data.DataLoader(
            dataset,
            batch_size=64,
            num_workers=2,
        ):

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()s
            opt.zero_grad()

            acc = accuracy(outputs, labels)
            print(f'batch: {i} loss: {loss.mean().item():.4f} batch_acc: {acc:.2f}')   

# 2. Traning From Amazon S3

In [None]:
client_s3 = boto3.client("s3")
folder = "tutorial_torch_dataset"
for i in range(3):
    batch = (
        torch.randn(100, 3, 32, 32),
        torch.randint(1, size=(100,)),
    )
    buff = io.BytesIO()
    torch.save(batch, buff)
    buff.seek(0)
    client_s3.put_object(
        Body=buff.read(),
        Bucket=bucket,
        Key=f"{folder}/file{i}.pt",
    )

## 2.2 Training Network

In [None]:
train(
    torchvision.models.resnet18(),
    wr.torch.S3IterableDataset(path=f"s3://{bucket}/{folder}"),
)

# 2. Training Directly From SQL Query

## 2.1 Writing Data to SQL Database

In [None]:
eng = wr.catalog.get_engine("aws-data-wrangler-redshift")
df = pd.DataFrame({
    "height": [2, 1.4, 1.7, 1.8, 1.9],
    "name": ["foo", "boo"],
    "target": [1, 0, 0, 1, 2, 3]
})

wr.db.to_sql(
    df,
    eng_redshift,
    schema="public",
    name="torch",
    if_exists="replace",
    index=False
)

## 2.2 Training Network From SQL

In [None]:
train(
    model = torch.nn.Sequential(
        torch.nn.Linear(, 20),
        torch.nn.ReLU(),
        torch.nn.Linear(20, 2),    
    ),
    wr.torch.SQLDataset(
        sql="SELECT * FROM public.torch"
        con=eng
        label_col="target",
        chunksize=100
    )
)

# 3. Delete Objects

In [None]:
wr.s3.delete_objects(f"s3://{bucket}/")