In [None]:
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from tqdm import(tqdm)

class DTMKL:
    def __init__(self, X_train_A, Y_train_A, X_train_T, Y_train_T,
                  X_unlabeled_T=None,
                 kernel_types=['linear', 'rbf'], 
                 C=1.0,theta = 1, eta=0.1,
                 gamma_rbf=1.0, epsilon=1e-3, max_iter=100):
        """
        Domain Transfer Multiple Kernel Learning (DTMKL) implementation.
        
        Parameters:
        - X_train_A: Features from auxiliary domain
        - Y_train_A: Labels from auxiliary domain
        - X_train_T: Features from target domain (labeled)
        - Y_train_T: Labels from target domain
        - X_unlabeled_T: Unlabeled features from target domain (for MMD calculation)
        - kernel_types: Types of base kernels ['linear', 'rbf', 'poly1.5']
        - C: SVM regularization parameter
        - eta: Trade-off parameter between domain adaptation and classification error
        - gamma_rbf: RBF kernel parameter
        - degree_poly: Polynomial kernel degree
        - epsilon: Convergence threshold
        - max_iter: Maximum number of iterations
        """
        # Initialize data
        self.X_train_A = X_train_A
        self.Y_train_A = Y_train_A.reshape(-1, 1)
        self.X_train_T = X_train_T
        self.Y_train_T = Y_train_T.reshape(-1, 1)
        self.theta = theta
        # Handle unlabeled target data
        if X_unlabeled_T is None:
            # If no unlabeled data is provided, use labeled target data for MMD
            self.X_unlabeled_T = X_train_T
        else:
            self.X_unlabeled_T = X_unlabeled_T
            
        # Combine labeled data from both domains
        self.X_labeled = np.vstack([X_train_A, X_train_T])
        self.Y_labeled = np.vstack([self.Y_train_A, self.Y_train_T]).flatten()
        self.nA = len(X_train_A)
        self.nT_labeled = len(X_train_T)
        self.nT_unlabeled = len(self.X_unlabeled_T)
        
        # All data for MMD calculation
        self.X_all_mmd = np.vstack([X_train_A,X_train_T,self.X_unlabeled_T])
        self.nA_mmd = self.nA
        self.nT_mmd = self.nT_labeled + self.nT_unlabeled
        
        # Kernel parameters
        self.kernel_types = kernel_types
        self.M = len(kernel_types)
        self.gamma_rbf = gamma_rbf
        self.degree_poly = degree_poly
        
        # Optimization parameters
        self.C = C
        self.eta = eta
        self.epsilon = epsilon
        self.max_iter = max_iter
        
        # Initialize domain indicator vector for MMD
        self.s = self._create_s_vector()
        
        # Precompute kernel matrices
        self.Km_list_labeled = self._precompute_base_kernels(self.X_labeled, self.X_labeled)
        self.Km_list_mmd = self._precompute_base_kernels(self.X_all_mmd, self.X_all_mmd)
        
        # Initialize kernel weights
        self.d = np.ones(self.M) / self.M
        self.alpha = None
        self.b = 0
        self.svm = None

    def _create_s_vector(self):
        """
        Create domain indicator vector for MMD calculation.
        Returns a vector where auxiliary domain entries are 1/nA and 
        target domain entries are -1/nT
        """
        s_A = np.full(self.nA_mmd, 1.0/self.nA_mmd)
        s_T = np.full(self.nT_mmd, -1.0/self.nT_mmd)
        return np.concatenate([s_A, s_T]).reshape(-1,1)  # (n_samples, 1)

    def _kernel_function(self, X1, X2, kernel_type):
        """Calculate kernel matrix for a given kernel type"""
        if kernel_type == 'linear':
            return X1 @ X2.T
        elif kernel_type == 'rbf':
            gamma = 1.0 / (X1.shape[1] * self.gamma_rbf)
            pairwise_dists = np.sum(X1**2, axis=1)[:, np.newaxis] + \
                             np.sum(X2**2, axis=1) - 2 * X1 @ X2.T
            return np.exp(-gamma * pairwise_dists)
        elif kernel_type.startswith("poly"):
            degree = float(kernel_type[4:])
            return (X1 @ X2.T + 1) ** degree
        else:
            raise ValueError(f"Unsupported kernel type: {kernel_type}")

    def _precompute_base_kernels(self, X1, X2):
        """Precompute all base kernel matrices"""
        return [self._kernel_function(X1, X2, kt) for kt in self.kernel_types]

    def compute_mmd_vector(self):
        """
        Compute the MMD vector p = [tr(K1S), ..., tr(KmS)]^T,
        where S = ss^T and s is the domain indicator vector
        """
        return np.array([np.trace(Km @ (self.s @ self.s.T)) for Km in self.Km_list_mmd]).reshape(-1, 1)

    def _combine_kernels(self, d, kernel_matrices):
        """Combine multiple kernel matrices using weights d"""
        return np.sum([d[m] * kernel_matrices[m] for m in range(self.M)], axis=0)

    def _compute_gradient(self, d):
        """
        Compute the gradient of the objective function:
        ∇h = pp^T d + η ∇J
        where p is the MMD vector and ∇J is the gradient of SVM objective
        """
        # MMD part gradient: pp^T * d
        p = self.compute_mmd_vector()
        grad_part1 = (p @ p.T) @ d.reshape(-1, 1)
        
        # SVM objective gradient: η * ∂J/∂d
        y = self.Y_labeled.reshape(-1, 1)
        alpha_y = self.alpha.reshape(-1, 1) * y
        grad_J = np.array([-0.5 * (alpha_y.T @ Km @ alpha_y).item() 
                          for Km in self.Km_list_labeled])
        grad_part2 = self.theta * grad_J.reshape(-1, 1)
        
        return grad_part1 + grad_part2

    def _compute_hessian(self, d):
        """
        Compute the Hessian matrix of the objective function
        H = pp^T + η ∇²J
        """
        p = self.compute_mmd_vector()
        hessian = p @ p.T
        
        # Add a small value to diagonal to ensure positive definiteness
        return hessian + 1e-2 * np.eye(self.M)

    def _project_simplex(self, v):
        """
        Project vector v onto the probability simplex:
        d ≥ 0, Σd = 1
        """
        u = np.sort(v)[::-1]
        cumsum = np.cumsum(u)
        rho = np.where(u > (cumsum - 1) / (np.arange(1, len(v)+1)))[0][-1]
        theta = (cumsum[rho] - 1) / (rho+1)
        return np.maximum(v - theta, 0)

    def fit(self):
        """Main training loop for DTMKL"""
        prev_obj = np.inf
        for iter in tqdm(range(self.max_iter)):
            print(self.d)
            # Step 1: Fix d, train SVM
            K_combined = self._combine_kernels(self.d, self.Km_list_labeled)
            self.svm = SVC(kernel='precomputed', C=self.C)
            self.svm.fit(K_combined, self.Y_labeled)
            
            # Reconstruct full alpha array
            self.alpha = np.zeros(len(self.Y_labeled))
            self.alpha[self.svm.support_] = self.svm.dual_coef_[0] * self.Y_labeled[self.svm.support_]
            self.b = self.svm.intercept_[0]
            
            # Step 2: Fix alpha, update d using second-order method
            gradient = self._compute_gradient(self.d)
            hessian = self._compute_hessian(self.d)
            
            # Use Newton's method with line search
            gradient_col = gradient.reshape(-1, 1)  # 确保列向量
            update_direction = -(np.linalg.inv(hessian) @ gradient_col).flatten()
            
            # Line search for step size
            step_size = self.eta
            d_new = self._project_simplex(self.d - step_size * update_direction)
            
            # # Calculate objective function value
            # p = self.compute_mmd_vector()
            # obj_mmd = 0.5 * (self.d.T @ p)**2
            
            # K_combined = self._combine_kernels(self.d, self.Km_list_labeled)
            # obj_svm = 0.5 * self.alpha @ K_combined @ self.alpha - np.sum(self.alpha)
            # obj = obj_mmd + self.eta * obj_svm
            
            # # Print progress
            # if verbose and (iter % 10 == 0 or iter == self.max_iter-1):
            #     acc = accuracy_score(self.Y_labeled, self.svm.predict(K_combined))
            #     print(f"Iter {iter}: Obj={obj.item():.4f}, Acc={acc:.4f}, d={np.round(self.d,3)}")
            
            # # Check convergence
            # if np.linalg.norm(d_new - self.d) < self.epsilon and abs(obj - prev_obj) < self.epsilon:
            #     if verbose: 
            #         print(f"Converged at iteration {iter}")
            #     break
                
            self.d = d_new
            print(self.d)
            #prev_obj = obj

    def predict(self, X_test):
        """Predict labels for new test data"""
        # Compute kernel matrices between test and training data
        K_test_list = []
        for kt in self.kernel_types:
            K_test = self._kernel_function(X_test, self.X_labeled, kt)
            K_test_list.append(K_test)
        
        # Combine kernels
        K_test_combined = self._combine_kernels(self.d, K_test_list)
        
        # Use SVM to predict
        return self.svm.predict(K_test_combined)

    def evaluate(self, X_test, Y_test):
        """Evaluate model performance"""
        y_pred = self.predict(X_test)
        return accuracy_score(Y_test, y_pred)



