# An Introduction to SageMaker Neural Topic Model (V3)

***Unsupervised representation learning and topic extraction using Neural Topic Model***

**This notebook has been migrated to SageMaker Python SDK V3**

1. [Introduction](#Introduction)
1. [Data Preparation](#Data-Preparation)
1. [Model Training](#Model-Training)
1. [Model Hosting and Inference](#Model-Hosting-and-Inference)
1. [Model Exploration](#Model-Exploration)


---
# Introduction

Amazon SageMaker Neural Topic Model (NTM) is an unsupervised learning algorithm that attempts to describe a set of observations as a mixture of distinct categories. NTM is most commonly used to discover a user-specified number of topics shared by documents within a text corpus. Here each observation is a document, the features are the presence (or occurrence count) of each word, and the categories are the topics. Since the method is unsupervised, the topics are not specified upfront and are not guaranteed to align with how a human may naturally categorize documents. The topics are learned as a probability distribution over the words that occur in each document. Each document, in turn, is described as a mixture of topics. 

In this notebook, we will use the Amazon SageMaker NTM algorithm to train a model on the [20NewsGroups](https://archive.ics.uci.edu/ml/datasets/Twenty+Newsgroups) data set. This data set has been widely used as a topic modeling benchmark. 

The main goals of this notebook are as follows:

1. learn how to obtain and store data for use in Amazon SageMaker,
2. create an AWS SageMaker training job on a data set to produce an NTM model,
3. use the model to perform inference with an Amazon SageMaker endpoint.
4. explore trained model and visualized learned topics

If you would like to know more please check out the [SageMaker Neural Topic Model Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/ntm.html).

---
# Data Preparation

The 20Newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups. This collection has become a popular data set for experiments in text applications of machine learning techniques, such as text classification and text clustering. Here, we will see what topics we can learn from this set of documents with NTM.

## Fetching Data Set

First let's define the folder to hold the data and clean the content in it which might be from previous experiments.

In [None]:
import os
import shutil
data_dir = '20_newsgroups'
if os.path.exists(data_dir):
    shutil.rmtree(data_dir)

In [None]:
!curl -O https://archive.ics.uci.edu/ml/machine-learning-databases/20newsgroups-mld/20_newsgroups.tar.gz

In [None]:
!tar -xzf 20_newsgroups.tar.gz
!ls 20_newsgroups

In [None]:
folders = [os.path.join(data_dir,f) for f in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, f))]
file_list = [os.path.join(d,f) for d in folders for f in os.listdir(d)]
print('Number of documents:', len(file_list))

In [None]:
from sklearn.datasets._twenty_newsgroups import strip_newsgroup_header, strip_newsgroup_quoting, strip_newsgroup_footer
data = []
for f in file_list:
    with open(f, 'rb') as fin:
        content = fin.read().decode('latin1')        
        content = strip_newsgroup_header(content)
        content = strip_newsgroup_quoting(content)
        content = strip_newsgroup_footer(content)        
        data.append(content)

---
## From Plain Text to Bag-of-Words (BOW)

In [None]:
!pip install nltk
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('wordnet')
from nltk import word_tokenize          
from nltk.stem import WordNetLemmatizer 
import re
token_pattern = re.compile(r"(?u)\b\w\w+\b")
class LemmaTokenizer(object):
    def __init__(self):
        self.wnl = WordNetLemmatizer()
    def __call__(self, doc):
        return [self.wnl.lemmatize(t) for t in word_tokenize(doc) if len(t) >= 2 and re.match("[a-z].*",t) 
                and re.match(token_pattern, t)]

In [None]:
import time
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
vocab_size = 2000
print('Tokenizing and counting, this may take a few minutes...')
start_time = time.time()
vectorizer = CountVectorizer(input='content', analyzer='word', stop_words='english',
                             tokenizer=LemmaTokenizer(), max_features=vocab_size, max_df=0.95, min_df=2)
vectors = vectorizer.fit_transform(data)
vocab_list = vectorizer.get_feature_names_out()
print('vocab size:', len(vocab_list))
print('vectors shape:', vectors.shape)

idx = np.arange(vectors.shape[0])
np.random.shuffle(idx)
vectors = vectors[idx]

print('Done. Time elapsed: {:.2f}s'.format(time.time() - start_time))


In [None]:
threshold = 25
vectors = vectors[np.array(vectors.sum(axis=1)>threshold).reshape(-1,)]
print('removed short docs (<{} words)'.format(threshold))        
print(vectors.shape)

In [None]:
import scipy.sparse as sparse
vectors = sparse.csr_matrix(vectors, dtype=np.float32)
print(type(vectors), vectors.dtype)

In [None]:
n_train = int(0.8 * vectors.shape[0])

train_vectors = vectors[:n_train, :]
test_vectors = vectors[n_train:, :]

n_test = test_vectors.shape[0]
val_vectors = test_vectors[:n_test//2, :]
test_vectors = test_vectors[n_test//2:, :]

In [None]:
print(train_vectors.shape, test_vectors.shape, val_vectors.shape)

---
## Store Data on S3

**V3 Migration Note**: In V3, we use CSV format instead of RecordIO Protobuf format since `sagemaker.amazon.common` module is not available.

### Setup AWS Credentials

In [None]:
import boto3
from sagemaker.core.helper.session_helper import Session, get_execution_role

sagemaker_session = Session()
role = get_execution_role()
region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()

In [None]:
prefix = '20newsgroups'

train_prefix = os.path.join(prefix, 'train')
val_prefix = os.path.join(prefix, 'val')
output_prefix = os.path.join(prefix, 'output')

s3_train_data = os.path.join('s3://', bucket, train_prefix)
s3_val_data = os.path.join('s3://', bucket, val_prefix)
output_path = os.path.join('s3://', bucket, output_prefix)
print('Training set location', s3_train_data)
print('Validation set location', s3_val_data)
print('Trained model will be saved at', output_path)

**V3 Migration**: Convert sparse matrices to CSV format and upload to S3

In [None]:
# Delete old files from S3
s3 = boto3.resource('s3')
bucket_obj = s3.Bucket(bucket)

for obj in bucket_obj.objects.filter(Prefix='20newsgroups/train/'):
    obj.delete()
for obj in bucket_obj.objects.filter(Prefix='20newsgroups/val/'):
    obj.delete()

print("Deleted old files. Now re-run the upload cells.")

In [None]:
# Upload CSV with extra column for NTM bug
def upload_csv_with_extra_col(sparray, bucket, prefix, template, n_parts):
    chunk_size = sparray.shape[0] // n_parts
    for i in range(n_parts):
        start = i * chunk_size
        end = (i + 1) * chunk_size if i + 1 < n_parts else sparray.shape[0]
        
        chunk = sparray[start:end].toarray().astype(int)
        fname = template.format(i)
        
        with open(fname, 'w') as f:
            for row in chunk:
                # Add extra 0 column at the end for NTM bug
                f.write(','.join(map(str, row)) + ',0\n')
        
        s3_key = os.path.join(prefix, fname)
        boto3.resource('s3').Bucket(bucket).upload_file(fname, s3_key)
        print(f'Uploaded: s3://{bucket}/{s3_key}')
        os.remove(fname)

upload_csv_with_extra_col(train_vectors, bucket, train_prefix, 'train_part{}.csv', 8)
upload_csv_with_extra_col(val_vectors, bucket, val_prefix, 'val_part{}.csv', 1)

In [None]:
from sagemaker.core import image_uris

container = image_uris.retrieve(
    framework='ntm',
    region=region
)

In [None]:
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import InputData, Compute

num_topics = 20

print(f'Training with feature_dim={train_vectors.shape[1]}')
trainer = ModelTrainer(
    training_image=container,
    role=role,
    compute=Compute(
        instance_count=2,
        instance_type='ml.c4.xlarge'
    ),
    hyperparameters={
        'num_topics': str(num_topics),
        'feature_dim': '2000',
        'mini_batch_size': '128',
        'epochs': '100',
        'num_patience_epochs': '5',
        'tolerance': '0.001'
    },
    sagemaker_session=sagemaker_session
)



In [None]:
from sagemaker.core.shapes.shapes import S3DataSource

training_job = trainer.train(
    input_data_config=[
        InputData(
            channel_name='train',
            data_source=S3DataSource(
                s3_data_type='S3Prefix',
                s3_uri=s3_train_data,
                s3_data_distribution_type='ShardedByS3Key'
            ),
            content_type='text/csv'
        ),
        InputData(
            channel_name='validation',
            data_source=s3_val_data,
            content_type='text/csv'
        )
    ],
    wait=True,
    logs=True
)

In [None]:
training_job = trainer._latest_training_job

print('Training job name: {}'.format(training_job.training_job_name))
print('Training job status: {}'.format(training_job.training_job_status))


# Model Hosting and Inference

**V3 Migration**: Using resource classes (`Model`, `EndpointConfig`, `Endpoint`) instead of `deploy()`

In [None]:
from sagemaker.core.resources import Model, EndpointConfig, Endpoint
import time

model_name = f"ntm-model-{int(time.time())}"
endpoint_config_name = f"ntm-endpoint-config-{int(time.time())}"
endpoint_name = f"ntm-endpoint-{int(time.time())}"

In [None]:
from sagemaker.core.resources import Model

# Create model from training job artifacts
model = Model.create(
    model_name=f'ntm-model-{int(__import__("time").time())}',
    execution_role_arn=role,
    primary_container={
        'image': container,
        'model_data_url': training_job.model_artifacts.s3_model_artifacts
    }
)

print(f'Model created: {model.model_name}')


In [None]:
from sagemaker.core.resources import EndpointConfig

# Create endpoint configuration
endpoint_config = EndpointConfig.create(
    endpoint_config_name=f'ntm-config-{int(__import__("time").time())}',
    production_variants=[{
        'variant_name': 'AllTraffic',
        'model_name': model.model_name,
        'initial_instance_count': 1,
        'instance_type': 'ml.m4.xlarge'
    }]
)

print(f'Endpoint config created: {endpoint_config.endpoint_config_name}')


In [None]:
from sagemaker.core.resources import Endpoint

# Create endpoint
endpoint = Endpoint.create(
    endpoint_name=f'ntm-endpoint-{int(__import__("time").time())}',
    endpoint_config_name=endpoint_config.endpoint_config_name
)

print(f'Endpoint created: {endpoint.endpoint_name}')
endpoint.wait_for_status('InService')
print('Endpoint is ready!')


In [None]:
print('Endpoint name: {}'.format(endpoint.endpoint_name))

---
## Data Serialization/Deserialization

**V3 Migration**: Using `endpoint.invoke()` method

In [None]:
import json
import numpy as np

def np2csv(arr):
    csv = '\n'.join([','.join([str(x) for x in row]) for row in arr])
    return csv

test_data = np.array(test_vectors.todense())
payload = np2csv(test_data[:5])

response = endpoint.invoke(
    body=payload,
    content_type='text/csv'
)

results = json.loads(response.body.read().decode())
print(results)

In [None]:
predictions = np.array([prediction['topic_weights'] for prediction in results['predictions']])
print(predictions)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

fs = 12
df=pd.DataFrame(predictions.T)
df.plot(kind='bar', figsize=(16,4), fontsize=fs)
plt.ylabel('Topic assignment', fontsize=fs+2)
plt.xlabel('Topic ID', fontsize=fs+2)

## Stop / Close the Endpoint

**V3 Migration**: Using `endpoint.delete()` method

In [None]:
endpoint.delete()

---
# Model Exploration

The trained NTM model contains learned topic representations. We can download and explore the model artifacts to understand the topics discovered in the 20 newsgroups dataset.

In [None]:
# Get training job reference
training_job = trainer._latest_training_job
print(f"Training job: {training_job.training_job_name}")
print(f"Status: {training_job.training_job_status}")

In [None]:
# Get topic distributions from endpoint
import numpy as np
import json

# Sample diverse documents from test set
test_data = np.array(test_vectors.todense())
sample_size = min(500, test_data.shape[0])
sample_indices = np.linspace(0, test_data.shape[0]-1, sample_size, dtype=int)
test_sample = test_data[sample_indices]

print(f"Using {sample_size} diverse samples")

# Get predictions
payload = np2csv(test_sample)
response = endpoint.invoke(body=payload, content_type="text/csv")
results = json.loads(response.body.read().decode())

# Extract topic distributions
topic_distributions = np.array([pred["topic_weights"] for pred in results["predictions"]])
print(f"Topic distributions shape: {topic_distributions.shape}")

In [None]:
# Extract distinctive words per topic
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

custom_stops = set([
    "don", "just", "think", "people", "like", "know", "time", "does", "said",
    "did", "way", "say", "good", "right", "ve", "ll", "didn", "doesn", "isn",
    "wasn", "aren", "god", "religion", "believe", "point", "things", "thing",
    "make", "want", "going", "really", "question", "post", "better", "claim"
])
all_stops = ENGLISH_STOP_WORDS.union(custom_stops)

def get_distinctive_words(topic_idx, n_words=20):
    topic_strengths = topic_distributions[:, topic_idx]
    high_threshold = np.percentile(topic_strengths, 85)
    high_mask = topic_strengths > high_threshold
    low_threshold = np.percentile(topic_strengths, 50)
    low_mask = topic_strengths < low_threshold
    
    if high_mask.sum() < 5 or low_mask.sum() < 5:
        return []
    
    high_docs = test_sample[high_mask]
    low_docs = test_sample[low_mask]
    high_freq = (high_docs > 0).sum(axis=0) / high_mask.sum()
    low_freq = (low_docs > 0).sum(axis=0) / low_mask.sum()
    diff = high_freq - low_freq
    
    filtered_words = []
    for idx in np.argsort(diff)[::-1]:
        word = vocab_list[idx]
        if word.lower() not in all_stops and len(word) > 2 and diff[idx] > 0:
            filtered_words.append((word, diff[idx]))
            if len(filtered_words) >= n_words:
                break
    return filtered_words

print("Distinctive Topic Words:")
print("="*60)
for topic_idx in range(min(5, topic_distributions.shape[1])):
    words = get_distinctive_words(topic_idx, 10)
    if words:
        print(f"\nTopic {topic_idx}: " + ", ".join([word for word, _ in words]))

In [None]:
# Visualize topic distributions
import matplotlib.pyplot as plt

avg_topic_strength = topic_distributions.mean(axis=0)
plt.figure(figsize=(12, 6))
plt.bar(range(len(avg_topic_strength)), avg_topic_strength)
plt.xlabel("Topic Index")
plt.ylabel("Average Strength")
plt.title("Average Topic Strength Across Documents")
plt.tight_layout()
plt.show()

In [None]:
# Create word clouds for all topics
!pip install wordcloud -q
from wordcloud import WordCloud

fig, axes = plt.subplots(5, 4, figsize=(20, 25))
for idx, ax in enumerate(axes.flat):
    if idx < topic_distributions.shape[1]:
        words = get_distinctive_words(idx, 50)
        if words:
            word_freq = {word: score for word, score in words}
            wc = WordCloud(width=600, height=400,
                          background_color="white",
                          colormap="tab20").generate_from_frequencies(word_freq)
            ax.imshow(wc, interpolation="bilinear")
        ax.set_title(f"Topic {idx}", fontsize=14, fontweight="bold")
        ax.axis("off")
    else:
        ax.axis("off")
plt.tight_layout()
plt.savefig("ntm_topics.png", dpi=150, bbox_inches="tight")
plt.show()
print(f"Generated word clouds for {topic_distributions.shape[1]} topics")