In [1]:
from CollateFn.CollateFnBase import CollateFnBase

CollateFnBase.id2label = CollateFnBase.ATESPid2label
CollateFnBase.label2id = CollateFnBase.ATESPlabel2id

from ModelSummary.ModelOutputsRecord import ModelOutputsRecord
import os
import numpy as np
import pandas as pd

In [2]:
class PairMap:
    def __init__(self, logDir='./ModelSummary/Record|0|0|0|0|0/') -> None:
        self.logDir = logDir
        self.recordsPath = os.listdir(logDir)
        self.modelRecords = []
        self.loadDir()
        self.loadSource2TargetIndex()

    def loadDir(self):
        for recordPath in self.recordsPath:
            if recordPath[-1] == 'd':
                modelRecord = ModelOutputsRecord.load(self.logDir + recordPath)
                # print('load ' + self.logDir + recordPath)
                self.modelRecords.append(modelRecord)
    
    def loadSource2TargetIndex(self):
        languages =['english', 'spanish', 'french']
        languageArrg = {
            'english': 'en', 
            'spanish': 'es', 
            'french': 'fr'
        }
        self.source2targetMapIndex = {}
        for s in languages:
            for t in languages:
                self.source2targetMapIndex[languageArrg[s]+'2'+languageArrg[t]] = []
        
        for i, modelRecord in enumerate(self.modelRecords):
            self.source2targetMapIndex[languageArrg[modelRecord.dataParams.Source]+'2'+languageArrg[modelRecord.dataParams.Target]].append(i)
        
    def analyse(self, key):
        # 要分析出均值和方差
        analyseDic = {}
        for source2target, indexs in self.source2targetMapIndex.items():
            valueList = [self.modelRecords[index].analyseModel()['test'][key] for index in indexs]
            analyseDic[source2target] = (np.round(np.mean(valueList)*100, 3), np.round(np.std(valueList)*100, 3), len(indexs))
        return pd.DataFrame(analyseDic)

In [5]:

pairMap = PairMap("./1LayerAttentionRecordNER/")
print(pairMap.analyse('MicroF1'))
# en2es:66.087±1.522  en2fr:56.651±1.894

# baseline:
# en2es:66.949±0.380  en2fr:57.197±1.951

    en2en   en2es   en2fr  es2en  es2es  es2fr  fr2en  fr2es  fr2fr
0  70.712  66.087  56.651    NaN    NaN    NaN    NaN    NaN    NaN
1   1.199   1.522   1.894    NaN    NaN    NaN    NaN    NaN    NaN
2   5.000   5.000   5.000    0.0    0.0    0.0    0.0    0.0    0.0


In [7]:
pairMap = PairMap("./1LayerAttentionRecordTwitterSentiment/")
print(pairMap.analyse('MicroF1'))
# en2es:65.251±2.484  en2fr:53.297 ± 3.114

# baseline:
# en2es:66.949±0.380  en2fr:57.197±1.951

    en2en   en2es   en2fr  es2en  es2es  es2fr  fr2en  fr2es  fr2fr
0  71.931  65.251  53.297    NaN    NaN    NaN    NaN    NaN    NaN
1   0.819   2.484   3.114    NaN    NaN    NaN    NaN    NaN    NaN
2   5.000   5.000   5.000    0.0    0.0    0.0    0.0    0.0    0.0


In [8]:
pairMap = PairMap("./RecordNER/")
print(pairMap.analyse('MicroF1'))
# en2es:66.605±1.025  en2fr:56.878±0.859

# baseline:
# en2es:66.949±0.380  en2fr:57.197±1.951

   en2en   en2es   en2fr  es2en  es2es  es2fr  fr2en  fr2es  fr2fr
0    NaN  66.605  56.878    NaN    NaN    NaN    NaN    NaN    NaN
1    NaN   1.025   0.859    NaN    NaN    NaN    NaN    NaN    NaN
2    0.0   5.000   5.000    0.0    0.0    0.0    0.0    0.0    0.0


In [9]:
pairMap = PairMap("./RecordTwitterSentiment/")
print(pairMap.analyse('MicroF1'))
# en2es:63.964±2.008  en2fr:52.236±1.783

# baseline:
# en2es:66.949±0.380  en2fr:57.197±1.951

   en2en   en2es   en2fr  es2en  es2es  es2fr  fr2en  fr2es  fr2fr
0    NaN  63.964  52.236    NaN    NaN    NaN    NaN    NaN    NaN
1    NaN   2.008   1.783    NaN    NaN    NaN    NaN    NaN    NaN
2    0.0   5.000   5.000    0.0    0.0    0.0    0.0    0.0    0.0


In [4]:
pairMap = PairMap("./RecordXLMRBase/")
print(pairMap.analyse('MicroF1'))

# baseline:
# en2es:66.949±0.380  en2fr:57.197±1.951

   en2en   en2es   en2fr  es2en  es2es  es2fr  fr2en  fr2es  fr2fr
0    NaN  66.949  57.197    NaN    NaN    NaN    NaN    NaN    NaN
1    NaN   0.380   1.951    NaN    NaN    NaN    NaN    NaN    NaN
2    0.0   5.000   5.000    0.0    0.0    0.0    0.0    0.0    0.0
