#Installing required dependencies

In [1]:
import requests
import json
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from itertools import combinations
from collections import Counter

from tensorflow.keras import layers, models, Input
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping

from sklearn.model_selection import train_test_split
from ast import literal_eval
from sklearn.metrics import precision_score, recall_score, f1_score

import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import re

In [2]:
nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


True

#Downloading Dataset

In [3]:
dataset_url = "https://www.federalregister.gov/api/v1/documents"
params = {"format": "json", "per_page": 2000}

docs = []
total_documents = 10000
downloaded = 0

while downloaded < total_documents:
    response = requests.get(dataset_url, params=params)
    if response.status_code == 200:
        data = response.json()
        results = data.get("results", [])
        docs.extend(results)
        downloaded += len(results)
        dataset_url = data.get("next_page_url")
        if not dataset_url:
            break
    else:
        print(f"Failed to fetch data. HTTP Status Code: {response.status_code}")
        break

total_docs = docs[:total_documents]
with open("federal_register.json", "w", encoding="utf-8") as file:
    json.dump(total_docs, file, ensure_ascii=False)
print(f"Successfully downloaded {len(total_docs)} documents to 'federal_register.json'")

Successfully downloaded 10000 documents to 'federal_register.json'


In [4]:
data_set = pd.read_json('federal_register.json')
data_set.head()

Unnamed: 0,title,type,abstract,document_number,html_url,pdf_url,public_inspection_pdf_url,publication_date,agencies,excerpts
0,Agency Information Collection Activities; Subm...,Notice,"The Department of Commerce, in accordance with...",2024-28463,https://www.federalregister.gov/documents/2024...,https://www.govinfo.gov/content/pkg/FR-2024-12...,https://public-inspection.federalregister.gov/...,2024-12-05,"[{'raw_name': 'DEPARTMENT OF COMMERCE', 'name'...","The Department of Commerce, in accordance with..."
1,"Approval of Subzone Status; Canoo Inc., Pryor,...",Notice,,2024-28472,https://www.federalregister.gov/documents/2024...,https://www.govinfo.gov/content/pkg/FR-2024-12...,https://public-inspection.federalregister.gov/...,2024-12-05,"[{'raw_name': 'DEPARTMENT OF COMMERCE', 'name'...",
2,Foreign-Trade Zone 40; Application for Expansi...,Notice,,2024-28473,https://www.federalregister.gov/documents/2024...,https://www.govinfo.gov/content/pkg/FR-2024-12...,https://public-inspection.federalregister.gov/...,2024-12-05,"[{'raw_name': 'DEPARTMENT OF COMMERCE', 'name'...",
3,"Foreign-Trade Zone 2-New Orleans, Louisiana; A...",Notice,,2024-28471,https://www.federalregister.gov/documents/2024...,https://www.govinfo.gov/content/pkg/FR-2024-12...,https://public-inspection.federalregister.gov/...,2024-12-05,"[{'raw_name': 'DEPARTMENT OF COMMERCE', 'name'...",
4,Agency Information Collection Activities; Subm...,Notice,"The Department of Commerce, in accordance with...",2024-28460,https://www.federalregister.gov/documents/2024...,https://www.govinfo.gov/content/pkg/FR-2024-12...,https://public-inspection.federalregister.gov/...,2024-12-05,"[{'raw_name': 'DEPARTMENT OF COMMERCE', 'name'...","The Department of Commerce, in accordance with..."


#Exploratory Data Analysis

1) Number of documents

2) Number of documents per category

3) Number of categories per document

4) Average word length

5) Average word length per document

6) Agency collaboration

In [5]:
# 1. number of documents
data = json.loads(Path('federal_register.json').read_text(encoding="utf-8"))
docs_count = {doc["document_number"] for doc in data if "document_number" in doc}

print(f'Total number of documents in the dataset : {len(docs_count)}')

Total number of documents in the dataset : 10000


In [6]:
# 2. number of documents per category/agency

df = pd.read_json('federal_register.json')
#dropped rows which doesn't have abstract
df = df[~pd.isna(df['abstract'])]
#dropping rows which doesnot have name(agency_name) in agencies column
df = df[df['agencies'].apply(lambda x: all('name' in agency for agency in x))]
#creating new column 'agency_names' to store agency names separately
df.loc[:, 'agency_names'] = df['agencies'].apply(lambda x: [agency['name'] for agency in x])