In [8]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
import warnings
warnings.filterwarnings('ignore')
import os
from sklearn.model_selection import train_test_split



def load_newsgroup_data(setting='comp_vs_rec'):
    """
    Load and prepare 20 Newsgroups dataset for domain adaptation experiment.
    
    Parameters:
    - setting: Which setting to use ('comp_vs_rec', 'comp_vs_sci', or 'comp_vs_talk')
    
    Returns:
    - Data for auxiliary and target domains
    """
    print(f"Loading {setting} data...")
    
    # Define category mappings according to the paper's Table 1
    settings = {
        'comp_vs_rec': {
            'auxiliary': ['comp.windows.x', 'rec.sport.hockey'],
            'target': ['comp.sys.ibm.pc.hardware', 'rec.motorcycles'],
            'positive_class': 'comp',
            'negative_class': 'rec'
        },
        'comp_vs_sci': {
            'auxiliary': ['comp.windows.x', 'sci.crypt'],
            'target': ['comp.sys.ibm.pc.hardware', 'sci.med'],
            'positive_class': 'comp',
            'negative_class': 'sci'
        },
        'comp_vs_talk': {
            'auxiliary': ['comp.windows.x', 'talk.politics.mideast'],
            'target': ['comp.sys.ibm.pc.hardware', 'talk.politics.guns'],
            'positive_class': 'comp',
            'negative_class': 'talk'
        }
    }
    
    # Get the specific setting
    current_setting = settings[setting]
    
    # Load the auxiliary domain data
    auxiliary_categories = current_setting['auxiliary']
    auxiliary_data = fetch_20newsgroups(subset='all', 
                                       categories=auxiliary_categories,
                                       shuffle=True, 
                                       random_state=42)
    
    # Load the target domain data
    target_categories = current_setting['target']
    target_data = fetch_20newsgroups(subset='all', 
                                    categories=target_categories,
                                    shuffle=True, 
                                    random_state=42)
    
    # Create labels
    # Positive class (comp) = 1, Negative class (rec/sci/talk) = -1
    auxiliary_labels = []
    for target_idx in auxiliary_data.target:
        category = auxiliary_data.target_names[target_idx]
        if category.startswith(current_setting['positive_class']):
            auxiliary_labels.append(1)
        else:
            auxiliary_labels.append(-1)
    auxiliary_labels = np.array(auxiliary_labels)

    target_labels = []
    for target_idx in target_data.target:
        category = target_data.target_names[target_idx]
        if category.startswith(current_setting['positive_class']):
            target_labels.append(1)
        else:
            target_labels.append(-1)
    target_labels = np.array(target_labels)
    

    
    # Create feature vectors - use TF-IDF on the text data
    vectorizer = TfidfVectorizer(min_df=5, max_df=0.9, sublinear_tf=True, use_idf=True)
    
    # Fit the vectorizer on all data
    all_texts = auxiliary_data.data + target_data.data
    vectorizer.fit(all_texts)
    
    # Transform the auxiliary and target data
    auxiliary_features = vectorizer.transform(auxiliary_data.data).toarray()
    target_features = vectorizer.transform(target_data.data).toarray()
    
    print(f"Auxiliary domain: {len(auxiliary_labels)} samples")
    print(f"Target domain: {len(target_labels)} samples")
    
    return auxiliary_features, auxiliary_labels, target_features, target_labels
