In [52]:
from zenml import step, pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
from scipy.special import softmax
import urllib.request
import csv
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import mlflow
import pandas as pd
import great_expectations as ge

label_mapping = {'negative': 0, 'neutral': 1, 'positive': 2}


@step
def read_tweets_from_file(file_path: str) -> list:
    with open(file_path, 'r', encoding='utf-8') as file:
        tweets = file.readlines()
    return [tweet.strip() for tweet in tweets]

@step
def read_labels_from_file(labels_path: str) -> list:
    with open(labels_path, 'r', encoding='utf-8') as file:
        labels = [int(line.strip()) for line in file]
    return labels

@step
def preprocess_step(texts: list) -> list:
    preprocessed_texts = []
    for text in texts:
        new_text = []
        for t in text.split(" "):
            t = '@user' if t.startswith('@') and len(t) > 1 else t
            t = 'http' if t.startswith('http') else t
            new_text.append(t)
        preprocessed_texts.append(" ".join(new_text))
    return preprocessed_texts


In [53]:
from cassandra.cluster import Cluster
import uuid

@step
def insert_preprocessed_tweets_into_cassandra(processed_texts: list):
  

    CASSANDRA_CLUSTER = ['localhost']
    KEYSPACE = 'mykeyspace'
    TABLE_NAME = 'preprocessed_tweets'

    cluster = Cluster(CASSANDRA_CLUSTER)
    session = cluster.connect(KEYSPACE)

    def insert_preprocessed_tweet(tweet_text):
        query = f"INSERT INTO {TABLE_NAME} (id, tweet_text) VALUES (%s, %s)"
        session.execute(query, (uuid.uuid4(), tweet_text))

    for tweet_text in processed_texts:
        stored_output=insert_preprocessed_tweet(tweet_text)
    
    print("All preprocessed tweets have been inserted into Cassandra.")
    


In [54]:
from zenml.materializers.base_materializer import BaseMaterializer
from transformers import AutoModelForSequenceClassification

@step
def model_inference_step(texts: list) -> list:
    predictions = []
    task = 'sentiment'
    MODEL = f"cardiffnlp/twitter-roberta-base-{task}"
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL)

    mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt"
    with urllib.request.urlopen(mapping_link) as f:
        html = f.read().decode('utf-8').split("\n")
        csvreader = csv.reader(html, delimiter='\t')
        labels = [row[1] for row in csvreader if len(row) > 1]

    for text in texts:
        encoded_input = tokenizer(text, return_tensors='pt')
        output = model(**encoded_input)
        scores = output[0][0].detach().numpy()
        scores = softmax(scores)
        ranking = np.argsort(scores)[::-1]
        text_predictions = [labels[i] for i in ranking] 
        predictions.append(text_predictions[0])
    return predictions


In [55]:
@step
def evaluate_predictions(predictions: list, true_labels: list) -> dict:
    predictions_mapped = [label_mapping[pred] for pred in predictions]
    
    accuracy = accuracy_score(true_labels, predictions_mapped)
    precision = precision_score(true_labels, predictions_mapped, average='weighted', zero_division=0)
    recall = recall_score(true_labels, predictions_mapped, average='weighted', zero_division=0)
    f1 = f1_score(true_labels, predictions_mapped, average='weighted', zero_division=0)
    
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)
    mlflow.log_metric("f1_score", f1)
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [56]:
import matplotlib.pyplot as plt

@step
def visualize_metrics(metrics: dict) -> str:
    
    names = list(metrics.keys())
    values = list(metrics.values())
    
    plt.figure(figsize=(10, 5))
    plt.bar(names, values)
    plt.ylabel('Score')
    plt.title('Model Evaluation Metrics')
    
    figure_path = 'metrics_figure.png'
    plt.savefig(figure_path)
    plt.close()
    
    return figure_path

In [None]:
@pipeline
def sentiment_analysis_pipeline_with_evaluation(file_path: str, labels_path: str):
    tweets = read_tweets_from_file(file_path)
    true_labels = read_labels_from_file(labels_path)
    processed_texts = preprocess_step(tweets)
    insert_preprocessed_tweets_into_cassandra(processed_texts)
    predictions = model_inference_step(processed_texts)
    evaluation_results = evaluate_predictions(predictions, true_labels)
    visualize_metrics(evaluation_results)
    

    

if __name__ == "__main__":
    file_path = '../val_text.txt' 
    labels_path = '../val_labels.txt' 
    sentiment_analysis_pipeline_with_evaluation(file_path, labels_path)
    

In [59]:

CASSANDRA_CLUSTER = ['localhost']
KEYSPACE = 'mykeyspace'
TABLE_NAME = 'preprocessed_tweets'    


def fetch_and_print_preprocessed_tweets():
    query = f"SELECT id, tweet_text FROM {TABLE_NAME}"
    rows = session.execute(query)
    
    for row in rows:
        print(f"ID: {row.id}, Tweet: {row.tweet_text}")
fetch_and_print_preprocessed_tweets()

ID: 6bcbafb0-e051-42e7-ba2a-7c6d3ce27431, Tweet: "Our Holiday Open House is in full swing!!! Please stop by if you can. Noon to 3 p.m. at 724 W. 2nd Avenue, Milan,...
ID: 7835b5bb-1d60-47b6-9f8e-9ac3a4f223fe, Tweet: holy shit holy shit holy shit just found out I'm gonna see Paul McCartney tomorrow!!!!!!!!!
ID: 3c508538-e3c4-4712-a999-49ef2c6f7add, Tweet: "Steve Jobs: Source: www.quotationspage.com --- Sunday, August 14, 2011\""You can't just ask customers what they wa...
ID: 500547fe-5824-4e85-907b-658998738635, Tweet: Any one want two Jason Aldean tickets for September 11th? Text or DM me
ID: 56edda71-1f96-49a2-98a0-3e7d3f0f9123, Tweet: Day 2 stuck in the Taipei airport. Possibly going to LA to wait out the storm. Didn't think I'd make it to Cali for the 1st time this way
ID: 93e659f8-2edc-4fe5-9306-13db1b10226a, Tweet: @user Was having trouble yesterday downloading Septembers edition using Google? Loads the 1st page then crashes?
ID: d3868151-8bd1-449a-945d-596b0ccc1573, Tweet: Sandy