new_df = df.explode('agency_names')
agencies_count = new_df['agency_names'].value_counts()
agencies_count_df = pd.DataFrame({
    'agency_name': agencies_count.index,
    'docs_count': agencies_count.values
})

#top 10 agencies with most doc count
agencies_count_df[:10]

Unnamed: 0,agency_name,docs_count
0,Commerce Department,808
1,Postal Service,799
2,Transportation Department,704
3,Health and Human Services Department,623
4,Interior Department,531
5,Homeland Security Department,521
6,Federal Aviation Administration,395
7,Environmental Protection Agency,383
8,National Oceanic and Atmospheric Administration,371
9,International Trade Administration,326


In [7]:
# 3. number of agencies per document

num_agencies_per_doc = df['agency_names'].apply(len)
doc_agency_count_df = pd.DataFrame({
    'document_number': df['document_number'],
    'num_agencies': num_agencies_per_doc
})

#Finding document with max agencies and printing it's tilte, doc_num and agencies
max_agencies_row = doc_agency_count_df[doc_agency_count_df['num_agencies'] == doc_agency_count_df['num_agencies'].max()]
document_number = max_agencies_row['document_number'].iloc[0]
agencies = df[df['document_number'] == document_number]['agency_names'].iloc[0]
title = df[df['document_number'] == document_number]['title'].iloc[0]
print("----------Document with max agencies---------------")
print(f"Document Number: {document_number}")
print(f"Title: {title}")
print(f"Agencies count: {doc_agency_count_df['num_agencies'].max()}")
print(f"Agencies: {agencies}")

----------Document with max agencies---------------
Document Number: 2024-18415
Title: Financial Data Transparency Act Joint Data Standards
Agencies count: 9
Agencies: ['Treasury Department', 'Comptroller of the Currency', 'Federal Reserve System', 'Federal Deposit Insurance Corporation', 'National Credit Union Administration', 'Consumer Financial Protection Bureau', 'Federal Housing Finance Agency', 'Commodity Futures Trading Commission', 'Securities and Exchange Commission']


In [8]:
# 4. average word length in abstract

all_abstract = df['abstract'].str.cat(sep=' ')

#function to count length of each word
def letterCounter(text):
    words = text.split()
    return [len(word) for word in words]

word_length_counter = letterCounter(all_abstract)
print(f'Average word length in abstract field across whole dataset : {np.mean(word_length_counter)}')

Average word length in abstract field across whole dataset : 5.716986136370888


In [9]:
# 5. average word length per document

abstract_text_df = df[['document_number', 'abstract']].copy()
abstract_text_df['word_length_list'] = abstract_text_df['abstract'].apply(letterCounter)
abstract_text_df['avg_word_length'] = abstract_text_df['word_length_list'].apply(np.mean)
abstract_text_df.sort_values(by='avg_word_length', ascending=False)

