# Develop training and inference scripts for Script Mode

## Overview
In this notebook, we will learn how to develop training and inference scripts using HuggingFace framework. We will leverage SageMaker pre-build containers for HuggingFace (with PyTorch backend).

We chose to solve a typical NLP task - text classification. We will use `20 Newsgroups` dataset which assembles ~ 20,000 newsgroup documents across 20 different newsgroups.

By the end of this notebook you will learn how to:
- prepare text corpus for training and inference using Amazon SageMaker;
- develop training script to run in pre-build HugginFace container;
- configure and schedule training job;
- develop inference code;
- configure and deploy real-time inference endpoint;
- test SageMaker endpoint.

## Preparing Dataset
First, we will download dataset using `sklearn` module facility.

In [9]:
from sklearn.datasets import fetch_20newsgroups

# We select 6 out of 20 diverse newsgroups
categories = [
    "comp.windows.x",
    "rec.autos",
    "sci.electronics",
    "misc.forsale",
    "talk.politics.misc",
    "alt.atheism"
]

train_dataset = fetch_20newsgroups(subset='train',
                                  categories=categories,
                                  shuffle=True,
                                  random_state=42
                                 )
test_dataset = fetch_20newsgroups(subset='test',
                                  categories=categories,
                                  shuffle=True,
                                  random_state=42
                                 )

print(f"Number of training samples: {len(train_dataset['data'])}")
print(f"Number of test samples: {len(test_dataset['data'])}")

print("=========================")
print(f"Sample news article: {train_dataset['data'][1]}")


Number of training samples: 3308
Number of test samples: 2203
Sample news article: From: sjp@hpuerca.atl.hp.com (Steve Phillips)
Subject: Re: SUPER MEGA AUTOMOBILE SIGHTING(s)!!!!! Exotics together!
Organization: Hewlett-Packard NARC Atlanta
X-Newsreader: Tin 1.1.3 PL5
Lines: 8

Give out the address, I'll drive by and take a look myself, then post.


--
Stephen Phillips
Atlanta Response Center
Atlanta, Ga.
Home of the Braves!



Now, we need to save selected datasets into file and upload resulting files to Amazon S3 storage. SageMaker will download them to training container at training time,

In [12]:
from pickle import dump

dump(train_dataset, open("train.pkl","wb"))
dump(test_dataset, open("train.pkl","wb"))

In [15]:
import pickle

re_train=pickle.load(open("train.pkl","rb"))
print(re_train.data[1])

From: sjp@hpuerca.atl.hp.com (Steve Phillips)
Subject: Re: SUPER MEGA AUTOMOBILE SIGHTING(s)!!!!! Exotics together!
Organization: Hewlett-Packard NARC Atlanta
X-Newsreader: Tin 1.1.3 PL5
Lines: 8

Give out the address, I'll drive by and take a look myself, then post.


--
Stephen Phillips
Atlanta Response Center
Atlanta, Ga.
Home of the Braves!



In [17]:
! pip install transformers datasets

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting transformers
  Downloading transformers-4.10.3-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 4.9 MB/s eta 0:00:01
[?25hCollecting datasets
  Downloading datasets-1.12.1-py3-none-any.whl (270 kB)
[K     |████████████████████████████████| 270 kB 81.5 MB/s eta 0:00:01
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 62.3 MB/s eta 0:00:01
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 75.7 MB/s eta 0:00:01
Collecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.17-py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 3.7 MB/s  eta 0:00:01
Collecting tqdm>=4.27
  Downloading tqdm-

In [None]:
from datasets import load_dataset

dataset = load_dataset("newsgroup")