X_aux, y_aux, X_tar, y_tar = load_newsgroup_data('comp_vs_talk')
print("X_aux",X_aux)
print("y_aux",y_aux)
print("X_tar",X_tar)
print("y_tar",y_tar)

        # 随机选择标记样本
pos_indices = np.where(y_tar == 1)[0]
neg_indices = np.where(y_tar == -1)[0]
        
np.random.seed(123)  # 设置随机种子以确保可重复性
labeled_samples_per_class=5
# 确保选择的样本数不超过可用样本数
n_pos = min(labeled_samples_per_class, len(pos_indices))
n_neg = min(labeled_samples_per_class, len(neg_indices))
        
if n_pos < labeled_samples_per_class or n_neg < labeled_samples_per_class:
    print(f"Warning: Requested {labeled_samples_per_class} samples per class, but only found {n_pos} positive and {n_neg} negative samples")
        
pos_labeled_idx = np.random.choice(pos_indices, n_pos, replace=False)
neg_labeled_idx = np.random.choice(neg_indices, n_neg, replace=False)
        
labeled_idx = np.concatenate([pos_labeled_idx, neg_labeled_idx])
unlabeled_idx = np.array([i for i in range(len(y_tar)) if i not in labeled_idx])
        
# 分割目标域数据为标记和未标记
X_tar_labeled = X_tar[labeled_idx]
y_tar_labeled = y_tar[labeled_idx]
X_tar_unlabeled = X_tar[unlabeled_idx]
y_tar_unlabeled = y_tar[unlabeled_idx]  # 用于评估的真实标签
# 只使用线性核
C=5
kernel_types = ['linear',"poly1.5","poly1.6","poly1.7","poly1.8","poly1.9","poly2.0"]
dtmkl = DTMKL(
      X_train_A=X_aux,
      Y_train_A=y_aux, 
      X_train_T=X_tar_labeled,
      Y_train_T=y_tar_labeled,
      X_unlabeled_T=X_tar_unlabeled,
      kernel_types=kernel_types,  # 只使用线性核
      C=C,
      theta = 2e-3,
      eta=2e-3,
      max_iter=50  # 减少迭代次数
)
                        
                        
# 训练DTMKL模型
dtmkl.fit(verbose=False)
                        
