In [10]:
import warnings
warnings.filterwarnings('ignore')
import ipywidgets as widgets
import anywidget
import traitlets
import jupyter
import tweet_browser as tb
import voila
from matplotlib import pyplot as plt
from IPython.display import display, Javascript
import pandas as pd
import io
import math
import time
import json
import time
from custom_widgets import *

# TWEETS_PER_PAGE = 20
# DEBUG_MODE = True
# # JUPYTER_FILE_PATH = "../tree/images/"
# JUPYTER_FILE_PATH = "images/"

out = widgets.Output()

fileUp = widgets.widgets.FileUpload(
    accept='.csv, .txt, .xls, .tsv',
    multiple=False,
    description='Change Dataset'
)

dummyEl = DummyElement()
def startSession(file):
    if file['type'] == 'xls':
        data = pd.read_excel(io.BytesIO(file.content))
    else:
        data = pd.read_csv(io.BytesIO(file.content))
    s = tb.Session(data, False)
    dummyEl.fileName = file['name']
    browser = Browser(s, out)

def autoStartSession(fileName):
    data = tb.parse_data(fileName)
    # s = tb.Session(data, False)
    s = tb.Session(data, False, embeddings=pd.read_csv("allCensus_sample_embeddings.csv", encoding = "utf-8", index_col=0))
    dummyEl.fileName = fileName
    browser = Browser(s, out)

def selectColumns (row, colHeaders: list):
    result = []
    for j in colHeaders:
        result.append(row[s.headerDict[j]])
    return result
    
