## Label Dataset 

In [2]:
%%capture
pip install -r ../../requirements.txt

In [3]:
import sys

# set this on the path so that we can reference the commong data locations
sys.path.append("../../scripts/")

In [4]:
from data_collection import authenticate_google_drive, grab_google_drive_folder_data

drive = authenticate_google_drive('../0_data_collection/credentials/google_drive_client_secret.json')
df = grab_google_drive_folder_data(drive=drive,credential_file="../0_data_collection/credentials/google_drive_folder_id.json",filename="reddit_filtered_data.csv")

Successfully loaded 'reddit_filtered_data.csv' into a DataFrame!


In [5]:
import ipywidgets as widgets
import pandas as pd
from IPython.display import display, clear_output

class show_visual(): 
    def __init__(self, df, file_number=1, reviewer="aserban"):
        self.df = df
        self.labels = {}  # Store labeled data
        self.batch_size = 5
        self.current_index = 0  # Track progress
        self.label_options = [ 'Unknown', 'Positive', 'Negative', 'Neutral']
        self.label_widgets = []

        self.next_button = widgets.Button(description="Next Batch", button_style='success', layout=widgets.Layout(width='200px', height='50px'))
        self.save_button = widgets.Button(description="Save & Exit", button_style='danger', layout=widgets.Layout(width='200px', height='50px'))
        self.reviewer = reviewer
        self.file_name = f"labeled_data/{reviewer}_labeled_data_{file_number}.csv"

        self.selected_columns = ['submission_id', 'subredit_topic', 'search_query', 'combine_text', 'url', 'label', 'reviewer']
        
    # Function to display the current batch
    def show_batch(self):
        clear_output(wait=True)
        if self.current_index >= len(self.df):
            print("✅ All samples labeled!")
            return
        
        batch_end = min(self.current_index + self.batch_size, len(self.df))
        current_batch = self.df.iloc[self.current_index:batch_end]
        
        self.label_widgets.clear()        
        for i, row in current_batch.iterrows():
            dropdown = widgets.Dropdown(
                options=self.label_options,
                description=f"{row['url']}",
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='500px')
            )
            self.label_widgets.append((i, dropdown))
            print(f"Post Number Index: {i}")

            display(widgets.VBox([
                widgets.HTML(f"Brand:<b><font color='red'>{row['search_query']}</b>"),
                widgets.Label(f"{row['combine_text']}", layout=widgets.Layout(width='1000px', word_wrap='break-word')),
                dropdown
            ]))
        
        display(widgets.HBox([self.next_button, self.save_button]))

    # Function to determine whether to save labels and proceed
    def save_labels(self,continue_labeling=True):
        for index, dropdown in self.label_widgets:
            self.labels[index] = dropdown.value
        
        self.current_index += self.batch_size
        if continue_labeling:
            self.show_batch()
        else:
            self.save_data()

    # Function to save labeled data locally
    def save_data(self):
        labeled_df = self.df.copy()
        labeled_df['label'] = labeled_df.index.map(self.labels)
        labeled_df['reviewer'] = self.reviewer

        selected_df = labeled_df[~labeled_df['label'].isna()].reset_index(drop=True)
        # selected columns 
        selected_df = selected_df[self.selected_columns]

        selected_df.to_csv(self.file_name, index=False)
        print(f"📂 Data saved as '{self.file_name}'")


    def start_manually_labelling(self): 
        
        # Bind buttons
        self.next_button.on_click(lambda _: self.save_labels(True))
        self.save_button.on_click(lambda _: self.save_labels(False))

        # Start labeling
        self.show_batch()

In [6]:
# sv = show_visual(df=df, file_number=1, reviewer="aserban")
# sv.start_manually_labelling()

In [10]:
# sv = show_visual(df=df.sample(20, random_state=42), file_number=2, reviewer="aserban")
# sv.start_manually_labelling()

In [11]:
# file_path_check = "labeled_data/aserban_labeled_data_2.csv"
# check_df = pd.read_csv(file_path_check)

In [13]:
# check_df.shape