In [30]:
from typing import List, Union, Any, Dict
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd

import re
import string
import unicodedata

import torch
from transformers import pipeline


def clean_str(s: str) -> str:
    """String pre-processing function, used to reduce noise.
        1. Convert all characters to ASCII
        2. Remove other irrelevant stuff like email address or external url
        3. Remove special symbols like newline character \\n"""
        
    # Normalize special chars
    s = str(s)
    s = (unicodedata.normalize('NFKD', s)
            .encode('ascii', 'ignore').decode())

    # Remove irrelevant info
    s = re.sub(r'\S*@\S*\s?', '', s)     # Email
    s = re.sub(r'\S*https?:\S*', '', s)  # URL
    
    # Keep punctuation and words only
    pattern_keep = (string.punctuation + 
                        string.ascii_letters + 
                        string.digits + 
                        r' ')
    return re.sub(r'[^' + pattern_keep + r']+', '', s)

In [31]:
# Load textual descriptions of interested entities
df_ent = pd.read_csv('../data/SF_all_tone_2k_entities.csv')
df_ent = df_ent.loc[~df_ent.loc[:, 'org_flag']]
df_ent.loc[:, 'description1'] = df_ent.loc[:, 'description1'].map(clean_str)

# Load occupation categories
df_occ = pd.read_csv('../data/categories.csv')
df_occ.loc[:, 'occupation'] = df_occ.loc[:, 'occupation'].str.lower()

In [35]:
df_ent.description1.iloc[0]

'Jacob Gedleyihlekisa Zuma is a South African politician who was the fourth president of South Africa from 2009 to 2018. He is also referred to by his ...'

In [58]:
# Predict occupation
model = pipeline('zero-shot-classification', 'cross-encoder/nli-MiniLM2-L6-H768')
categories = df_occ.occupation.unique().tolist()
def predict_occupation(row: pd.Series) -> str:
    
    ent = row['entity']
    desc = row['description1']
    return model(desc, categories, hypothesis_template=(f'{ent} is a ' + r'{}'))['labels'][0]

In [59]:
df_ent.loc[:, 'occ_pred'] = df_ent.apply(predict_occupation, axis=1)

TypeError: unhashable type: 'slice'

In [None]:
df_ent[:, ['entity', 'description1', 'occ_pred']]

['politician', 'politician', 'politician', 'politician', 'journalist']