In [1]:
import numpy as np
from sklearn.linear_model import LogisticRegression
import pandas as pd
from sklearn.model_selection import train_test_split
import statsmodels.api as sm

In [2]:
data = pd.read_csv('../openreview.csv')
binary_decisions = {'Accept (Oral)':1, 'Accept (Poster)':1, 'Accept (Spotlight)':1, 'Accept (Talk)':1, 'Withdrawn':0, 'Reject':0, 'Invite to Workshop Track':0}
mideast = ['bh', 'cy', 'eg','ir','iq','il','jo', 'kw','lb','om','ps','qa','sa','sy','tr','ae','ye']
eastasian = ['cn', 'jp', 'mn', 'kp','kr', 'tw', 'hk']
southasian = ['af','bd','bt', 'in', 'np', 'pk', 'lk','mv','sg']

In [3]:
def logistic_regression(prediction, mean_reviewer_score, country):
    scores =[]
    COO_indicator=[]
    target=[]
    print(country)

    for rating, decision, COO, region in zip(data['ratings'], data['decisions'], data['COO'], data['regions']):
        if country == 'US':
            if region == 'usa': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'Canada':
            if region == 'canada': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'South America':
            if region == 'southamerica': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'Australia and New Zealand':
            if COO == 'au' or COO == 'nz': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'Mid East':
            if COO in mideast: COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'UK/Ireland':
            if COO == 'uk' or COO == 'ie': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'mainland Europe':
            if region == 'europe': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'Russia':
            if COO == 'russia': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'Africa':
            if region == 'africa': COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'East Asian':
            if COO in eastasian: COO_indicator.append(1)
            else: COO_indicator.append(0)
        elif country == 'South Asia':
            if COO in southasian: COO_indicator.append(1)
            else: COO_indicator.append(0)
        else:
            print('cannot recognize input country')

        avg_rating=0
        for i in rating.split(';'):
            avg_rating += int(i)
        scores.append(avg_rating/len(rating.split(';')))

        target.append(binary_decisions[decision])

    x=pd.DataFrame()
    x['mean reviewer score']=scores
    x['COO']=COO_indicator
    x['constant']=[1]*len(scores)
    y=np.array(target)

    if country == 'Africa': return 'no data for Africa', None
    logit_model = sm.Logit(y, x)
    summary = logit_model.fit().summary()

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=0)
    logisticRegr = LogisticRegression(random_state=0)
    model = logisticRegr.fit(x_train, y_train)
    accuracy = logisticRegr.score(x_test, y_test)

    if prediction:
        input_list = [mean_reviewer_score, 1, 1]
        pred = np.array(input_list)
        y_pred = model.predict_proba(pred.reshape(1,-1))
        input_list2 = [mean_reviewer_score, 0, 1]
        pred = np.array(input_list2)
        y_pred2 = model.predict_proba(pred.reshape(1,-1))
        print('From country: ', y_pred)
        print('Not from country: ', y_pred2)

    return summary, accuracy

