In [33]:
import sagemaker

from sagemaker.pytorch import PyTorch
from sagemaker.session import Session
from sagemaker.experiments.run import Run
from sagemaker.utils import unique_name_from_base

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()

role = sagemaker.get_execution_role()

## Train

In [None]:
experiment_name = unique_name_from_base("captioner")
with Run(experiment_name=experiment_name, run_name="run-01", sagemaker_session=Session()) as run:
    
    estimator = PyTorch(
        role=role,
        entry_point="train.py",
        framework_version="2.0.0",
        py_version="py310",
        instance_count=1,
        instance_type="ml.g4dn.xlarge",
        keep_alive_period_in_seconds=30,
        dependencies=['requirements.txt', 'model.py', 'data.py'],
        hyperparameters={"epochs": 5, "batch-size":64, 'lr':0.0002, 'heads':8, 'nlayers':4}
    )
    
    estimator.fit({'train': f's3://{bucket}/train2017', 'test': f's3://{bucket}/val2017', 'text': f's3://{bucket}/annotations'})

## Inference

In [34]:
import json
import boto3
import torch
import matplotlib.pyplot as plt

from model import CaptioningModel
from data import LocalDataset, collate_fn
from torch.utils.data import DataLoader
from torchvision.models import ResNet50_Weights

In [35]:
bucket = boto3.resource('s3').Bucket(bucket)
vocab = json.load(bucket.Object('annotations/vocab.json').get()['Body'])

device = torch.device('cpu')

dataset = LocalDataset('val2017')
img_transforms = ResNet50_Weights.DEFAULT.transforms()

batch_size = 1
loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

model = CaptioningModel(vocab_size=len(vocab),
                        d_model=512,
                        nheads=8,
                        nlayers=4)

# download checkpoint with boto3 download_file() or via console
checkpoint = torch.load('checkpoint.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

loader = iter(loader)

In [None]:
CAP_LENGTH = 16
meta_tokens = [0, 1, 2]

image, target = next(loader)

target = target[:, 1:]

batch_size = 1
captions = torch.zeros((batch_size, CAP_LENGTH), dtype=torch.int32)
_target = torch.ones((batch_size, 1), dtype=torch.int32)

for i in range(CAP_LENGTH):
    source = img_transforms(image)
    scores = model.forward(source, _target)
    scores = scores[:, -1, :]

    predictions = torch.argmax(scores, axis=1)
    captions[:, i] = predictions

    predictions = predictions.unsqueeze(1)
    _target = torch.cat([_target, predictions], dim=1)
    
reference = [vocab[str(token)] for token in target.squeeze(0).tolist() if token not in meta_tokens]
hypothesis = [vocab[str(token)] for token in captions.squeeze(0).tolist() if token not in meta_tokens]
reference = ' '.join(reference).capitalize()
hypothesis = ' '.join(hypothesis).capitalize()

image = image.squeeze(0).permute(1, 2, 0)

plt.imshow(image)
plt.text(0.5, 1.05, 'Caption: ' + reference, fontsize=8, ha='center', transform=plt.gca().transAxes)
plt.text(0.5, -.05, 'Predicted: ' + hypothesis, fontsize=8, ha='center', transform=plt.gca().transAxes)
plt.axis('off')
plt.show()