In [65]:
import json
import operator
import numpy as np

In [176]:
with open('MeGCBERT_Error_analysis2.json', 'r') as f:
    data = json.load(f)
with open('MeHGCNN2.json', 'r') as f:
    data2 = json.load(f)
with open('ReTesting/stc3-nddq-test-annotations.json', 'r') as f:
    label = json.load(f)

## Preprocessing

In [55]:
def label_preprocess(l):
    lDQ = []
    for tmp in l['annotations']:
        lDQ.append(tmp['quality'])

    lDQ_processed = label_format(lDQ)
    lA, lE, lS = lDQ_processed['A'], lDQ_processed['E'], lDQ_processed['S']
    return lA, lE, lS

In [45]:
def label_format(quality):
    DQ = {}
    qualityA = {-2: 0, -1: 0, 0: 0, 1: 0, 2: 0}
    qualityS = {-2: 0, -1: 0, 0: 0, 1: 0, 2: 0}
    qualityE = {-2: 0, -1: 0, 0: 0, 1: 0, 2: 0}
    quality_keys = [-2, -1, 0, 1, 2]

    for q in quality:
        qualityA[q['A']] += 1
        qualityS[q['S']] += 1
        qualityE[q['E']] += 1

    for k in quality_keys:
        qualityA[k] /= 19
        qualityS[k] /= 19
        qualityE[k] /= 19

    DQ['A'] = [qualityA[k] for k in sorted(qualityA.keys())]
    DQ['S'] = [qualityS[k] for k in sorted(qualityS.keys())]
    DQ['E'] = [qualityE[k] for k in sorted(qualityE.keys())]
    return DQ

## Error Analysis

In [58]:
def normalize(pred, truth):
    if len(pred) != len(truth):
        raise ValueError("pred and truth have different lengths")
    if len(pred) == 0 or len(truth) == 0:
        raise ValueError("pred or truth are empty")

    pred, truth = np.asarray(pred), np.asarray(truth)
    if not ((pred >= 0).all() and (truth >= 0).all()):
        raise ValueError("probability distribution should not be negative")
    pred, truth = pred / pred.sum(), truth / truth.sum()
    return pred, truth

def nmd(pred, truth):
    pred, truth = normalize(pred, truth)
    cum_p, cum_q = np.cumsum(pred), np.cumsum(truth)
    return (np.abs(cum_p - cum_q)).sum() / (len(pred) - 1.)

def distance_weighted(pred, truth, i):
    return np.sum([np.abs(i - j) * ((pred[j] - truth[j]) ** 2) for j in range(len(pred))])

def order_aware_div(pred, truth):
    return np.mean([distance_weighted(pred, truth, i) for i in range(len(pred)) if pred[i] > 0])

def rsnod(pred, truth):
    pred, truth = normalize(pred, truth)
    sod = (order_aware_div(pred, truth) + order_aware_div(truth, pred)) / 2.
    return np.sqrt((sod / (len(pred) - 1)))

In [182]:
def _toList(quality):
    l = [quality[k] for k in sorted(quality.keys())]
    l[0], l[1] = l[1], l[0]
    return l

In [183]:
def DQerror_analysis(data, label):        
    maxNMDA = 0
    maxNMDE = 0
    maxNMDS = 0
    maxRSNODA = 0
    maxRSNODE = 0
    maxRSNODS = 0
    idsA = set()
    idsE = set()
    idsS = set()

    for d, l in zip(data, label):
        assert d['id'] == l['id'], 'ID not match {} & {}'.format(d['id'], l['id'])
        dA, dE, dS = _toList(d['quality']['A']), _toList(d['quality']['E']), _toList(d['quality']['S']),
        lA, lE, lS = label_preprocess(l)
              
        curNMDA = nmd(dA, lA)
        curRSNODA = rsnod(dA, lA)
        curNMDE = nmd(dE, lE)
        curRSNODE = rsnod(dE, lE)
        curNMDS = nmd(dS, lS)
        curRSNODS = rsnod(dS, lS)
        
        if curNMDA > maxNMDA:
            maxNMDA = curNMDA
            idsA.add(d['id'])
            
        if curNMDE > maxNMDE:
            maxNMDE = curNMDE
            idsE.add(d['id'])
            
        if curNMDS > maxNMDS:
            maxNMDS = curNMDS
            idsS.add(d['id'])
        
        if curRSNODA > maxRSNODA:
            maxRSNODA = curRSNODA
            idsA.add(d['id'])
            
        if curRSNODE > maxRSNODE:
            maxRSNODE = curRSNODE
            idsE.add(d['id'])
            
        if curRSNODS > maxRSNODS:
            maxRSNODS = curRSNODS
            idsS.add(d['id'])
    
    return idsA, idsE, idsS