# 预测
y_pred = dtmkl.predict(X_tar_unlabeled)  
print("y_pred",y_pred)
print((len(y_pred)-np.count_nonzero(y_pred-y_tar_unlabeled))/len(y_pred))     
clf = SVC(kernel='linear', C=C)
clf.fit(X_tar_labeled, y_tar_labeled)
y_pred1 = clf.predict(X_tar_unlabeled)
print("ypre-svm_T",y_pred1)
print((len(y_pred1)-np.count_nonzero(y_pred1-y_tar_unlabeled))/len(y_pred1))

clf = SVC(kernel='linear', C=C)
clf.fit(X_aux, y_aux)
y_pred2 = clf.predict(X_tar_unlabeled)
print("ypre-svm_A",y_pred2)
print((len(y_pred2)-np.count_nonzero(y_pred2-y_tar_unlabeled))/len(y_pred2))

X_combined = np.vstack([X_aux, X_tar_labeled])
y_combined = np.concatenate([y_aux, y_tar_labeled])
clf = SVC(kernel='linear', C=C)
clf.fit(X_combined, y_combined)
y_pred3 = clf.predict(X_tar_unlabeled)
print("ypre-svm_AT",y_pred2)
print((len(y_pred3)-np.count_nonzero(y_pred3-y_tar_unlabeled))/len(y_pred3))


 
        