Unnamed: 0,document_number,abstract,word_length_list,avg_word_length
3858,2024-24537,The Commission will consider a restricted adju...,"[3, 10, 4, 8, 1, 10, 12, 6, 4, 3, 5, 7, 72]",11.153846
3441,2024-24947,This notice announces the availability of the ...,"[4, 6, 9, 3, 12, 2, 3, 5, 4, 5, 3, 12, 6, 4, 7...",7.360000
7314,2024-20970,The National Toxicology Program (NTP) Interage...,"[3, 8, 10, 7, 5, 11, 6, 3, 3, 10, 2, 11, 13, 7...",7.333333
2521,2024-26051,The Natural Resources Conservation Service (NR...,"[3, 7, 9, 12, 7, 6, 8, 6, 5, 3, 3, 11, 2, 3, 1...",7.329412
5264,2024-23138,The Environmental Protection Agency (EPA) is p...,"[3, 13, 10, 6, 5, 2, 9, 6, 2, 6, 7, 10, 5, 2, ...",7.232558
...,...,...,...,...
7068,2024-21270,The Board updates for 2024 the fees that the p...,"[3, 5, 7, 3, 4, 3, 4, 4, 3, 6, 4, 3, 2, 4, 7, ...",4.452381
893,2024-27567,Notice is given of the names of members of the...,"[6, 2, 5, 2, 3, 5, 2, 7, 2, 3, 4, 11, 6, 5, 3,...",4.428571
9573,2024-18727,This notice publishes the 2024 List of Explosi...,"[4, 6, 9, 3, 4, 4, 2, 9, 10, 2, 8, 2, 4, 3, 4,...",4.423077
1011,2024-27459,Notice is given of the names of members of a P...,"[6, 2, 5, 2, 3, 5, 2, 7, 2, 1, 11, 6, 5, 3, 3,...",4.368421


In [10]:
#6. Agencies collaboration...which agencies frequently co-publish documents

collaborative_docs = df[df['agency_names'].apply(len) > 1].copy()
collaborative_docs['agency_pairs'] = collaborative_docs['agency_names'].apply(
    lambda agencies: list(combinations(agencies, 2))
)

all_agency_pairs = [agency for pair in collaborative_docs['agency_pairs'] for agency in pair]
collaboration_counts = Counter(all_agency_pairs)
top_collaborations = collaboration_counts.most_common(10)
print(f'Number of documents with multiple agencies: {doc_agency_count_df[doc_agency_count_df["num_agencies"]>1].shape[0]}')
print()
print("Top 10 agency collaborations:")
for pair, count in top_collaborations:
    print(f"{pair}: {count} documents")

Number of documents with multiple agencies: 3890

Top 10 agency collaborations:
('Transportation Department', 'Federal Aviation Administration'): 395 documents
('Commerce Department', 'National Oceanic and Atmospheric Administration'): 371 documents
('Commerce Department', 'International Trade Administration'): 326 documents
('Interior Department', 'National Park Service'): 269 documents
('Health and Human Services Department', 'Food and Drug Administration'): 222 documents
('Homeland Security Department', 'Federal Emergency Management Agency'): 207 documents
('Homeland Security Department', 'Coast Guard'): 172 documents
('Treasury Department', 'Internal Revenue Service'): 125 documents
('Health and Human Services Department', 'Centers for Medicare & Medicaid Services'): 93 documents
('Interior Department', 'Land Management Bureau'): 93 documents


In [11]:
data_frame = df.copy()

In [12]:
#prepared abstract by lemmatizing it and removing the stop words.  Lower case and use the same regexp tokenizer with `r'(\b[\w]{2,}\b)'`.
lemmatizer = WordNetLemmatizer()
stop_words_list = set(stopwords.words('english'))

#function for text preparation
def clean_text(text):
    text = text.lower()
    tokens = re.findall(r'(\b[\w]{2,}\b)', text)
    cleaned_text = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words_list]
    return ' '.join(cleaned_text)

data_frame.loc[:, 'abstract'] = data_frame['abstract'].apply(clean_text)

#Multi-Label Classification

In [55]:
#splitting data

test_split = 0.2
train_df, test_df = train_test_split(
    data_frame,
    test_size=test_split,
)

val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")

Number of rows in training set: 5726
Number of rows in validation set: 716
Number of rows in test set: 716


In [56]:
#Multi-label binarization.....Now we preprocess our labels using the StringLookup layer.

terms = tf.ragged.constant(train_df["agency_names"].values)
lookup = tf.keras.layers.StringLookup(output_mode="multi_hot")
lookup.adapt(terms)
vocab = lookup.get_vocabulary()

def invert_multi_hot(encoded_labels):
    hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
    return np.take(vocab, hot_indices)

print("Vocabulary:\n")
print(vocab)

Vocabulary:

['[UNK]', 'Commerce Department', 'Postal Service', 'Transportation Department', 'Health and Human Services Department', 'Interior Department', 'Homeland Security Department', 'Federal Aviation Administration', 'Environmental Protection Agency', 'National Oceanic and Atmospheric Administration', 'International Trade Administration', 'Treasury Department', 'National Park Service', 'Food and Drug Administration', 'Small Business Administration', 'Agriculture Department', 'Labor Department', 'Federal Emergency Management Agency', 'Coast Guard', 'Justice Department', 'Energy Department', 'Federal Communications Commission', 'Internal Revenue Service', 'International Trade Commission', 'Veterans Affairs Department', 'Education Department', 'Land Management Bureau', 'Postal Regulatory Commission', 'Centers for Medicare & Medicaid Services', 'Defense Department', 'State Department', 'Nuclear Regulatory Commission', 'Foreign Assets Control Office', 'Centers for Disease Control and 

In [57]:
#Here we are separating the individual unique classes available from the label pool and then using this information to represent a given label set with 0's and 1's.

sample_label = train_df["agency_names"].iloc[0]
print(f"Original label: {sample_label}")

label_binarized = lookup([sample_label])
print(f"Label-binarized representation: {label_binarized}")

Original label: ['Environmental Protection Agency']
Label-binarized representation: [[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]


In [58]:
batch_size = 32

def make_dataset(dataframe, is_train=True):
    agency_labels = tf.ragged.constant(dataframe["agency_names"].values)
    type_labels = pd.get_dummies(dataframe["type"]).values  # Convert type column to one-hot encoding
    label_binarized = lookup(agency_labels).numpy()

    # Ensure the text data is in a tensor of shape (batch_size,)
    dataset = tf.data.Dataset.from_tensor_slices(
        (dataframe["abstract"].values, {"agency_names": label_binarized, "type": type_labels})
    )

    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    return dataset.batch(batch_size)

In [59]:
# prepare datasets
train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)

In [60]:
for text_batch, label_batch in train_dataset.take(1):
       print(text_batch.shape)

(32,)


In [61]:
# Extract a batch from the train dataset
text_batch, label_batch = next(iter(train_dataset))

# Loop through the first 10 examples in the batch
for i, text in enumerate(text_batch[:10]):
    agency_label = label_batch['agency_names'][i].numpy()[None, ...]  # Get the agency label
    type_label = label_batch['type'][i].numpy()  # Get the type label

    # Invert multi-hot encoding for agency names
    agency_names = invert_multi_hot(agency_label[0])

    # Extract the index of the active class from one-hot encoded `type_label`
    type_label_scalar = type_label.argmax()

    # Determine the type (notice, rule, or proposed rule)
    if type_label_scalar == 0:
        type_pred = "rule"
    elif type_label_scalar == 1:
        type_pred = "notice"
    else:
        type_pred = "proposed rule"

    print(f"Abstract: {text}")
    print(f"Agency Labels: {agency_names}")
    print(f"Type: {type_pred}")
    print(" ")

Abstract: b'purpose federal register notice provide public notice secretary department veteran affair va intends enter enhanced use lease eul building 10 approximately acre underutilized land prescott campus northern arizona va health care system vahcs'
Agency Labels: ['Veterans Affairs Department']
Type: rule
 
Abstract: b'accordance paperwork reduction act 1995 pra onrr proposing renew information collection information collection request icr onrr seek renewed authority collect information related paperwork requirement chief financial officer act 1990 cfo act covering collection royalty mineral revenue due obligation accounted account receivables'
Agency Labels: ['Interior Department' 'Natural Resources Revenue Office']
Type: rule
 
Abstract: b'environmental protection agency epa correcting error found vessel incidental discharge national standard performance final rule final rule appeared federal register october 2024 correction remove footnote superscript number included error acco

In [62]:
# Prepare the vocabulary
vocabulary = set()
train_df["abstract"].str.lower().str.split().apply(vocabulary.update)
vocabulary_size = len(vocabulary)
print(f"Vocabulary size: {vocabulary_size}")

Vocabulary size: 12915


In [63]:
# Text Vectorization
text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)

# Adapt the TextVectorization layer based on the training dataset
with tf.device("/CPU:0"):
    text_vectorizer.adapt(train_df["abstract"].values)

In [64]:
# Apply vectorization to datasets
train_dataset = train_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(tf.data.experimental.AUTOTUNE)

validation_dataset = validation_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(tf.data.experimental.AUTOTUNE)

test_dataset = test_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(tf.data.experimental.AUTOTUNE)


#Model Preparation & Training

In [65]:
#layers preparation

# input layer
text_input = Input(shape=(vocabulary_size,), dtype="float32") # Adjust shape value according to Vocabulary size

#output layers
agency_names_output = layers.Dense(512, activation='relu')(text_input)
agency_names_output = layers.Dense(256, activation='relu')(agency_names_output)
agency_names_output = layers.Dense(128, activation='relu')(agency_names_output)
agency_names_output = layers.Dense(lookup.vocabulary_size(), activation='sigmoid', name="agency_names")(agency_names_output) #multi-label classification

type_output = layers.Dense(512, activation='relu')(text_input)
type_output = layers.Dense(256, activation='relu')(type_output)
type_output = layers.Dense(128, activation='relu')(type_output)
type_output = layers.Dense(3, activation='softmax', name="type")(type_output)  # Multi-class classification