In [184]:
DQerror_analysis(data,label)

({'3575941147776865',
  '3663184819022212',
  '3718912216646049',
  '3789298845556342',
  '3848972026478415',
  '3850365923254701',
  '4022842192155296',
  '4059161753539505',
  '4095409889100831',
  '4136604979135705',
  '4228745059319600'},
 {'3718912216646049',
  '3782529536386965',
  '3822586650143107',
  '3848972026478415',
  '4022842192155296',
  '4059161753539505',
  '4221836747009055'},
 {'3718912216646049',
  '3804383832045677',
  '3848972026478415',
  '4022842192155296',
  '4059161753539505',
  '4132444007180169',
  '4221836747009055'})

In [189]:
# Error analysis check
for d, d2, l in zip(data, data2, label):
    if l['id'] == '3718912216646049' or l['id'] == '4221836747009055':
        print(l['turns'])
        dA, dE, dS = _toList(d['quality']['A']), _toList(d['quality']['E']), _toList(d['quality']['S'])
        d2A, d2E, d2S = _toList(d2['quality']['A']), _toList(d2['quality']['E']), _toList(d2['quality']['S'])
        lA, lE, lS = label_preprocess(l)
        print(dA)
        print(d2A)
        print(lA)
        print()

[{'utterances': ["Compared to Unicom, the speed of Telecom 3G doesn't have the advantage, so the Unicom dual SIM card is better than Telecom. However, the dual SIM card phone of Telecom doesn’t support Telecom and Unicom’s dual SIM card and dual 3G any more, which is still outmoded and only engaged in cdma2000+gsm. I will abandon the Telecom phone number sooner or later and many people will give up Telecom phone number! @ China Telecom Beijing Customer Service @ ZTE Corporation @ Huawei Fans Club @ Coolpad Official Weibo @ Xue Tao’s Design Road"], 'sender': 'customer'}, {'utterances': ['Hello! Where did you surf the Internet? If convenient, could you please send the business number and detailed address to me? I am very grateful for your precious suggestions. I will feed back to you'], 'sender': 'helpdesk'}, {'utterances': ["Telecom cdma2000 is not as quick as Unicom wcdma, which is determined by the technology and can't be changed. I just want to use Telecom and Unicom 3G on the same p

## ACC

In [227]:
def DQACC(data, label):    
    from pprint import pprint
    def _toList(quality):
        l = [quality[k] for k in sorted(quality.keys())]
        l[0], l[1] = l[1], l[0]
        return l
    
    correctA = 0
    correctE = 0
    correctS = 0
    cmA = [[0 for x in range(5)] for y in range(5)] 
    cmE = [[0 for x in range(5)] for y in range(5)] 
    cmS = [[0 for x in range(5)] for y in range(5)] 
    
    for d, l in zip(data, label):
        assert d['id'] == l['id'], 'ID not match {} & {}'.format(d['id'], l['id'])
        dA, dE, dS = _toList(d['quality']['A']), _toList(d['quality']['E']), _toList(d['quality']['S']),
        lA, lE, lS = label_preprocess(l)
        
        d_idxA = dA.index(max(dA))
        d_idxE = dE.index(max(dE))
        d_idxS = dS.index(max(dS))
        
        maxA = max(lA)
        maxE = max(lE)
        maxS = max(lS)
        l_indicesA = [i for i, x in enumerate(lA) if x == maxA]
        l_indicesE = [i for i, x in enumerate(lE) if x == maxE]
        l_indicesS = [i for i, x in enumerate(lS) if x == maxS]
        
        for l_idxA in l_indicesA:
            cmA[d_idxA][l_idxA] += 1
        
        for l_idxE in l_indicesE:
            cmE[d_idxE][l_idxE] += 1
            
        for l_idxS in l_indicesA:
            cmS[d_idxS][l_idxS] += 1
            
        if d_idxA not in l_indicesA:
            print(d['id'])
        
        correctA = correctA + 1 if d_idxA in l_indicesA else correctA
        correctE = correctE + 1 if d_idxE in l_indicesE else correctE
        correctS = correctS + 1 if d_idxS in l_indicesS else correctS
        
#     pprint(cmA)
#     pprint(cmE)
#     pprint(cmS)

    return correctA / 390, correctE / 390, correctS / 390

In [228]:
DQACC(data, label)

4059161753539505
4022842192155296
3906308417819163
4230898565609970
3670008993717977
4170708147859382
4136604979135705
4079388452755871
4229802662028408
3955159908365412
4095409889100831
3786677996271452
4106893935333575
4149091275267452
4221407908748613
3663184819022212
4123631946721258
3848972026478415
4112104343138661
4080949546864041
4205840107796224
4211405613394373
3849312264071011
4207453434212744
3626843729416100
4232672933535868
4091277618515823
3622094619700594
4049755295968214
4231853198745520
4013401988125952
4167377174031311
4157278477060400
3864503135524202
3952324063474435
4224817022759822
4230955466585544
4220489154580263
4165579365808192
3825046810242194
4136292722007598
3804383832045677
3777859678345451
4016731338518342
3970715911227972
3652824510840990
4217894759777453
3836924639066650
3834873678442069
3748873702923912
4181918142457583
3789435626027083
3944768482398385
4226565858151376
3997616767822908
3675432954483271
4228706065049719
3994233775579414
34702990118842

(0.658974358974359, 0.658974358974359, 0.7282051282051282)

In [229]:
DQACC(data2, label)

4022842192155296
3906308417819163
3670008993717977
4170708147859382
4136604979135705
4229802662028408
3955159908365412
4095409889100831
3926256229298112
3786677996271452
4106893935333575
4221407908748613
3663184819022212
3848972026478415
4112104343138661
4080949546864041
3980067338393709
4211405613394373
3849312264071011
3972229676995092
4207453434212744
3626843729416100
4091277618515823
3622094619700594
4049755295968214
4231853198745520
4013401988125952
4167377174031311
4157278477060400
4224817022759822
4230955466585544
4220489154580263
4165579365808192
3590332534845032
3825046810242194
4136292722007598
3453944166480384
3777859678345451
4156229766374281
3652824510840990
4094209810335722
4217894759777453
3836924639066650
3623635279875356
3834873678442069
3789435626027083
3944768482398385
4226565858151376
3997616767822908
4232678481350000
3889986434893062
4228706065049719
3994233775579414
3470299011884270
4004885068113649
3787426159483132
4060892825611272
3938125489261217
42330140131847

(0.7102564102564103, 0.6615384615384615, 0.7769230769230769)

In [196]:
4079542001558440: Bert對，w2v不對
3718912216646049: 都錯

SyntaxError: invalid syntax (<ipython-input-196-b07dd2347f46>, line 1)

In [241]:
# Error analysis check
# 3965041341078954
# BERT: 對，整個distribution幾乎一樣
# W2V : 對，只有最高分精準


# 4022842192155296
# BERT: 錯，但只錯一點
# W2V : 錯，而且錯很多

# 4149091275267452
# BERT:錯
# W2V :對
for d, d2, l in zip(data, data2, label):
    if l['id'] == '4221038969605399':
#     if l['id'] == '4022842192155296':
        print(l['turns'])
        dA, dE, dS = _toList(d['quality']['A']), _toList(d['quality']['E']), _toList(d['quality']['S'])
        d2A, d2E, d2S = _toList(d2['quality']['A']), _toList(d2['quality']['E']), _toList(d2['quality']['S'])
        lA, lE, lS = label_preprocess(l)
        print(dA)
        print(d2A)
        print(lA)
        print()

[{'utterances': ['| When the cellphone USB with type c interface is inserted into M1L, why is there no display... Does it need the operation of other USB debugging? @ Smartisan @Smartisan Customer Service|| Huainan · Jindi...|'], 'sender': 'customer'}, {'utterances': ["H Hello! M1L cellphone doesn’t support OTG function. The charging interface can't recognize external equipment."], 'sender': 'helpdesk'}, {'utterances': ['Are all Smartisan cellphones not suitable for OTG external memory function ... How about the cellphones released newly?'], 'sender': 'customer'}, {'utterances': ['H Smartisan Pro and Smartisan Pro 2 both support OTG function. If you are urgent to use the data in the USB, you are suggested to export the data to the computer and then copy it to the cellphone.'], 'sender': 'helpdesk'}, {'utterances': ['OK, Thank you!.'], 'sender': 'customer'}]
[0.0443945266, 0.01495077, 0.0858940631, 0.4301918149, 0.4245687425]
[0.0259245094, 0.0221765321, 0.0747838169, 0.4130394757, 0.46

In [242]:
nmd([0.0443945266, 0.01495077, 0.0858940631, 0.4301918149, 0.4245687425], 
    [0.0, 0.10526315789473684, 0.05263157894736842, 0.3684210526315789, 0.47368421052631576])

0.038020796146660839

In [243]:
nmd([0.0259245094, 0.0221765321, 0.0747838169, 0.4130394757, 0.4640755951], 
    [0.0, 0.10526315789473684, 0.05263157894736842, 0.3684210526315789, 0.47368421052631576])

0.031926269134064073