In [1]:
import torch
import sys
import os
import time
import numpy as np
import argparse
sys.path.append("..")

from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from umap.umap_ import find_ab_params

from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler
from singleVis.SingleVisualizationModel import VisModel
from singleVis.losses import UmapLoss, ReconstructionLoss, SingleVisLoss
from singleVis.edge_dataset import DataHandler
from singleVis.trainer import SingleVisTrainer
from singleVis.data import NormalDataProvider
from singleVis.spatial_edge_constructor import kcSpatialAlignmentEdgeConstructor
# from singleVis.temporal_edge_constructor import GlobalTemporalEdgeConstructor
from singleVis.alignment_edge_constructor import LocalAlignmentEdgeConstructor
from singleVis.projector import TimeVisProjector
from singleVis.eval.evaluator import Evaluator


import torch
import numpy as np

# REF_PATH : reference dataset path
# CONFUSION_PATH : benchmark1
# EXCHANGE_PATH : benchmark2

REF_PATH = "/home/yifan/dataset/noisy/pairflip/cifar10/noisy0.001"
CLEAN_PATH = "/home/yifan/dataset/clean/pairflip/cifar10/0"

CONFUSION_PATH = "/home/yifan/dataset/confusion/pairflip/cifar10/0"
EXCHANGE_PATH = "/home/yifan/dataset/exchange/pairflip/cifar10/0"

sys.path.append(REF_PATH)


from config import config

SETTING = config["SETTING"]
CLASSES = config["CLASSES"]
DATASET = config["DATASET"]
GPU_ID = config["GPU"]
EPOCH_START = config["EPOCH_START"]
EPOCH_END = config["EPOCH_END"]
EPOCH_PERIOD = config["EPOCH_PERIOD"]

# Training parameter (subject model)
TRAINING_PARAMETER = config["TRAINING"]
NET = TRAINING_PARAMETER["NET"]
LEN = TRAINING_PARAMETER["train_num"]

# Training parameter (visualization model)
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
PREPROCESS = VISUALIZATION_PARAMETER["PREPROCESS"]
LAMBDA = VISUALIZATION_PARAMETER["LAMBDA"]
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
INIT_NUM = VISUALIZATION_PARAMETER["INIT_NUM"]
ALPHA = VISUALIZATION_PARAMETER["ALPHA"]
BETA = VISUALIZATION_PARAMETER["BETA"]
MAX_HAUSDORFF = VISUALIZATION_PARAMETER["MAX_HAUSDORFF"]
# HIDDEN_LAYER = VISUALIZATION_PARAMETER["HIDDEN_LAYER"]
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
T_N_EPOCHS = VISUALIZATION_PARAMETER["T_N_EPOCHS"]
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]

VIS_MODEL_NAME = 'vis'
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]

SEGMENTS = [(EPOCH_START, EPOCH_END)]
# define hyperparameters
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")

import Model.model as subject_model
net = eval("subject_model.{}()".format(NET))


ref_provider = NormalDataProvider(REF_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, split=-1, device=DEVICE, classes=CLASSES,verbose=1)
clean_provider = NormalDataProvider(CLEAN_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, split=-1, device=DEVICE, classes=CLASSES,verbose=1)

confusion_provider = NormalDataProvider(CONFUSION_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, split=-1, device=DEVICE, classes=CLASSES,verbose=1)

exchange_provider = NormalDataProvider(EXCHANGE_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, split=-1, device=DEVICE, classes=CLASSES,verbose=1)


ref_train_data = ref_provider.train_representation(200).squeeze()

confusion_data = confusion_provider.train_representation(200).squeeze()

exchange_data = exchange_provider.train_representation(200).squeeze()

clean_data = clean_provider.train_representation(200).squeeze()

  from .autonotebook import tqdm as notebook_tqdm


Finish initialization...
Finish initialization...
Finish initialization...
Finish initialization...


In [2]:
from representationTrans.sequence_alignment import SequenceAlignment
sa = SequenceAlignment(clean_provider, ref_provider,200,200)

In [3]:
sa.get_alignment_list()

NameError: name 'cuda_cka' is not defined

In [None]:
###  ============================= caculate cca ============================== ###
import numpy as np
from sklearn.cross_decomposition import CCA
clean_cca_list = []
# 建立模型