In [20]:
def logistic_regression2():
    scores =[]
    us_indicator=[]
    canada_indicator=[]
    south_america_indicator=[]
    aunz_indicator=[]
    mideast_indicator=[]
    ukireland_indicator=[]
    mainland_europe_indicator=[]
    russia_indicator=[]
    africa_indicator=[]
    east_asian_indicator=[]
    south_asia_indicator=[]
    target=[]

    for rating, decision, COO, region in zip(data['ratings'], data['decisions'], data['COO'], data['regions']):
        print(region)
        print(region=='NAN')
        if region !='NAN':
            if region == 'usa':
                us_indicator.append(1)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif region == 'canada':
                us_indicator.append(0)
                canada_indicator.append(1)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif region == 'southamerica': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(1)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif COO == 'au' or COO == 'nz': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(1)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif COO in mideast: 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(1)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif COO == 'uk' or COO == 'ie': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(1)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif region == 'europe': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(1)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif COO == 'russia': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(1)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif region == 'africa': 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(1)
                east_asian_indicator.append(0)
                south_asia_indicator.append(0)
            elif COO in eastasian: 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(1)
                south_asia_indicator.append(0)
            elif COO in southasian: 
                us_indicator.append(0)
                canada_indicator.append(0)
                south_america_indicator.append(0)
                aunz_indicator.append(0)
                mideast_indicator.append(0)
                ukireland_indicator.append(0)
                mainland_europe_indicator.append(0)
                russia_indicator.append(0)
                africa_indicator.append(0)
                east_asian_indicator.append(0)
                south_asia_indicator.append(1)
            else:
                print('cannot recognize input country')

        avg_rating=0
        for i in rating.split(';'):
            avg_rating += int(i)
        scores.append(avg_rating/len(rating.split(';')))
        target.append(binary_decisions[decision])

    x=pd.DataFrame()
    x['us indicator']=us_indicator
    x['canada indicator']=canada_indicator
    x['south america indicator']=south_america_indicator
    x['AU/NZ indicator']=aunz_indicator
    x['mideast indicator']=mideast_indicator
    x['uk/ireland indicator']=ukireland_indicator
    x['mainland europe indicator']=mainland_europe_indicator
    x['russia indicator']=russia_indicator
    x['africa indicator']=africa_indicator
    x['east asia indicator']=east_asian_indicator
    x['south asia indicator'] = south_america_indicator
    x['constant']=[1]*len(us_indicator)
    x['mean reviewer score']=scores
    y=np.array(target)

    #if country == 'Africa': return 'no data for Africa', None
    logit_model = sm.Logit(y, x)
    summary = logit_model.fit().summary()

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=0)
    logisticRegr = LogisticRegression(random_state=0)
    model = logisticRegr.fit(x_train, y_train)
    accuracy = logisticRegr.score(x_test, y_test)


    return summary, accuracy

In [21]:
summary, score = logistic_regression2()
print(summary)
print('Accuracy: ', score, '\n')
'''countries = ['US', 'Canada', 'South America', 'Australia and New Zealand', 'Mid East', 'UK/Ireland', 'mainland Europe', 'Russia', 'Africa', 'East Asian', 'South Asia']
for country in countries:
    #summary, score = logistic_regression(True, 6.0, country)
    summary, score = logistic_regression2()
    print(summary)
    print('Accuracy: ', score, '\n')'''

usa
False
usa
False
NAN
True
NAN
True
europe
False
usa
False
usa
False
NAN
True
NAN
True
NAN
True
usa
False
usa
False
NAN
True
usa
False
usa
False
NAN
True
NAN
True
NAN
True
europe
False
usa
False
europe
False
NAN
True
usa
False
usa
False
NAN
True
NAN
True
NAN
True
canada
False
NAN
True
NAN
True
asia
False
asia
False
asia
False
NAN
True
NAN
True
usa
False
usa
False
usa
False
usa
False
usa
False
canada
False
europe
False
usa
False
usa
False
canada
False
NAN
True
NAN
True
europe
False
NAN
True
europe
False
usa
False
usa
False
NAN
True
usa
False
usa
False
NAN
True
usa
False
usa
False
usa
False
NAN
True
canada
False
usa
False
usa
False
usa
False
usa
False
usa
False
usa
False
NAN
True
usa
False
usa
False
NAN
True
NAN
True
NAN
True
usa
False
NAN
True
NAN
True
NAN
True
NAN
True
NAN
True
usa
False
canada
False
NAN
True
usa
False
NAN
True
NAN
True
canada
False
usa
False
europe
False
NAN
True
NAN
True
NAN
True
NAN
True
canada
False
europe
False
NAN
True
usa
False
NAN
True
NAN
True
asia
False
asi

ValueError: Length of values (8553) does not match length of index (4623)