class Browser:
    def __init__(self, s, out):
        self.s = s
        dummyEl.size = s.length
        dummyEl.observe(self.alertHandler, ["changeSignal"])
        self.screen = "main"
        self.colHeaders = list(s.headerDict.keys())
        self.createWidgets()
        self.search(None)
        self.resetDisplay()
        # self.history = [StoredSearch()]
    
    def resetDisplay(self, b = None):
        out.clear_output(True)
        with out:
            if self.screen == "main":
                display(self.mainPage)
                self.loadTab(None, self.tabs.selected_index)
                self.getTweets()
                if DEBUG_MODE:
                    temp = self.s.currentSet
                    if self.s.currentSet != self.s.base:
                        self.s.back()
                    self.s.currentSet = temp
                    assert(self.s.currentSet.size >= len(self.tweetDisplay.value))
            elif self.screen == "advanced":
                # display(self.closeButton)
                display(self.advancedPage)
            elif self.screen == "summary":
                # display(self.closeButton)
                display(self.aiSummary)
            # display(self.debugText)
            display(dummyEl)
            
    def search(self, b=None):
        self.s.currentSet = self.s.base
        self.displaySimilarityScore.hidden = 1
        for i in range(self.geography.count):
            self.s.filterBy("State", json.loads(self.geography.value)[i].lower().capitalize())
        for i in range(self.userName.count):
            self.s.filterBy("SenderScreenName", json.loads(self.userName.value)[i])
        if(self.fromDate.value != None and self.toDate.value != None):
            self.s.filterDate(self.fromDate.value.strftime('%Y-%m-%d'), self.toDate.value.strftime('%Y-%m-%d'))
        if(self.allowRetweets.value == 0):
            self.s.removeRetweets()
        if(self.semanticSearch.value != ""):
            self.displaySimilarityScore.hidden = 0
            self.s.semanticSearch(self.semanticSearch.value, float(self.semanticSearch.filterPercent / 100))
        if(self.exclude.count > 0):
            self.s.exclude(json.loads(self.exclude.value))
        if(self.mustInclude.count > 0):
            self.s.searchKeyword(json.loads(self.mustInclude.value), False)
        if(self.containOneOf.count > 0):
            self.s.searchKeyword(json.loads(self.containOneOf.value), True)
        self.toggleSimilarityScore()
        self.filterBox.children = [self.mainFilters]
        self.hasChanges = False
        self.sampleSelector.total = self.s.currentSet.size
        self.tweetDisplay.pageNum = 1
        self.currentWorkingSet = self.s.currentSet
        self.aiSummary.rerender = 1
        self.updateSearchParams(b)

    def tryGetNewSample(self):
        self.tweetDisplay.pageNum = 1
        sampleSize = self.sampleSelector.value
        if self.s.currentSet.size < self.sampleSelector.value:
            self.sampleSelector.value = -1
        if self.sampleSelector.value == -1:
            sampleSize = self.s.currentSet.size
        
        if DEBUG_MODE:
            assert(sampleSize >= 0)
        if self.weightBy.value == "None":
            self.s.simpleRandomSample(sampleSize)  
        else:
            self.s.weightedSample(sampleSize, self.weightBy.value)
        self.getTweets()        
    
    def generateNewSample(self, b):
        # make sure to call only after a sample has already been generated
        # self.s.back()
        self.s.currentSet = self.currentWorkingSet
        self.tryGetNewSample()
    
    def getTweets(self, change=None):
        dataSet = self.s.getCurrentSubset()
        pageNum = self.tweetDisplay.pageNum
        
        tempArr = []
        sorted = self.getSortedTweets(pageNum)
        if DEBUG_MODE:
            assert(len(sorted) <= 2 * TWEETS_PER_PAGE)
        self.tweetDisplay.maxPage = math.ceil(self.s.currentSet.size / TWEETS_PER_PAGE)
        for i in range(len(sorted)):   
            tempArr.append(sorted.iloc[i].to_json())
        self.sampleTitle.value = "Displaying " + format(self.s.currentSet.size, ',d') + " posts from " + format(self.sampleSelector.total, ',d') + " results"
        self.tweetDisplay.value = tempArr
    
    def getSortedTweets(self, pageNum):
        ans = self.s.getCurrentSubset()
        
        asc = True
        na_pos = "first"
        column = self.sortBar.sortColumn
        keyFunc = None
            
        if (column == "Username" or column == "SenderScreenName"):
            keyFunc = userNameToLower
        if (self.sortBar.sortOrder == "DESC"):
            asc = False
            na_pos = "last"
        if self.sortBar.sortColumn == "None":
            # ans = ans.sample(frac = 1)
            pass
        else:
            ans = ans.sort_values(by=[column], ascending=asc, na_position = na_pos, key=keyFunc)
        ans = ans.iloc[max((pageNum-2) * TWEETS_PER_PAGE, 0) : min(pageNum * TWEETS_PER_PAGE, len(ans))]
        return ans
        
    def createWidgets(self):
        self.advancedButton = widgets.Button(description='Click here to enter search query').add_class("long-button")
        self.advancedButton.on_click(self.openSearchMenu)
        self.searchedCriteria = widgets.HTML("SEARCHED CRITERIA", layout=widgets.Layout(padding="0px 32px")).add_class("heading4")
        self.tweetDisplay = TweetDisplay(height = "60vh", displayAddOn = 0, addOnColumnName = "SimilarityScore")
        self.tweetDisplay.observe(self.getTweets, names=["pageNum"])
        self.datasetDisplay = fileUp
        self.sortBar = SortBar()
        self.sortBar.observe(self.getTweets, names=["sortScope", "sortColumn", "sortOrder"])
        self.displaySimilarityScore = ToggleSwitch(label = "Relevance", hidden=1, value=0).add_class("tweet-display-add-on")
        self.displaySimilarityScore.observe(self.toggleSimilarityScore, ["value"])
        # optionsBar = widgets.Box(children = [self.sortBar])
        # optionsBar.layout = widgets.Layout(align_items = "center", justify_content = "space-between", width = "100%")
        # self.searchedKeywords = ParameterDisplay(firstWord = "Searched", secondWord = "Keywords", headers = ["Must Include", "Contain one of", "Exclude"], notFound = 'To enter keywords, click "Search & Filter"')
        # self.appliedFilters = ParameterDisplay(firstWord = "Applied", secondWord = "Filters", headers = ["calendar.svg", "geography.svg", "username.svg", "repost.svg", "weight.svg"], notFound = 'To enter filters, click "Search & Filter"')
        self.advancedBar = widgets.VBox([self.searchedCriteria, self.advancedButton]).add_class("advanced-bar")
        self.sampleTitle = widgets.HTML().add_class("display-count")
        self.sampleSelector = SampleSelector(label="Generate New Sample >")
        self.sampleSelector.observe(self.generateNewSample, names=["changeSignal"])

        self.makeAiSummaryPage()
        self.makeFilterBar()
        self.makeStanceAnalysisPage()
        
        self.typicalSampleTitle = widgets.HTML().add_class("display-count")
        self.displayCentralityScore = ToggleSwitch(label = "Typicality", value=0).add_class("tweet-display-add-on")
        self.displayCentralityScore.observe(self.toggleTypicalityScore, ["value"])
        sampleTopBar = widgets.HBox([self.sampleTitle, self.sampleSelector], layout=widgets.Layout(justify_content="space-between", flex="0 0"))
        sortingBar = widgets.HBox([self.sortBar, self.displaySimilarityScore], layout=widgets.Layout(flex="0 0"))
        self.randomSelection = widgets.VBox([sampleTopBar, sortingBar, self.tweetDisplay], layout=widgets.Layout(max_height="100%"))
        self.centralTweets = TweetDisplay(height="60vh", displayAddOn=0, addOnColumnName="centrality")
        self.centralTweetBox = widgets.VBox([widgets.HBox([self.typicalSampleTitle, self.displayCentralityScore], layout=widgets.Layout(flex="0 0", margin="0px 0px 16px 0px")), self.centralTweets], layout=widgets.Layout(max_height="100%"))
        self.tabs = widgets.Tab(children=[self.randomSelection, self.centralTweetBox, self.summaryTab, self.stanceAnalysis], titles=["Random Posts", "Typical Posts", "AI Summary", "Stance Analysis"])
        self.tabs.observe(self.loadTab, names=["selected_index"])
        self.topBar = widgets.HBox([widgets.HTML("Tweet Browser").add_class("title"), self.datasetDisplay]).add_class("top-bar")
        self.mainPage = widgets.VBox([self.topBar, self.advancedBar, widgets.HBox([self.filterBox, self.tabs])])
        # self.mainPage = widgets.VBox([self.paramDisplay, self.tabs])
        self.debugText = widgets.HTML("test")

        self.makeAdvancedPage()

    def makeFilterBar(self):
        self.filterBy = widgets.HTML(value = "Refine Results").add_class("heading5").add_class("medium")
        dateRange = widgets.HTML(value = "Date").add_class("body2").add_class("medium")
        self.fromDate = widgets.DatePicker(description = "From")
        self.toDate = widgets.DatePicker(description = "To")
        minDate = self.s.findMinDate().strftime("%Y-%m-%d")
        maxDate = self.s.findMaxDate().strftime("%Y-%m-%d")
        dummyEl.calendarStart = minDate
        dummyEl.calendarEnd = maxDate
        self.fromDate.add_class("date-constraint") # The script to set the elements attribute is attached to the toggleSwitch widget
        self.toDate.add_class("date-constraint") # This was done for convenience and should be changed later
        self.weightBy = WeightBy()
        self.dateBox = widgets.VBox([dateRange, widgets.HBox([self.fromDate, self.toDate])]).add_class("date-bar")
        self.allowRetweets = ToggleSwitch(label = "Include reposts")
        self.retweets = widgets.VBox([widgets.HTML(value = "Retweets").add_class("body2").add_class("medium"), self.allowRetweets])
        self.geography = SearchBar(header = "Geography", placeholder = "Search")
        self.userName = SearchBar(header = "Username", placeholder = "Search")
        self.mainPageClear = widgets.Button(description='Clear All').add_class("clear-button")
        self.mainPageClear.on_click(self.clearFilters)
        self.mainPageSearch = widgets.Button(description='Apply').add_class("generic-button")
        self.mainPageSearch.on_click(self.search)
        self.popUpOptions = widgets.HBox([self.mainPageClear, self.mainPageSearch]).add_class("pop-up-options")
        self.userName.observe(self.displayPopUp, names=["value"])
        self.geography.observe(self.displayPopUp, names=["value"])
        self.allowRetweets.observe(self.displayPopUp, names=["value"])
        self.weightBy.observe(self.displayPopUp, names=["value"])
        self.fromDate.observe(self.displayPopUp, names=["value"])
        self.toDate.observe(self.displayPopUp, names=["value"])
        self.mainFilters = widgets.VBox([self.filterBy, self.dateBox, self.retweets, self.geography, self.userName, self.weightBy]).add_class("main-filters")
        self.filterBox = widgets.VBox([self.mainFilters]).add_class("filter-box")

    def makeStanceAnalysisPage(self):
        self.stanceAnalysis = StanceAnalysis()
        # self.stanceAnalysis.observe()
        
        # modifyStanceButton = widgets.Button("Modify Stance Annotation").add_class("generic-button")
        # self.stanceSampleSelector = SampleSelector("New Stance Analysis >")
        # stanceTitle = widgets.HTML("Displaying " + stanceSampleSelector.value + " posts from " + str(self.stanceSampleSelector.total) + " results")
        self.stanceAnalysisPage = self.stanceAnalysis

    # def showStanceAnalysisResults(self, change=None):
        
    def makeAdvancedPage(self):
        self.searchButton = widgets.Button(description='Search', icon="search").add_class("generic-button").add_class("search-button")
        self.hiddenButton = widgets.Button()
        self.hiddenButton.add_class("hidden-button") # work around for syncing search when the user still has input in the search bars
        self.hiddenButton.on_click(self.search)
        self.clearButton = widgets.Button(description='Clear All').add_class("clear-button")
        self.clearButton.on_click(self.clearSearch)
        self.bottomBar = widgets.HBox([self.clearButton, self.searchButton, self.hiddenButton], layout = widgets.Layout(justify_content = "flex-end"))
        keyWordSearch = widgets.HTML(value = "Exact Match", layout = widgets.Layout(margin = "0px 0px -8px 0px")).add_class("heading5").add_class("medium")
        self.mustInclude = SearchBar(header = "Must include all", header2="(AND)", placeholder='e.g. “civil null” means each post in the result must contain the word “civil” and “null”')
        self.containOneOf = SearchBar(header = "Must include one of", header2="(OR)", placeholder='e.g. “census penny” means each post in the result must contain either “census” or “penny” or both')
        self.exclude = SearchBar(header = "Must not include", header2="(NOT)", placeholder='e.g. “toxic ban” means none of the posts in the result contains the word “toxic” and “ban”')
        self.semanticSearch = SemanticSearch(placeholder = "e.g. misinformation and miscommunication")
        self.searches = widgets.VBox([widgets.HTML(value = "<b>Search Criteria</b>"), self.semanticSearch, keyWordSearch, self.mustInclude, self.containOneOf, self.exclude])
        self.searches.add_class("search-box")
        self.closeButton = widgets.Button(description = 'X')
        self.closeButton.add_class("close-button")
        self.closeButton.on_click(self.closeSearchMenu)
        # self.advancedBox = widgets.HBox([self.searches, self.filterBox])
        # self.advancedBox.add_class("advanced-box")
        # self.advancedPage = widgets.VBox([self.advancedBox, self.bottomBar])
        self.advancedPage = widgets.VBox([self.closeButton, self.searches, self.bottomBar]).add_class("advanced-page")
    
    def makeAiSummaryPage(self):
        self.loadingPage = LoadingPage(text="Generating AI Summary")
        self.aiSummary = AiSummary()
        self.aiSummary.observe(self.updateAiPageSelect, names=["changeSignal"])
        self.aiTitle = widgets.HTML().add_class("display-count")
        self.pageSelectAi = PageSelect()
        self.pageSelectAi.observe(self.getSummaryTweets, names=["value"])
        self.summaryDisplay = TweetDisplay(height="60vh")
        leftBar = widgets.VBox([widgets.HTML("AI Generated Summary").add_class("heading4").add_class("medium"), self.aiSummary, self.pageSelectAi]).add_class("left-bar")
        self.newSummaryButton = widgets.Button(description="Generate Another Summary").add_class("generic-button").add_class("summary-button")
        self.newSummaryButton.on_click(self.generateNewSummary)
        summaryContent = widgets.HBox([leftBar, widgets.VBox([widgets.HTML("Contributing Posts").add_class("heading4").add_class("medium"), self.summaryDisplay]).add_class("right-bar")]).add_class("summary-tab")
        self.summaryTab = widgets.VBox([self.aiTitle, summaryContent, self.newSummaryButton], layout=widgets.Layout(max_height="100%"))

    
    def openSearchMenu(self, change):
        self.screen = "advanced"
        self.resetDisplay()

    def toggleSimilarityScore(self, change=None):
        self.tweetDisplay.displayAddOn = self.displaySimilarityScore.value
        if self.displaySimilarityScore.hidden > 0:
            self.tweetDisplay.displayAddOn = 0
    
    def toggleTypicalityScore(self, change):
        self.centralTweets.displayAddOn = self.displayCentralityScore.value
    
    def displayPopUp(self, change):
        self.hasChanges = True
        self.filterBox.children = [self.mainFilters, self.popUpOptions]
    
    def clearSearch(self, change):
        self.semanticSearch.value = ""
        self.exclude.value = "[]"
        self.exclude.count = 0
        self.mustInclude.value = "[]"
        self.mustInclude.count = 0
        self.containOneOf.value = "[]"
        self.containOneOf.count = 0
        self.resetDisplay()
        
    def clearFilters(self, change):
        self.geography.value = "[]"
        self.geography.count = 0
        self.userName.value = "[]"
        self.userName.count = 0
        self.fromDate.value = None
        self.toDate.value = None
        self.allowRetweets.value = True
        self.weightBy.value = "None"
        self.filterBox.children = [self.mainFilters]
        self.search()
        # self.resetDisplay()

    def getTypicalPosts(self, change=None):
        tempArr = []
        self.s.currentSet = self.currentWorkingSet
        # result = self.s.getCentral(self.s.currentSet.parent.parents[0])
        # self.s.back()
        result = self.s.getCentral()
        count = min(5, len(result))
        if count == 0:
            self.typicalSampleTitle.visible = False
            self.typicalSampleTitle.value = ""
        else:
            self.typicalSampleTitle.visible = True
            self.typicalSampleTitle.value = "Displaying " + format(count, ",d") + " typical posts from " + format(len(result), ',d') + " results"
        for i in range(count):
            tempArr.append(result.iloc[i].to_json())
        self.centralTweets.value = tempArr
        
    # def restoreSearch(self, settings):

    def alertHandler(self, change):
        if dummyEl.userResponse == 1:
            self.search()
            self.loadTab(None, self.loadTab.selected_index)
        dummyEl.userResponse = 2

    def loadTab(self, change, tabNum = None):
        if change != None:
            tabNum = change["new"]
            if self.hasChanges:
                dummyEl.alertTrigger += 1
        if tabNum == 1:
            self.getTypicalPosts()
        elif tabNum == 2:
            self.generateSummary()

    def updateAiPageSelect(self, b):
        self.pageSelectAi.value = self.aiSummary.selected + 1
        self.pageSelectAi.changeSignal += 1
        self.getSummaryTweets(b)

    def generateNewSummary(self, b):
        self.s.currentSet = self.currentWorkingSet
        self.aiSummary.rerender = 1
        self.generateSummary(b)
    
    def getSummaryTweets(self, b):
        pageNum = self.pageSelectAi.value - 1 # convert to 0 indexing
        self.aiSummary.selected = pageNum
        tweets = self.aiSummary.sentenceNums[pageNum]
        ans = self.s.allData.iloc[tweets]
        tempArr = []
        for i in range(len(ans)):
            tempArr.append(ans.iloc[i].to_json())
        self.summaryDisplay.value = tempArr
        self.pageSelectAi.changeSignal += 1
        # self.resetDisplay()
    
    def generateSummary(self, change = None):
        if self.aiSummary.rerender == 0:
            return
        if self.s.currentSet.size > 50:
            self.aiTitle.value = "Summarizing 50 posts sampled from " + format(self.s.currentSet.size, ',d') + " results"
            self.s.simpleRandomSample(50)
        else:
            self.aiTitle.value = "Summarizing " + format(self.s.currentSet.size, ',d') + " posts sampled from " + format(self.s.currentSet.size, ',d') + " results"
        self.tabs.children = [self.randomSelection, self.centralTweetBox, self.loadingPage, self.stanceAnalysisPage]
        summary = self.s.summarize()
        self.tabs.children = [self.randomSelection, self.centralTweetBox, self.summaryTab, self.stanceAnalysisPage]
        strings, tweets, unused = self.s.parseSummary(summary)
        self.aiSummary.value = strings
        self.aiSummary.sentenceNums = tweets
        self.aiSummary.unused = unused
        self.aiSummary.rerender = 0
        self.pageSelectAi.maxPage = len(tweets)
        self.aiSummary.selected = 0
        self.pageSelectAi.value = 1
        self.getSummaryTweets(None)
    
    def updateSearchParams(self, change):
        val1 = val2 = val3 = ""
        if(self.mustInclude.count > 0):
            val1 = ', '.join(json.loads(self.mustInclude.value))
        if(self.containOneOf.count > 0):
            val2 = ', '.join(json.loads(self.containOneOf.value))
        if(self.exclude.count > 0):
            val3 = ', '.join(json.loads(self.exclude.value))
        # self.searchedKeywords.value = [val1, val2, val3]
        selectedDates = ''
        if(self.fromDate.value != None and self.toDate.value != None):
            selectedDates = str(self.fromDate.value) + " to " + str(self.toDate.value)
        geo = usrname = ""
        if (self.geography.count > 0):
            geo = ', '.join(json.loads(self.geography.value))
        if (self.userName.count > 0):
            usrname = ', '.join(json.loads(self.userName.value))
        retweets = "" if self.allowRetweets.value == 2 else ("yes" if self.allowRetweets.value > 0 else "no")
        weightBy = ""
        if self.weightBy.value != "None":
            weightBy = self.weightBy.value
            if self.weightBy.value == "SenderInfluencerScore":
                weightBy = "Influencer Score"
        # self.appliedFilters.value = [selectedDates, geo, usrname, retweets, weightBy]
        self.closeSearchMenu(change)

    def closeSearchMenu(self, change):
        self.screen = "main"
        self.resetDisplay()
    
def fileHandler(change):
    startSession(fileUp.value[0])
    
def userNameToLower(input):
    return input.str.lower()

# with out:
#     display(fileUp)

autoStartSession("allCensus_sample.csv")

fileUp.observe(fileHandler, names=["value"])

out

Output()