for i in range(1, 199):
    X = clean_provider.train_representation(i)
    Y = clean_provider.train_representation(i+1)
    cca = CCA(n_components=1)
    cca.fit(X, Y)
    X_train_r, Y_train_r = cca.transform(X, Y)
    if i % 10 == 0:
        print(i)
    clean_cca_list.append(np.corrcoef(X_train_r[:, 0], Y_train_r[:, 0])[0, 1])
    # print(np.corrcoef(X_train_r[:, 0], Y_train_r[:, 0])[0, 1]) #输出相关系数


In [None]:

ref_cca_list = [0.9622451094249629, 0.9835753696934071, 0.9885950360380795, 0.9884658596956643, 0.9907527021024848, 0.9914980617225714, 0.9918927613441185, 0.9924212070471425, 0.9912522945816102, 0.9913458076425505, 0.9907505998051814, 0.9923754440789142, 0.991637639898143, 0.9920791281562766, 0.990701403306507, 0.9899279646265837, 0.9918938558317462, 0.9912421362053275, 0.9929991687054294, 0.9911291153754909, 0.9919374859685238, 0.992178009077236, 0.991496968455006, 0.9910565388564236, 0.9916578567221991, 0.9910457254823529, 0.9892470256898181, 0.9907298273645971, 0.9912576466298185, 0.9916304533230863, 0.9908150199237304, 0.9901749841649476, 0.991380755870939, 0.9909099703048552, 0.9917641356893454, 0.9916125981221563, 0.9913325190694237, 0.9907191708547577, 0.9890562079494036, 0.9909263474569632, 0.9911462777326566, 0.989943884825059, 0.9911790417925539, 0.99235797405449, 0.9909741189128223, 0.9913086937082533, 0.9910876881956175, 0.9906494577882143, 0.9889024095338231, 0.9911112041624691, 0.9900249319789002, 0.9904836631297695, 0.9903575795485646, 0.9898199140097278, 0.9907121827475263, 0.9910025115244324, 0.9919861506050189, 0.9910551665779878, 0.9916800195807682, 0.9910223553220018, 0.9919359255652829, 0.9910943454684946, 0.9915235986837119, 0.9920351119800985, 0.9919620931638461, 0.9908937837284022, 0.9911716410814223, 0.9915540908415501, 0.9920974083449621, 0.992180302449655, 0.9931418510622484, 0.9922916880049685, 0.9923553535977595, 0.9923739846463326, 0.9938629306796499, 0.9923693674651273, 0.9924516539564034, 0.9922977153296574, 0.9919650688685142, 0.993095282636129, 0.9922137264812327, 0.9923043986025436, 0.9919606415483643, 0.9933168237286163, 0.9938471364798467, 0.9941640531593279, 0.9941000139724575, 0.9942449446191827, 0.9939208298901461, 0.9931849440867937, 0.9928654921387595, 0.9938343076755056, 0.9947415876075195, 0.9941610043851405, 0.9945210555624641, 0.993032912946252, 0.9941358655281427, 0.9936290487461124, 0.9945107103654884, 0.9951999438551333, 0.9956474188916958, 0.9943348721845243, 0.9941831912518191, 0.9951800451441885, 0.9950271350289819, 0.9956146938905595, 0.9961011070623562, 0.995789734099666, 0.9955675541292495, 0.9960121520008333, 0.9953244845802698, 0.9955023174912867, 0.99648174310157, 0.996677528705376, 0.9964995346896914, 0.9970539170939894, 0.9969678657744953, 0.9970443583532566, 0.997001675886903, 0.9973494438733952, 0.9974083738342602, 0.9973101268228184, 0.997373932703503, 0.997613577315616, 0.9977608897728937, 0.9977532416676256, 0.9975120938733212, 0.9981483151714753, 0.9978534613710964, 0.9981032193218553, 0.9982911543045908, 0.9982064831988767, 0.998069523191224, 0.9982984017879011, 0.9985564088133065, 0.9985969327187675, 0.9988338755959894, 0.9984912321123166, 0.9984139775819055, 0.9987142193492315, 0.9989031389076246, 0.9988966838781999, 0.9990361552495063, 0.9990141041137556, 0.9990910612239544, 0.9990442784557647, 0.9989599874636896, 0.9990776947333739, 0.9990330458603452, 0.9991420088071717, 0.9992492581033559, 0.9992357096303092, 0.9992410523396844, 0.9992991281273796, 0.9992632263036039, 0.9991685089780734, 0.9992549937209275, 0.9993566369018072, 0.9993040758318992, 0.9994147333550638, 0.9994508570323385, 0.9993186240599018, 0.9993948263493732, 0.9994450673414637, 0.999516711115316, 0.9994930006866765, 0.9995682045889294, 0.9996393603275299, 0.9995538185347962, 0.9996715685721145, 0.9996132719326178, 0.999704394584593, 0.9997678268311944, 0.9997736256447755, 0.9997350108361696, 0.9997521228037624, 0.9997557276723626, 0.9997600429859627, 0.9998237663061464, 0.999804374280031, 0.9998514695451305, 0.9998153136404436, 0.9998022836468979, 0.999790843464696, 0.9998013659084106, 0.9998246800018652, 0.9998237577430623, 0.9998074747982748, 0.9997494698933964, 0.9998177376540552, 0.999839634409812, 0.9998689725052073, 0.9998471894147288, 0.99982229361615, 0.9998064905079058, 0.9998368819246747, 0.9998379315531063, 0.9998380063732677]
confusion_cca_list = [0.9691079691612061, 0.9822425947952197, 0.9891749811150524, 0.9919902303796175, 0.9913304729659498, 0.9915032314511251, 0.9911266714155959, 0.9900789607090601, 0.9911563756175615, 0.9912015507398092, 0.9912163906280782, 0.9896079431832678, 0.9915716798378974, 0.9911241917700453, 0.9910566052660595, 0.9909186692137999, 0.9901049411705345, 0.990519975813794, 0.9905454542492705, 0.991664291612424, 0.9916398321929895, 0.9902203085963489, 0.9908775338260658, 0.9914335410955846, 0.9902586326216413, 0.9902906772565273, 0.9920870621008104, 0.9902659250606273, 0.9910267790076663, 0.9892729258019238, 0.9919444047223305, 0.9916042439779408, 0.9907261184772774, 0.9923439302731147, 0.993090352861986, 0.9893047156681617, 0.9910741167310907, 0.9913413170665972, 0.9912282224683948, 0.9911295124710229, 0.9915988277565911, 0.9896196679477531, 0.9903756515171117, 0.9912068218020462, 0.9910949391533591, 0.9914142893323672, 0.9915586298130105, 0.991901141324067, 0.9915358930069084, 0.9915331254771972, 0.9909563894320966, 0.991008204316034, 0.9904996623425726, 0.9910403710901854, 0.9913992735251576, 0.9915853967876933, 0.9902144557168321, 0.9906882240681562, 0.9916701421332133, 0.9911325524167488, 0.99110070672332, 0.9912082131858605, 0.9916076479046935, 0.9902612709451922, 0.9916012293857719, 0.9926205232897192, 0.9924459553155616, 0.9918943669584674, 0.9913583709518758, 0.9904128593036354, 0.9927553400226962, 0.9925293077738789, 0.9923774605172413, 0.9931138770538077, 0.9940650948416286, 0.9937185113453069, 0.9926988834945155, 0.9917813812043296, 0.9925319835435433, 0.9930748479486312, 0.9921087947258896, 0.9933455963088015, 0.9942635813291416, 0.9938254192675762, 0.993383134344554, 0.9943651916932125, 0.9939126492933409, 0.9929709307562634, 0.9931457815056719, 0.9938006855994072, 0.9944795523224754, 0.9938984379139449, 0.9946867284772373, 0.9939659720892872, 0.9939644616612384, 0.9947136509939496, 0.9951116902319268, 0.9946465712168454, 0.9949340869805533, 0.9951884871972384, 0.9952946908430956, 0.9961317516550324, 0.9953308616693018, 0.9957274450116549, 0.9961730380153746, 0.9959514941517723, 0.995947756002263, 0.995764354175928, 0.995375611723924, 0.9959798730168571, 0.996054411683772, 0.9958474189454791, 0.9966503427716941, 0.9962834476692906, 0.9961530521402198, 0.9970342701939289, 0.9970366321403493, 0.997272390662084, 0.9968668520084241, 0.9967596755659891, 0.9972292156430986, 0.9970914024608349, 0.9973317587795717, 0.9978021332233171, 0.9972924413202467, 0.997526231743933, 0.9977195887721713, 0.997307181366895, 0.9976998110527793, 0.9979710728410327, 0.9977882195727032, 0.9980875496022737, 0.9982500074355877, 0.9984595158176603, 0.9981843000421406, 0.9983981120738988, 0.9985273088876732, 0.9984762280292848, 0.9986339990383886, 0.99875492181865, 0.9986611947830945, 0.9988787140564158, 0.9989274017620462, 0.9988797847719695, 0.9990879354208785, 0.9990379493315017, 0.9990215454181116, 0.9992238434692426, 0.9990906061788857, 0.9993130508134689, 0.9992835788737048, 0.9991296311524542, 0.9992599466420964, 0.9991491907713526, 0.9994343157098656, 0.9994872531683596, 0.999369396448298, 0.9994613990284087, 0.9995211713319307, 0.9995836285985453, 0.9996145782606282, 0.9995486111672852, 0.9996039893122889, 0.9995430121408369, 0.9996229632023815, 0.9996345106064188, 0.9996877016094247, 0.9996744002180465, 0.9995990864670491, 0.9996241195503401, 0.9997302920570886, 0.9997531641206577, 0.9997454398907096, 0.9997809724023462, 0.9997383396304627, 0.9997586357279417, 0.999745529711172, 0.9997375612162193, 0.9998126009849321, 0.999803787995494, 0.9998219034019782, 0.9998111502300219, 0.9998291463108531, 0.999795828809873, 0.9997941295141138, 0.9998191182513961, 0.9998090907853469, 0.9998076446038813, 0.9997624099172717, 0.9998412326052819, 0.9998432666543409, 0.9998589879255736, 0.9998524413636124, 0.9998313142170382, 0.9998049925293127, 0.999849888477844, 0.9998380655465523, 0.9998385985815512]
exchange_cca_list = [0.9625416922637114, 0.9772292140858627, 0.9888686550353462, 0.9857288543705459, 0.9897333847102339, 0.9922574393105134, 0.9916061426441319, 0.9919612344401878, 0.9932897196417003, 0.990733623848824, 0.9897844142909603, 0.9908418909316102, 0.9901240933258761, 0.9921455829192927, 0.9932155432141062, 0.9926855986049717, 0.9920253570488698, 0.9909565735447242, 0.9927565942783686, 0.9913983904812115, 0.9925364761147013, 0.9907390000335751, 0.9904167311858342, 0.9915752596482009, 0.990264945129291, 0.9910074250465744, 0.9908377506018619, 0.9904250784625429, 0.9910226837029633, 0.9914889165996984, 0.9912628985612736, 0.990655903613584, 0.9908260156581322, 0.9916205413965046, 0.9910854155104328, 0.9909776048055939, 0.9917358637943079, 0.9911573606033897, 0.989320583175197, 0.9900400641089953, 0.9905048187212195, 0.9888278810533379, 0.9905180074558894, 0.9914415800226986, 0.9907152494613233, 0.9913507670297728, 0.991639450950877, 0.9911213893780079, 0.9900272506155725, 0.9909599951610346, 0.9900330888774284, 0.9907947105640066, 0.9901779698324892, 0.9899963181717432, 0.9903467702966885, 0.9904745318393521, 0.9900316785170274, 0.9914990798990252, 0.9914110609312846, 0.99165059106778, 0.9906750602684019, 0.9908738526409723, 0.989776677129118, 0.9903385493371345, 0.9910786796601072, 0.9917289891590522, 0.9923230569524817, 0.9914130827474287, 0.9920434533549615, 0.9907209753790179, 0.9916361967211017, 0.991762922719466, 0.9924410678222552, 0.9935574972171022, 0.993567462942503, 0.9923182353708428, 0.9928748337474277, 0.9924578778029751, 0.992919532422317, 0.9923917786731269, 0.9935702513203063, 0.9936005354364866, 0.992913230550444, 0.9924635479626384, 0.9924720507744168, 0.9933916485024251, 0.9943133679098782, 0.9941233767024431, 0.9937520684813214, 0.9933933966275583, 0.9943176508375132, 0.9941345985154235, 0.9935444249846745, 0.9941754006345629, 0.9946312048226655, 0.993951903823864, 0.9937424229573918, 0.9942399498569865, 0.9947845890414918, 0.9952701465275648, 0.9952135883499701, 0.9960578885732313, 0.9957861189231559, 0.9947373914503317, 0.9956055362180464, 0.9958798016997922, 0.9953485948895525, 0.9958069636084204, 0.9961294848008391, 0.9964089920162637, 0.9965139699811376, 0.995878573211714, 0.9965761568549759, 0.9967769570018609, 0.9968031169785241, 0.9972478460723094, 0.9968683791552854, 0.9973225651989476, 0.9969229611055229, 0.9962064542871905, 0.9971577565997749, 0.9968612324962882, 0.9976359574632719, 0.9974411133948995, 0.9977876862532359, 0.9976025661457425, 0.9974657138523548, 0.9973967828026383, 0.9978008865585564, 0.9980145045838275, 0.9980846734314418, 0.998118345964994, 0.9983778240208685, 0.9984168747504433, 0.9980992655517134, 0.9980440808778002, 0.9982826927418399, 0.9985964819381884, 0.9984431229170245, 0.9981673267758756, 0.9981948971763621, 0.99867500230732, 0.9987535083376661, 0.9988488168039115, 0.998777534371259, 0.9988383125389056, 0.9990924135142072, 0.9991409090383826, 0.9989647794282093, 0.9990266600653138, 0.99925257841211, 0.9988558120731599, 0.9992138316637813, 0.9993518178523768, 0.9991643929283764, 0.9992932743858753, 0.9993363093371692, 0.9994161401558553, 0.9995688552492822, 0.9994938000139392, 0.9995470249354611, 0.9996602119560518, 0.999556845499058, 0.9994713280304764, 0.9997075170388626, 0.9996739879008507, 0.9996791571251831, 0.9997915760643519, 0.9997452762245091, 0.9998369843242096, 0.999833240196438, 0.9998195444434842, 0.9997783367457588, 0.9998404655511646, 0.9997996562785472, 0.9998225453636974, 0.9998418957668874, 0.9998010718680105, 0.999841351987324, 0.999868585368663, 0.9998749770450129, 0.9998807463789926, 0.9998787412842799, 0.9998693165836638, 0.9998578877824389, 0.9998617620273133, 0.9998762074676931, 0.9998664674098675, 0.999823400565184, 0.9998840907200951, 0.9999046966817794, 0.999823202886412, 0.9999090066827392, 0.9998853423154211, 0.9998704787689346, 0.9999073108395595, 0.9998966418594631, 0.9998929652618687]
clean_cca_list = [0.9664381333630849, 0.9823183058857132, 0.9898305049472212, 0.9916307161267205, 0.9902366847941002, 0.9909821041506907, 0.9919535893266979, 0.99225772520799, 0.9912563656790486, 0.9913936067098814, 0.9922190424247704, 0.99098657213197, 0.9905438699920097, 0.9926522241579533, 0.9915411665644067, 0.9920985217174182, 0.9901174862998157, 0.9884994204332413, 0.9898336185960632, 0.9896764221559183, 0.9922428128053004, 0.9919667656494353, 0.9911286419771195, 0.9914852049706925, 0.99185156517518, 0.9912726583896807, 0.9912692518320055, 0.9919629210462825, 0.9915466373475277, 0.9912091228369629, 0.9923420947954605, 0.9918821277521025, 0.9909999900808344, 0.9904218834975913, 0.9906341416093438, 0.9896065619159338, 0.9894592850779631, 0.9895685855284996, 0.9906330989455345, 0.99151195830708, 0.99079964640464, 0.9911894303342903, 0.9914081113797879, 0.9909136860097827, 0.9907154095346129, 0.9897934938530927, 0.9896426742573847, 0.990690130565513, 0.9892021560658765, 0.9895942642371922, 0.991439898503177, 0.9902505539294885, 0.9917276092317666, 0.9910320117979324, 0.9903778632931711, 0.9907308988314721, 0.9908287281153826, 0.990957464836934, 0.9902591602321942, 0.9901778969576829, 0.9893212929713242, 0.9899597238231583, 0.9897740796907252, 0.989404708199639, 0.9897341293817226, 0.9911076721435754, 0.9923859635584618, 0.9920135300467058, 0.9908329906916475, 0.9911073751797638, 0.9915878339838115, 0.9917280337284392, 0.9923965747637741, 0.9922828152381195, 0.9920029957331039, 0.9929341433322159, 0.9921611080978424, 0.9930204708646274, 0.9926611911282773, 0.9934293355396077, 0.9911731779304321, 0.9923365461584377, 0.9932484151577886, 0.9936716748815193, 0.993613432016308, 0.993556648605173, 0.9934878327047019, 0.9933242612425928, 0.9943275295485016, 0.9940442040787156, 0.9937744017787109, 0.9942837459020432, 0.9934687194976468, 0.9940454930043204, 0.9940482152868995, 0.9942721702459794, 0.9944867249426934, 0.9945023332125071, 0.9943541895178122, 0.9946333650368039, 0.9947111379610288, 0.995056788067687, 0.9957342497489945, 0.995339437428384, 0.9959257163799949, 0.9952947492781006, 0.9961833649053766, 0.996247868462556, 0.9959911571369325, 0.9963593804460807, 0.9967279771872858, 0.9968447662503979, 0.9964948430723041, 0.9962391791564471, 0.9964342003163594, 0.9962650366556998, 0.9967413471769525, 0.9970149627937812, 0.9973407555501131, 0.9971993752414847, 0.9971942224685683, 0.9969761182454963, 0.9972789724699017, 0.9974996528034911, 0.9974482791178748, 0.9975834957199025, 0.9980724544587282, 0.9983786080576931, 0.9987019165984972, 0.9982294796022949, 0.9980600783812776, 0.9985797694710253, 0.9981150536792706, 0.9983587353388038, 0.9981711083745048, 0.9983536862114208, 0.9989124824134836, 0.9988643512987458, 0.9984754132219557, 0.99882439082139, 0.9990125667790151, 0.9989257286151529, 0.9988003107831569, 0.9990233575038319, 0.9991526636095149, 0.9990826667389523, 0.9992383541784713, 0.999281424644698, 0.9993259422830337, 0.9993128859892402, 0.9992104762379568, 0.9992923634731601, 0.9992749288483549, 0.9992243323094138, 0.9994637628224613, 0.9995385332390841, 0.999533042682391, 0.9996013861797668, 0.9996104659286753, 0.9994667385435814, 0.999527695504661, 0.9995363942756871, 0.9995493658178972, 0.9994758096770595, 0.9996224049517177, 0.9996867248279919, 0.9996972261739577, 0.9996683018348005, 0.9996771025600594, 0.9997189039797815, 0.9997665430460021, 0.9997670461635924, 0.9997783594004233, 0.9997845474506614, 0.999804442151636, 0.9998030055079398, 0.9998178896303914, 0.9997939310222153, 0.9998266549570507, 0.9998339943899183, 0.9998651307602774, 0.9998541914151953, 0.9998576586987935, 0.9998142102604031, 0.999819897329105, 0.9998313659524504, 0.9998490386553952, 0.9998298904710171, 0.9997862715243909, 0.9998520430449932, 0.9998867782343663, 0.9998867363703333, 0.9998674114044384, 0.9998502096499234, 0.9998247859158637, 0.9998563046703977, 0.9998618870622744, 0.9998646982556855]
import matplotlib.pyplot as plt

