In [None]:
!pip install -U scikit-fuzzy

In [1]:
!python --version

Python 3.6.8 :: Anaconda, Inc.


In [None]:
import pandas as pd
import numpy as np
from zoobot import label_metadata, schemas
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from scipy.optimize import linear_sum_assignment as linear_assignment
import skfuzzy

In [16]:
def findChoice(frac):
    choice = np.zeros_like(frac)
    choice[np.arange(len(frac)), frac.argmax(1)] = 1
    return choice

def getQuestionClasses(auto_f, volunteers, question):
    qcol_name = question.text+'_total-votes'
    fcol_names = [(cols.text+'_fraction') for cols in question.answers]
    anscol_names = [cols.text for cols in question.answers]
    valid_feats = []
    
    valid_vol = volunteers.query('`{}`/`smooth-or-featured_total-votes` >= 0.5'.format(qcol_name))
    valid_idx = valid_vol.index.tolist()
    vol_results = valid_vol[fcol_names].values
    
    auto_values = auto_f.values
    
    for i in valid_idx:
        valid_feats.append(auto_values[i])
        
    rounded_vol_results = findChoice(np.asarray(vol_results))
    support = len(rounded_vol_results)
    
    _,u,_,_,_,_,_ = skfuzzy.cmeans(np.transpose(valid_feats), c=3, m=2, error=1e-4, maxiter=300,seed=None)
    pred_results = np.argmax(u, axis=0)
    
    vol_classes = np.argmax(rounded_vol_results, axis=1)
    
    return valid_idx, support, anscol_names, np.array(pred_results), np.array(vol_classes)

In [3]:
def makeComp(pred, vol, i, numLabels):
    compAll = []
    for j in range(numLabels):
        vol_where = np.where(vol==j)
        pred_where = np.where(pred==i)
        comp = np.intersect1d(vol_where[0], pred_where[0])
        compAll.append(len(comp)/len(vol_where[0]))
    return compAll

def _make_cost_m(cm):
    s = np.max(cm)
    return (- cm + s)

def labelMap(vol, pred):
    cm = confusion_matrix(vol, pred)
    indexes = linear_assignment(_make_cost_m(cm))
    indexes = np.asarray(indexes)
    return indexes[1]
    
def convertLabels(lmap, pred):
    conv_preds = []
    for i in range(len(pred)):
        conv_preds.append(lmap[pred[i]])
    return np.array(conv_preds)

In [4]:
auto_features = pd.read_csv("../autoencoder/extracted_features.csv")

In [5]:
auto_features = auto_features.drop('file_loc',axis=1)

In [6]:
decals_test = pd.read_csv('../Ilifu_data/decals_ilifu_test.csv')
schema = schemas.Schema(label_metadata.decals_pairs, label_metadata.get_gz2_and_decals_dependencies(label_metadata.decals_pairs))

{smooth-or-featured, indices 0 to 2, asked after None: (0, 2), disk-edge-on, indices 3 to 4, asked after smooth-or-featured_featured-or-disk, index 1: (3, 4), has-spiral-arms, indices 5 to 6, asked after disk-edge-on_no, index 4: (5, 6), bar, indices 7 to 9, asked after disk-edge-on_no, index 4: (7, 9), bulge-size, indices 10 to 14, asked after disk-edge-on_no, index 4: (10, 14), how-rounded, indices 15 to 17, asked after smooth-or-featured_smooth, index 0: (15, 17), edge-on-bulge, indices 18 to 20, asked after disk-edge-on_yes, index 3: (18, 20), spiral-winding, indices 21 to 23, asked after has-spiral-arms_yes, index 5: (21, 23), spiral-arm-count, indices 24 to 29, asked after has-spiral-arms_yes, index 5: (24, 29), merging, indices 30 to 33, asked after None: (30, 33)}


In [17]:
total_report = {}
for question in label_metadata.decals_pairs:
    idxs, support, anscols, valid_preds, valid_vol = getQuestionClasses(auto_features, decals_test, schema.get_question(question))
    lmap = labelMap(valid_vol, valid_preds)
    conv_preds = convertLabels(lmap, valid_preds)
    question_report = precision_recall_fscore_support(y_pred=conv_preds, y_true=valid_vol, average='weighted')
    total_report[question] = {
        'precision': question_report[0],
        'recall': question_report[1],
        'f1': question_report[2],
        'support': support
    }

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [18]:
report_df = pd.DataFrame.from_dict(total_report, orient='index')
report_df

Unnamed: 0,precision,recall,f1,support
smooth-or-featured,0.348344,0.110483,0.084641,49917
disk-edge-on,0.710671,0.504176,0.544913,15445
has-spiral-arms,0.71371,0.57355,0.617568,11380
bar,0.425038,0.453691,0.434058,11380
bulge-size,0.270337,0.316608,0.29109,11380
how-rounded,0.461095,0.440263,0.449872,32526
edge-on-bulge,0.677051,0.575758,0.575119,2475
spiral-winding,0.397989,0.443926,0.417089,7499
spiral-arm-count,0.383374,0.349113,0.359439,7499
merging,0.741674,0.522468,0.607943,49247