In [66]:
model = models.Model(inputs=text_input, outputs={"agency_names": agency_names_output, "type": type_output})
model.compile(
    loss={"agency_names": "binary_crossentropy", "type": "categorical_crossentropy"},
    optimizer="adam",
    metrics={"agency_names": "binary_accuracy", "type": "categorical_accuracy"}
)

In [67]:
model.summary()

In [68]:
early_stopping = EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)

#model training
epochs = 10
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    callbacks=[early_stopping]
)

Epoch 1/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 210ms/step - agency_names_binary_accuracy: 0.9407 - agency_names_loss: 0.1503 - loss: 0.6473 - type_categorical_accuracy: 0.8474 - type_loss: 0.4970 - val_agency_names_binary_accuracy: 0.9984 - val_agency_names_loss: 0.0083 - val_loss: 0.1052 - val_type_categorical_accuracy: 0.9679 - val_type_loss: 0.0962
Epoch 2/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 183ms/step - agency_names_binary_accuracy: 0.9989 - agency_names_loss: 0.0052 - loss: 0.0466 - type_categorical_accuracy: 0.9894 - type_loss: 0.0415 - val_agency_names_binary_accuracy: 0.9990 - val_agency_names_loss: 0.0055 - val_loss: 0.1093 - val_type_categorical_accuracy: 0.9665 - val_type_loss: 0.1026
Epoch 3/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 196ms/step - agency_names_binary_accuracy: 0.9996 - agency_names_loss: 0.0015 - loss: 0.0301 - type_categorical_accuracy: 0.9913 - type_loss: 0.0286 - v

In [69]:
final_epoch_index = len(history.history['loss']) - 1

print("----------- Training and validation Metrics ---------------")
print(f" Agency prediction accuracy on the train set: {history.history['agency_names_binary_accuracy'][final_epoch_index]}")
print(f" Agency prediction accuracy on the validation set: {history.history['val_agency_names_binary_accuracy'][final_epoch_index]}")
print(f" Type prediction accuracy on the train set: {history.history['type_categorical_accuracy'][final_epoch_index]}")
print(f" Type prediction accuracy on the validation set: {history.history['val_type_categorical_accuracy'][final_epoch_index]}")

----------- Training and validation Metrics ---------------
 Agency prediction accuracy on the train set: 0.9998420476913452
 Agency prediction accuracy on the validation set: 0.9990457892417908
 Type prediction accuracy on the train set: 0.9970310926437378
 Type prediction accuracy on the validation set: 0.9553072452545166


#Model Evaluation

In [70]:
loss, agency_loss, type_loss, agency_acc, type_acc = model.evaluate(test_dataset)
print("----- Test Metrics ------")
print(f"Agency prediction accuracy on the test set: {round(agency_acc * 100, 2)}%")
print(f"Type prediction accuracy on the test set: {round(type_acc * 100, 2)}%")