plt.figure()
plt.xlabel('epoch')
plt.ylabel('Neibour')
plt.title('EpochK vs Epoch K+1 CKA')

# plt.plot([0],[ref_cca_list[0] ],'o',c='r') 
# plt.plot([0],[confusion_cca_list[0]],'o',c='b') 
# plt.plot([0],[exchange_cca_list[0]],'o',c='g') 


# linestyle="--"
# plt.ylim(0.994,1.001)
plt.plot(ref_cca_list[:30], 'r', lw=0.5, label='reference')
plt.plot(clean_cca_list[:30], lw=0.5, label='ref_2')
plt.plot(confusion_cca_list[:30], 'b', lw=0.5, label='confusion')
plt.plot(exchange_cca_list[:30], 'g', lw=0.5, label='exchange')

plt.legend(loc=4)
plt.show()

In [None]:
import numpy as np
from sklearn.cross_decomposition import CCA
ccaConList = []
# 建立模型

for i in range(1, 199):
    X = confusion_provider.train_representation(i)
    Y = confusion_provider.train_representation(i+1)
    cca = CCA(n_components=1)
    cca.fit(X, Y)
    X_train_r, Y_train_r = cca.transform(X, Y)
    print(i)
    ccaList.append(np.corrcoef(X_train_r[:, 0], Y_train_r[:, 0])[0, 1])
    print(np.corrcoef(X_train_r[:, 0], Y_train_r[:, 0])[0, 1]) #输出相关系数