In [1]:
# Imports 
import pandas as pd
import numpy as np
import pickle 
from sentence_transformers import SentenceTransformer
from typing import Callable, Union, List

import ipywidgets as widgets
from IPython.display import display


import warnings
warnings.filterwarnings('ignore')

## Demo Setup

- Load in embedding model from HF 
- Load in trained Multilabel classification models 
- Get defined labels from Module 2
- Load in sample data

In [7]:
# Load in embedding model 
embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

In [8]:
# Load in trained models 
lr = pickle.load(open('../trained_models/lr_v2.pkl', 'rb'))
kmeans = pickle.load(open('../trained_models/kmeans.pkl', 'rb'))

In [9]:
# Grab the labels from GPT API from Module 2
labels_dict = {
    0: ['social issues', 'personal development', 'business and economics', 'community building'],
    1: ['india', 'updates', 'testing', 'fatalities', 'recoveries', 'healthcare'],
    2: ['face masks', 'safety', 'protection', 'public health', 'prevention'],
    3: ['social media', 'resilience', 'community support', 'online events'],
    4: ['global', 'cases', 'deaths', 'statistics'],
    5: ['politics', 'government response', 'public health', 'conspiracy', 'human rights'],
    6: ['health', 'information','vacccine', 'public awareness'],
    7: ['layoffs', 'misinofrmation', 'mental health', 'lockdown', 'access', 'financial impact', 'political response', 'education']
}

# Grab the labels from GPT API from Module 2
labels = ['social-issues',
 'personal-development',
 'business-and-economics',
 'community-building',
 'india',
 'updates',
 'testing',
 'fatalities',
 'recoveries',
 'healthcare',
 'face-masks',
 'safety',
 'protection',
 'public-health',
 'prevention',
 'social-media',
 'resilience',
 'community-support',
 'online-events',
 'global',
 'cases',
 'deaths',
 'statistics',
 'politics',
 'government-response',
 'conspiracy',
 'human-rights',
 'health',
 'information',
 'vacccine',
 'public-awareness',
 'layoffs',
 'misinformation',
 'mental-health',
 'lockdown',
 'access',
 'financial-impact',
 'political-response',
 'education']

## Demo Functions

In [18]:
def make_prediction(tweet: str, 
                    labels: Union[dict, list],
                    embedding_model: SentenceTransformer, 
                    classification_model = None,
                    clustering_model = None): 
    
    """
    Generate a list of predicted labels for an input tweet
    """
    
    # Generate the word embedding 
    embedding = embedding_model.encode(tweet)
    embedding = embedding.reshape(1, -1).astype(float)
    
    # If both there -> return nothing because error
    if (clustering_model) and (classification_model):
        return None
    
    # If clustering -> use that
    if (clustering_model):
        prediction = clustering_model.predict(embedding)
        return labels[prediction[0]]
    
    if (classification_model):
        prediction = classification_model.predict(embedding)[0]
        return [labels[i] for i in range(len(labels)) if prediction[i] == 1]

## Demo

In [19]:
# Demo

# Tweet Input
tweet = input("Enter a tweet: ")

# Select model type 
options = ['KMeans + GPT API', 'Logistic Regression']
dropdown = widgets.Dropdown(options = options,
                            value = options[0],
                            description = 'Select Model Type: ')

# Button to process the dataset and generate the plot
button = widgets.Button(description = 'Run Inference', tooltip = 'Run Inference')

# Display Result
output = widgets.Output()

def run_inference(labels, labels_dict, embedding_model, kmeans, lr): 
    
    with output: 
        
        # Clear any existing display
        output.clear_output()
    
        # If clustering
        if dropdown.value == options[0]: 
            labels = make_prediction(tweet = tweet, 
                                     labels = labels_dict, 
                                     embedding_model = embedding_model,
                                     clustering_model = kmeans, 
                                     classification_model = None)
            
        else: 
            labels = make_prediction(tweet = tweet, 
                                     labels = labels,
                                     embedding_model = embedding_model,
                                     clustering_model = None, 
                                     classification_model = lr)
            
        # Print output
        print(labels)

button.on_click(lambda _: run_inference(labels, labels_dict, embedding_model, kmeans, lr))

# Display everything 
display(dropdown)
display(button)
display(output)

Enter a tweet: #COVID19 cases are on the rise again. Let's all do our part to stop the spread: get vaccinated, wear a mask, and social distance. Together, we can beat this virus.


Dropdown(description='Select Model Type: ', options=('KMeans + GPT API', 'Logistic Regression'), value='KMeans…

Button(description='Run Inference', style=ButtonStyle(), tooltip='Run Inference')

Output()

In [17]:
"#COVID19 cases are on the rise again. Let's all do our part to stop the spread: get vaccinated, wear a mask, and social distance. Together, we can beat this virus."

"#COVID19 cases are on the rise again. Let's all do our part to stop the spread: get vaccinated, wear a mask, and social distance. Together, we can beat this virus."