[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 134ms/step - agency_names_binary_accuracy: 0.9982 - agency_names_loss: 0.0091 - loss: 0.1315 - type_categorical_accuracy: 0.9649 - type_loss: 0.1226
----- Test Metrics ------
Agency prediction accuracy on the test set: 99.83%
Type prediction accuracy on the test set: 96.51%


In [71]:
predictions = model.predict(test_dataset)

#Extract agency predictions and true labels
agency_predictions = predictions['agency_names']
agency_true_labels = np.concatenate([labels['agency_names'] for _, labels in test_dataset], axis=0)
agency_pred_binary = (agency_predictions > 0.5).astype(int) # Convert predictions to binary (0 or 1)

#metrics calculation
precision_agency = precision_score(agency_true_labels, agency_pred_binary, average='micro')
recall_agency = recall_score(agency_true_labels, agency_pred_binary, average='micro')
f1_agency = f1_score(agency_true_labels, agency_pred_binary, average='micro')

print(f"Agency prediction Precision: {precision_agency:.4f}")
print(f"Agency prediction Recall: {recall_agency:.4f}")
print(f"Agency prediction F1 Score: {f1_agency:.4f}")



[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 144ms/step
Agency prediction Precision: 0.9825
Agency prediction Recall: 0.8007
Agency prediction F1 Score: 0.8823


In [72]:
#Extract type predictions and true labels
type_predictions = predictions['type']
type_true_labels = np.concatenate([labels['type'] for _, labels in test_dataset], axis=0)
type_pred_labels = np.argmax(type_predictions, axis=1)
type_true_labels = np.argmax(type_true_labels, axis = 1)

#metrics calculation
precision_type = precision_score(type_true_labels, type_pred_labels, average='macro')
recall_type = recall_score(type_true_labels, type_pred_labels, average='macro')
f1_type = f1_score(type_true_labels, type_pred_labels, average='macro')

print(f"Type prediction Precision: {precision_type:.4f}")
print(f"Type prediction Recall: {recall_type:.4f}")
print(f"Type prediction F1 Score: {f1_type:.4f}")

Type prediction Precision: 0.9453
Type prediction Recall: 0.8901
Type prediction F1 Score: 0.9159


#Predictions on test set

In [73]:
# Combine TextVectorization and trained model into a single inference model
model_for_inference = models.Sequential([text_vectorizer, model])

inference_dataset = make_dataset(test_df.sample(20), is_train=False) #taking a sample of 20 abstracts
text_batch, label_batch = next(iter(inference_dataset))

# Perform inference on the text batch
predicted_probabilities = model_for_inference.predict(text_batch)

for i, text in enumerate(text_batch[:len(predicted_probabilities["agency_names"])]):
    type_label = label_batch['type'][i].numpy()

    # Invert multi-hot encoding for agency names
    true_agency_names = label_batch['agency_names'][i].numpy()
    true_agency_names = [lookup.get_vocabulary()[j] for j in np.where(true_agency_names == 1)[0]]
    true_type = ["rule", "proposed rule", "notice"][np.where(type_label)[0][0]]

    # Predicted probabilities for agency names and type
    agency_proba = predicted_probabilities["agency_names"][i]
    type_proba = predicted_probabilities["type"][i]  # Vector of probabilities for three classes
    predicted_type = ["rule", "proposed rule", "notice"][np.argmax(type_proba)]  # Decode the class

    top_3_agencies = [
        name
        for _, name in sorted(
            zip(agency_proba, lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]

    print(f"Abstract: {text.numpy().decode('utf-8')}")
    print(f"True Agency Labels: {', '.join(true_agency_names)}")
    print(f"Predicted Agencies: {', '.join(top_3_agencies)}")
    print(f"True Type: {true_type}")
    print(f"Predicted Type: {predicted_type}")
    print(" ")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 216ms/step
Abstract: faa proposes adopt new airworthiness directive ad certain airbus sa model a319 111 112 113 114 115 131 132 133 airplane a320 211 212 214 216 231 232 233 airplane a321 111 112 131 211 212 213 231 232 airplane proposed ad prompted full scale fatigue test found crack main landing gear mlg bay rear skin panel stringer run frame 46 stringer 32 left hand right hand side proposed ad would require repetitive special detailed inspection sdis affected area cracking applicable corrective action specified european union aviation safety agency easa ad proposed incorporation reference ibr faa proposing ad address unsafe condition product
True Agency Labels: Transportation Department, Federal Aviation Administration
Predicted Agencies: Transportation Department, Federal Aviation Administration, Federal Motor Carrier Safety Administration
True Type: proposed rule
Predicted Type: proposed rule
 
Abstract: environmental p

#Predictions on unseen data

In [74]:
def infer_model(model, text_vectorizer, raw_text):
    preprocessed_text = text_vectorizer(tf.constant([raw_text]))
    predictions = model(preprocessed_text)
    agency_preds = predictions["agency_names"].numpy()[0]
    type_pred = predictions["type"].numpy()[0]
    predicted_agencies = invert_multi_hot(agency_preds > 0.5)
    predicted_type_index = np.argmax(type_pred)
    predicted_type = ["rule", "proposed rule", "notice"][predicted_type_index]
    return predicted_agencies, predicted_type

#example
#test_text = "The Food and Drug Administration (FDA or the Agency) has det"
#test_text = "The Environmental Protection Agency announces a public hearing."
test_text = "This action establishes new safety standards for motor vehicles."
predicted_agencies, predicted_type = infer_model(model, text_vectorizer, test_text)
print(f"text: {test_text}")
print(f"Predicted Agencies: {', '.join(predicted_agencies)}")
print(f"Predicted Type: {predicted_type}")

text: This action establishes new safety standards for motor vehicles.
Predicted Agencies: Transportation Department
Predicted Type: notice
