# PyTorch: Training ResNet50 using the AISShardReader and WebDataset

We train the ResNet50 model using dummy ImageNet data sharded in the WebDataset format. Note that you can download the actual ImageNet dataset and use that instead if you would like.

#### 1) Import necessary packages, define constants, and create AIS Client.

In [2]:
try:
    from aistore.sdk import Client
    from aistore.pytorch.shard_reader import AISShardReader
except:

    # Use local version of aistore if pip version is too old or aistore not installed
    import sys

    sys.path.append("../../")

    from aistore.sdk import Client
    from aistore.pytorch.shard_reader import AISShardReader

import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim, no_grad, max, stack, tensor
from random import shuffle
from PIL import Image
from io import BytesIO

import requests

In [3]:
AIS_ENDPOINT = "http://localhost:8080"
AIS_PROVIDER = "ais"
BCK_NAME = "fake-imagenet"
DATASET_URL = "https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-{000000..001281}.tar"

client = Client(endpoint=AIS_ENDPOINT)
bucket = client.bucket(BCK_NAME, AIS_PROVIDER).create(exist_ok=True)

#### 2) Populate the bucket with WebDataset formatted shards using the AIS CLI.

or do `ais start download "https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-{000000..001281}.tar" ais://fake-imagenet`.

The dataset comes from the official WebDataset example notebooks: https://github.com/webdataset/webdataset/blob/main/examples/train-resnet50-wds.ipynb.

In [6]:
bucket = client.bucket(BCK_NAME, AIS_PROVIDER)
try:
    bucket.create(exist_ok=False)

    print("Downloading dataset...")

    headers = {"User-Agent": "Mozilla/5.0"}

    for i in range(1282):
        tar_url = DATASET_URL.replace("{000000..001281}", f"{i:06}")
        name = tar_url.split("/")[-1]

        response = requests.get(tar_url, headers=headers, stream=True)
        response.raise_for_status()

        data = BytesIO(response.content)
        bucket.object(name).put_content(data.read())

    print("Done putting files into buckets.")

    print("Cleaned up downloaded dataset.")
except Exception as e:
    print("Bucket already has dataset! Nothing will be done.")
    bucket.create(exist_ok=True)

Downloading dataset...
Done putting files into buckets.
Cleaned up downloaded dataset.


#### 3) Generate random split of indices and pass to ShardReader.

In [None]:
TRAIN_SPLIT = 0.80
NUM_SHARDS = 100  # 1281 is total number, we will take subset to save time

shard_indices = list(range(NUM_SHARDS))
shuffle(shard_indices)

train_boundary = int(len(shard_indices) * 0.8)

train_indices = shard_indices[:train_boundary]
validation_indices = shard_indices[train_boundary:]

In [None]:
train_shards = AISShardReader(
    bucket_list=bucket,
    prefix_map={
        bucket: [f"imagenet-train-{index:06}.tar" for index in train_indices]
    },  # :06 because each number has two digits prepended
)

validation_shards = AISShardReader(
    bucket_list=bucket,
    prefix_map={
        bucket: [f"imagenet-train-{index:06}.tar" for index in validation_indices]
    },
)

#### 4) Create DataLoader and pass in parameters.

In [None]:
BATCH_SIZE = 100
NUM_WORKERS = 16

train_loader = DataLoader(train_shards, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

validation_loader = DataLoader(
    validation_shards, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
)

#### 5) Define model, hyperparameters, optimizer, transforms, and loss function for training.

In [None]:
LEARNING_RATE = 0.1
WEIGHT_DECAY = 5e-4

resnet_model = resnet50()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    resnet_model.parameters(),
    lr=0.01,
    momentum=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

transform_train = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

#### 6) Train the model on a number of epochs and validate accuracy.

In [None]:
NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
    print(f"EPOCH {epoch + 1}\n-----------")

    loss = 0
    i = 0
    for i, (_, contents) in enumerate(train_loader):

        images = stack(
            [
                transform_train(Image.open(BytesIO(image_bytes)))
                for image_bytes in contents["jpg"]
            ]
        )
        labels = tensor(
            [int(label_bytes.decode("utf-8")) for label_bytes in contents["cls"]]
        )

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = resnet_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        loss += loss.item()

        print(f"Batch: {i + 1}, Samples Processed: {(i + 1) * BATCH_SIZE}")

    print(f"Samples Processed: {(i+1) * BATCH_SIZE}, Loss: {loss / 100}")

    # Validation
    resnet_model.eval()
    with no_grad():
        correct = 0
        total = 0
        for _, contents in validation_loader:

            images = stack(
                [
                    transform_train(Image.open(BytesIO(image_bytes)))
                    for image_bytes in contents["jpg"]
                ]
            )
            labels = tensor(
                [int(label_bytes.decode("utf-8")) for label_bytes in contents["cls"]]
            )

            if len(labels) != BATCH_SIZE:
                print(len(labels))
            outputs = resnet_model(images)
            _, predicted = max(outputs.data, 1)

            correct += (predicted == labels).sum().item()
            total += len(labels)

    print(f"Accuracy: {100 * correct / total}%\n")

print("-----------\nFinished Training")