# Image Classification using Vision Transformers

In [None]:
from pathlib import Path
import numpy as np
import cv2
import PIL.Image as Image
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.ticker import MaxNLocator
from glob import glob
import shutil

import torch, torchvision
from torch import nn, optim
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

import sagemaker


RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

In [None]:
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/pytorch-vit"
role = sagemaker.get_execution_role()


The dataset used  is `Intel Image Classification`. 
It contains around 25k images of size 150x150 distributed under 6 categories.
```
{'buildings' -> 0,
'forest' -> 1,
'glacier' -> 2,
'mountain' -> 3,
'sea' -> 4,
'street' -> 5 }
```

In [None]:
from zipfile import ZipFile
with ZipFile('data1.zip', 'r') as zipObj:
   
   zipObj.extractall('data1')

In [None]:
train_set = './data1/seg_train/seg_train'
test_set = './data1/seg_test/seg_test'
pred_set = './data1/seg_pred/seg_pred'

In [None]:
class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
class_indices = [0,1,2,3,4,5]

In [None]:
train_folders = sorted(glob(train_set + '/*'))
len(train_folders)

In [None]:
def load_image(img_path, resize=True):
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    
    if resize:
        img = cv2.resize(img, (64,64), interpolation = cv2.INTER_AREA)
    
    return img

def show_image(img_path):
    img = load_image(img_path)
    plt.imshow(img)
    plt.axis('off')
    
def show_sign_grid(image_paths):
    images = [load_image(img) for img in image_paths]
    images = torch.as_tensor(images)
    images = images.permute(0,3,1,2)
    grid_img = torchvision.utils.make_grid(images, nrow=11)
    plt.figure(figsize=(24,12))
    plt.imshow(grid_img.permute(1,2,0))
    plt.axis('off')

In [None]:
sample_images = [np.random.choice(glob(f'{tf}/*jpg')) for tf in train_folders]
show_sign_grid(sample_images)

In [None]:
!rm -rf data

DATA_DIR = Path('data')

DATASETS = ['train', 'val']

for ds in DATASETS:
    for cls in class_names:
        (DATA_DIR / ds / cls).mkdir(parents=True, exist_ok=True)

In [None]:
for i, cls_index in enumerate(class_indices):
    image_paths = np.array(glob(f'{train_folders[cls_index]}/*jpg'))
    class_name = class_names[i]
    print(f'{class_name}: {len(image_paths)}')
    np.random.shuffle(image_paths)
    
    ds_split = np.split(
        image_paths,
        indices_or_sections=[int(.8*len(image_paths)), int(.9*len(image_paths))]
    )
    
    dataset_data = zip(DATASETS, ds_split)
    for ds, images in dataset_data:
        for img_path in images:
            shutil.copy(img_path, f'{DATA_DIR}/{ds}/{class_name}/')

In [None]:
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]

transforms = {'train': T.Compose([
    T.RandomResizedCrop(size=224),
    T.RandomRotation(degrees=15),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean_nums, std_nums)
]), 'val': T.Compose([
    T.Resize(size=224),
    T.CenterCrop(size=224),
    T.ToTensor(),
    T.Normalize(mean_nums, std_nums)
]),}

In [None]:
image_datasets = {
    d: ImageFolder(f'{DATA_DIR}/{d}', transforms[d]) for d in DATASETS
}

data_loaders = {
    d: DataLoader(image_datasets[d], batch_size=16, shuffle=True, num_workers=4) for d in DATASETS
}

In [None]:
dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS}
class_names = image_datasets['train'].classes
dataset_sizes

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1,2,0))
    mean = np.array([mean_nums])
    std = np.array([std_nums])
    inp = std * inp + mean
    inp = np.clip(inp,0,1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.axis('off')
    
inputs, classes = next(iter(data_loaders['train']))
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

In [None]:
# upload data to S3
input_path = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input specification (in this case, just an S3 path): {}'.format(input_path))

## Training job on Sagemaker

In [None]:
from datetime import datetime
from sagemaker.pytorch import PyTorch

now = datetime.now()
timestr = now.strftime("%m-%d-%Y-%H-%M-%S")
vt_training_job_name = "vt-training-{}".format(timestr)
print(vt_training_job_name)

estimator = PyTorch(
    entry_point="vit-job.py",
    source_dir="code",
    role=role,
    framework_version="1.6.0",
    py_version="py3",
    instance_count=1,  
    instance_type="ml.p3.16xlarge", 
    use_spot_instances=False,
    debugger_hook_config=False,
    hyperparameters={
        "epochs": 5,
        "num_classes": 6,
        "batch-size": 256,
    },
    metric_definitions=[
                   {'Name': 'validation:loss', 'Regex': 'Valid_loss = ([0-9\\.]+);'},
                   {'Name': 'validation:accuracy', 'Regex': 'Valid_accuracy = ([0-9\\.]+);'},
                   {'Name': 'train:accuracy', 'Regex': 'Train_accuracy = ([0-9\\.]+);'},
                   {'Name': 'train:loss', 'Regex': 'Train_loss = ([0-9\\.]+);'},
                ]
)
estimator.fit({"training": input_path}, wait=True, job_name=vt_training_job_name)

In [None]:
vt_training_job_name = estimator.latest_training_job.name
print("Vision Transformer training job name: ", vt_training_job_name)

## Deploy to sagemaker endpoint


In [None]:
from sagemaker import get_execution_role
ENDPOINT_NAME='pytorch-inference-{}'.format(timestr)
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p3.2xlarge', endpoint_name=ENDPOINT_NAME)

In [None]:
import json
import requests
from IPython.display import Image 
import json
import boto3
import numpy as np

runtime= boto3.client('runtime.sagemaker')
client = boto3.client('sagemaker')

endpoint_desc = client.describe_endpoint(EndpointName=ENDPOINT_NAME)
print(endpoint_desc)
print('---'*60)

## Predictions from Sagemaker endpoint

In [None]:
payload =  '[{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/019/1390196df443f2cf614f2255ae75fcf8.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/015/1390157d4caaf290962de5c5fb4c42.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/020/1390207be327f4c4df1259c7266473.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/021/139021f9aed9896831bf88f349fcec.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/028/139028d865bafa3de66568eeb499f4a6.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/030/13903090f3c8c7a708ca69c8d5d68b2.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/002/010/00201099c5bf0d794c9a951b74390.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/136/139136bb43e41df8949f873fb44af.jpg"},\
{"url":"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/145/1391457e4a2e25557cbf956aaee4345.jpg"}]'

payload = json.loads(payload)
for item in payload:
    item = json.dumps(item)
    response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME, 
                                   ContentType='application/json', 
                                   Body=item)
    result = response['Body'].read()
    result = json.loads(result)
    print('predicted:', result[0]['prediction'])

    from PIL import Image
    import requests

    input_data = json.loads(item)
    url = input_data['url']
    im = Image.open(requests.get(url, stream=True).raw)
    newsize = (250, 250) 
    im1 = im.resize(newsize) 

    from IPython.display import Image
    display(im1)

In [None]:
predictor.delete_endpoint()