Loading comp_vs_talk data...
Auxiliary domain: 1928 samples
Target domain: 1892 samples
X_aux [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
y_aux [ 1 -1 -1 ...  1  1  1]
X_tar [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
y_tar [-1  1  1 ... -1  1 -1]


  0%|          | 0/50 [00:00<?, ?it/s]

[0.14285714 0.14285714 0.14285714 0.14285714 0.14285714 0.14285714
 0.14285714]


  2%|▏         | 1/50 [00:14<12:08, 14.86s/it]

[0.15002231 0.14495309 0.14374406 0.14245893 0.1410923  0.13963838
 0.13809093]
[0.15002231 0.14495309 0.14374406 0.14245893 0.1410923  0.13963838
 0.13809093]


  4%|▍         | 2/50 [00:29<11:38, 14.54s/it]

[0.15728071 0.1470768  0.14464291 0.14205577 0.13930449 0.13637738
 0.13326194]
[0.15728071 0.1470768  0.14464291 0.14205577 0.13930449 0.13637738
 0.13326194]


  6%|▌         | 3/50 [00:43<11:20, 14.48s/it]

[0.16463612 0.14922942 0.14555421 0.14164748 0.13749277 0.13307243
 0.12836756]
[0.16463612 0.14922942 0.14555421 0.14164748 0.13749277 0.13307243
 0.12836756]


  8%|▊         | 4/50 [00:57<11:04, 14.44s/it]

[0.17209209 0.15141204 0.14647843 0.14123388 0.13565627 0.12972191
 0.12340539]
[0.17209209 0.15141204 0.14647843 0.14123388 0.13565627 0.12972191
 0.12340539]


 10%|█         | 5/50 [01:12<10:47, 14.40s/it]

[0.17965172 0.15362556 0.14741595 0.1408148  0.13379423 0.1263244
 0.11837334]
[0.17965172 0.15362556 0.14741595 0.1408148  0.13379423 0.1263244
 0.11837334]


 12%|█▏        | 6/50 [01:26<10:31, 14.34s/it]

[0.1873185  0.15587105 0.14836726 0.14039006 0.13190578 0.12287833
 0.11326902]
[0.1873185  0.15587105 0.14836726 0.14039006 0.13190578 0.12287833
 0.11326902]


 14%|█▍        | 7/50 [01:42<10:40, 14.90s/it]

[0.19509654 0.15814974 0.14933287 0.13995946 0.12998992 0.11938181
 0.10808966]
[0.19509654 0.15814974 0.14933287 0.13995946 0.12998992 0.11938181
 0.10808966]


 16%|█▌        | 8/50 [02:00<11:03, 15.79s/it]

[0.20299013 0.16046296 0.15031337 0.13952278 0.12804558 0.11583288
 0.1028323 ]
[0.20299013 0.16046296 0.15031337 0.13952278 0.12804558 0.11583288
 0.1028323 ]


 18%|█▊        | 9/50 [02:18<11:14, 16.46s/it]

[0.21100357 0.16281199 0.15130933 0.13907981 0.1260717  0.11222958
 0.09749403]
[0.21100357 0.16281199 0.15130933 0.13907981 0.1260717  0.11222958
 0.09749403]


 20%|██        | 10/50 [02:35<11:03, 16.59s/it]

[0.21914169 0.16519829 0.15232137 0.13863029 0.1240671  0.10856971
 0.09207156]
[0.21914169 0.16519829 0.15232137 0.13863029 0.1240671  0.10856971
 0.09207156]


 22%|██▏       | 11/50 [02:51<10:43, 16.50s/it]

[0.22740933 0.16762332 0.15335014 0.13817398 0.12203058 0.10485105
 0.08656159]
[0.22740933 0.16762332 0.15335014 0.13817398 0.12203058 0.10485105
 0.08656159]


 24%|██▍       | 12/50 [03:07<10:25, 16.46s/it]

[0.23581185 0.1700887  0.15439633 0.13771061 0.11996082 0.10107119
 0.0809605 ]
[0.23581185 0.1700887  0.15439633 0.13771061 0.11996082 0.10107119
 0.0809605 ]


 26%|██▌       | 13/50 [03:24<10:07, 16.41s/it]

[0.24435495 0.17259617 0.15546071 0.13723989 0.11785642 0.09722749
 0.07526437]
[0.24435495 0.17259617 0.15546071 0.13723989 0.11785642 0.09722749
 0.07526437]


 28%|██▊       | 14/50 [03:40<09:46, 16.29s/it]

[0.25304462 0.17514754 0.15654408 0.13676151 0.11571589 0.09331723
 0.06946913]
[0.25304462 0.17514754 0.15654408 0.13676151 0.11571589 0.09331723
 0.06946913]


 30%|███       | 15/50 [03:56<09:31, 16.34s/it]

[0.26188757 0.17774482 0.1576473  0.13627514 0.11353759 0.08933736
 0.06357022]
[0.26188757 0.17774482 0.1576473  0.13627514 0.11353759 0.08933736
 0.06357022]


 32%|███▏      | 16/50 [04:12<09:15, 16.35s/it]

[0.27089036 0.18039005 0.15877127 0.13578046 0.11131989 0.08528485
 0.05756311]
[0.27089036 0.18039005 0.15877127 0.13578046 0.11131989 0.08528485
 0.05756311]


 34%|███▍      | 17/50 [04:29<09:00, 16.37s/it]

[0.28006066 0.18308551 0.15991699 0.13527706 0.10906091 0.08115623
 0.05144264]
[0.28006066 0.18308551 0.15991699 0.13527706 0.10906091 0.08115623
 0.05144264]


 36%|███▌      | 18/50 [04:45<08:44, 16.40s/it]

[0.28940689 0.1858338  0.16108559 0.13476454 0.10675857 0.07694761
 0.04520299]
[0.28940689 0.1858338  0.16108559 0.13476454 0.10675857 0.07694761
 0.04520299]


 38%|███▊      | 19/50 [05:02<08:29, 16.44s/it]

[0.29893768 0.18863751 0.16227823 0.13424246 0.10441075 0.07265508
 0.0388383 ]
[0.29893768 0.18863751 0.16227823 0.13424246 0.10441075 0.07265508
 0.0388383 ]


 40%|████      | 20/50 [05:18<08:13, 16.45s/it]

[0.30866248 0.19149953 0.16349615 0.13371034 0.1020151  0.06827431
 0.03234209]
[0.30866248 0.19149953 0.16349615 0.13371034 0.1020151  0.06827431
 0.03234209]


 42%|████▏     | 21/50 [05:34<07:52, 16.28s/it]

[0.31859144 0.19442294 0.1647407  0.13316769 0.09956914 0.06380065
 0.02570746]
[0.31859144 0.19442294 0.1647407  0.13316769 0.09956914 0.06380065
 0.02570746]


 44%|████▍     | 22/50 [05:50<07:33, 16.21s/it]

[0.32873495 0.19741087 0.16601325 0.13261396 0.0970703  0.05922937
 0.01892731]
[0.32873495 0.19741087 0.16601325 0.13261396 0.0970703  0.05922937
 0.01892731]


 46%|████▌     | 23/50 [06:07<07:19, 16.27s/it]

[0.33910572 0.20046723 0.16731552 0.13204853 0.09451544 0.05455463
 0.01199293]
[0.33910572 0.20046723 0.16731552 0.13204853 0.09451544 0.05455463
 0.01199293]


 48%|████▊     | 24/50 [06:23<07:05, 16.36s/it]

[0.34971752 0.20359621 0.16864936 0.13147073 0.09190117 0.04977013
 0.00489488]
[0.34971752 0.20359621 0.16864936 0.13147073 0.09190117 0.04977013
 0.00489488]


 50%|█████     | 25/50 [06:40<06:50, 16.43s/it]

[0.3601883  0.206406   0.16962056 0.13048378 0.08882796 0.0444734
 0.        ]
[0.3601883  0.206406   0.16962056 0.13048378 0.08882796 0.0444734
 0.        ]


 52%|█████▏    | 26/50 [06:56<06:34, 16.42s/it]

[0.37006078 0.20844618 0.16978036 0.12864069 0.0848507  0.0382213
 0.        ]
[0.37006078 0.20844618 0.16978036 0.12864069 0.0848507  0.0382213
 0.        ]


 54%|█████▍    | 27/50 [07:12<06:15, 16.32s/it]

[0.38011511 0.21052494 0.1699438  0.12676385 0.08079968 0.03185262
 0.        ]
[0.38011511 0.21052494 0.1699438  0.12676385 0.08079968 0.03185262
 0.        ]


 56%|█████▌    | 28/50 [07:29<05:59, 16.32s/it]

[0.39035965 0.21264407 0.17011107 0.12485173 0.07667152 0.02536195
 0.        ]
[0.39035965 0.21264407 0.17011107 0.12485173 0.07667152 0.02536195
 0.        ]


 58%|█████▊    | 29/50 [07:45<05:42, 16.32s/it]

[0.40080379 0.2148056  0.17028239 0.1229026  0.07246238 0.01874324
 0.        ]
[0.40080379 0.2148056  0.17028239 0.1229026  0.07246238 0.01874324
 0.        ]


 60%|██████    | 30/50 [08:01<05:28, 16.40s/it]

[0.41145694 0.21701156 0.17045796 0.12091471 0.06816843 0.0119904
 0.        ]
[0.41145694 0.21701156 0.17045796 0.12091471 0.06816843 0.0119904
 0.        ]


 62%|██████▏   | 31/50 [08:18<05:12, 16.45s/it]

[0.42233067 0.21926445 0.17063805 0.11888593 0.06378495 0.00509595
 0.        ]
[0.42233067 0.21926445 0.17063805 0.11888593 0.06378495 0.00509595
 0.        ]


 64%|██████▍   | 32/50 [08:35<04:59, 16.66s/it]

[0.4330473  0.22117725 0.17043336 0.11642452 0.05891757 0.
 0.        ]
[0.4330473  0.22117725 0.17043336 0.11642452 0.05891757 0.
 0.        ]


 66%|██████▌   | 33/50 [08:51<04:40, 16.47s/it]

[0.44294114 0.2220899  0.16918519 0.11287319 0.05291058 0.
 0.        ]
[0.44294114 0.2220899  0.16918519 0.11287319 0.05291058 0.
 0.        ]


 68%|██████▊   | 34/50 [09:07<04:20, 16.31s/it]

[0.45299352 0.22301789 0.16791738 0.1092648  0.04680641 0.
 0.        ]
[0.45299352 0.22301789 0.16791738 0.1092648  0.04680641 0.
 0.        ]


 70%|███████   | 35/50 [09:23<04:02, 16.20s/it]

[0.46321116 0.22396189 0.16662912 0.1055969  0.04060093 0.
 0.        ]
[0.46321116 0.22396189 0.16662912 0.1055969  0.04060093 0.
 0.        ]


 72%|███████▏  | 36/50 [09:39<03:45, 16.14s/it]

[0.47360167 0.22492264 0.16531948 0.10186677 0.03428945 0.
 0.        ]
[0.47360167 0.22492264 0.16531948 0.10186677 0.03428945 0.
 0.        ]


 74%|███████▍  | 37/50 [09:55<03:30, 16.20s/it]

[0.48417207 0.22590086 0.16398759 0.09807186 0.02786762 0.
 0.        ]
[0.48417207 0.22590086 0.16398759 0.09807186 0.02786762 0.
 0.        ]


 76%|███████▌  | 38/50 [10:13<03:19, 16.63s/it]

[0.49493086 0.2268974  0.16263242 0.0942091  0.02133021 0.
 0.        ]
[0.49493086 0.2268974  0.16263242 0.0942091  0.02133021 0.
 0.        ]


 78%|███████▊  | 39/50 [10:30<03:02, 16.62s/it]

[0.50588617 0.22791306 0.16125298 0.09027558 0.01467221 0.
 0.        ]
[0.50588617 0.22791306 0.16125298 0.09027558 0.01467221 0.
 0.        ]


 80%|████████  | 40/50 [10:46<02:46, 16.61s/it]

[0.51704754 0.2289488  0.15984809 0.08626784 0.00788773 0.
 0.        ]
[0.51704754 0.2289488  0.15984809 0.08626784 0.00788773 0.
 0.        ]


 82%|████████▏ | 41/50 [11:02<02:27, 16.42s/it]

[0.52842542 0.23000568 0.1584165  0.08218211 0.00097029 0.
 0.        ]
[0.52842542 0.23000568 0.1584165  0.08218211 0.00097029 0.
 0.        ]


 84%|████████▍ | 42/50 [11:18<02:10, 16.26s/it]

[0.53850792 0.22956313 0.15543552 0.07649343 0.         0.
 0.        ]
[0.53850792 0.22956313 0.15543552 0.07649343 0.         0.
 0.        ]


 86%|████████▌ | 43/50 [11:34<01:52, 16.13s/it]

[0.5484972  0.22886812 0.1521631  0.07047158 0.         0.
 0.        ]
[0.5484972  0.22886812 0.1521631  0.07047158 0.         0.
 0.        ]


 88%|████████▊ | 44/50 [11:50<01:37, 16.18s/it]

[0.55863248 0.22816345 0.14884293 0.06436114 0.         0.
 0.        ]
[0.55863248 0.22816345 0.14884293 0.06436114 0.         0.
 0.        ]


 90%|█████████ | 45/50 [12:07<01:21, 16.24s/it]

[0.56892    0.22744875 0.14547296 0.05815829 0.         0.
 0.        ]
[0.56892    0.22744875 0.14547296 0.05815829 0.         0.
 0.        ]


 92%|█████████▏| 46/50 [12:23<01:05, 16.31s/it]

[0.57936492 0.22672365 0.14205151 0.05185992 0.         0.
 0.        ]
[0.57936492 0.22672365 0.14205151 0.05185992 0.         0.
 0.        ]


 94%|█████████▍| 47/50 [12:39<00:48, 16.23s/it]

[0.58997379 0.22598778 0.13857644 0.04546199 0.         0.
 0.        ]
[0.58997379 0.22598778 0.13857644 0.04546199 0.         0.
 0.        ]


 96%|█████████▌| 48/50 [12:55<00:32, 16.19s/it]

[0.60075387 0.22524063 0.13504538 0.03896012 0.         0.
 0.        ]
[0.60075387 0.22524063 0.13504538 0.03896012 0.         0.
 0.        ]


 98%|█████████▊| 49/50 [13:11<00:16, 16.18s/it]

[0.61171134 0.22448184 0.1314563  0.03235053 0.         0.
 0.        ]
[0.61171134 0.22448184 0.1314563  0.03235053 0.         0.
 0.        ]


100%|██████████| 50/50 [13:28<00:00, 16.16s/it]

[0.6228549  0.22371085 0.12780636 0.02562789 0.         0.
 0.        ]





y_pred [-1  1  1 ... -1  1 -1]
0.9011689691817216
ypre-svm_T [-1  1  1 ... -1  1 -1]
0.8804463336875664
ypre-svm_A [-1  1  1 ... -1  1 -1]
0.8767268862911796
ypre-svm_AT [-1  1  1 ... -1  1 -1]
0.9017003188097769
