# **TriTraining Model Theory**


## Theory
Tri-Training is a semi-supervised learning algorithm that involves training three classifiers independently and refining them iteratively through mutual agreement. Each classifier is trained on a different bootstrap sample of the labeled data and enhances itself using the predictions of the other two classifiers.

The main idea is to:
- Train three classifiers on different bootstrap samples.
- Use the classifiers to predict labels for the unlabeled data.
- Update each classifier with the confident predictions agreed upon by the other two classifiers.
- Repeat the process iteratively to improve the classifiers.

## Tri-Training Process
1. **Initialization**:
- Create three bootstrap samples from the labeled dataset.
- Train a classifier on each bootstrap sample.

2. **Iteration**:
- Use each classifier to predict labels for the unlabeled data.
- Identify the predictions where two classifiers agree.
- Update the third classifier with the agreed-upon predictions.
- Repeat until convergence or a stopping criterion is met.

## Key Steps
1. **Bootstrap Sampling**:
- Generate three different bootstrap samples from the labeled dataset.
- Train a separate classifier on each bootstrap sample.

2. **Label Prediction**:
- Each classifier predicts labels for the unlabeled data independently.

3. **Agreement-Based Selection**:
- Identify the predictions where two classifiers agree.
- Use these agreed-upon predictions to update the third classifier.

4. **Classifier Update**:
- Incorporate the new labeled data into the training set of the third classifier.
- Retrain the classifier with the updated training set.

5. **Convergence**:
- Repeat the process until the classifiers' performance stabilizes or the maximum number of iterations is reached.

## Mathematical Formulation
1. **Bootstrap Sampling**:
- Create three bootstrap samples \( D_1, D_2, D_3 \) from the labeled dataset \( D \).

2. **Agreement-Based Labeling**:
- For each classifier \( C_i \), update its training set \( D_i \) with the predictions agreed upon by the other two classifiers \( C_j \) and \( C_k \):
$$ D_i = D_i \cup \{(x, y) \mid C_j(x) = C_k(x) = y \text{ and } C_i(x) \neq y \} $$

3. **Retraining**:
- Retrain classifier \( C_i \) with the updated training set \( D_i \).

## Advantages
- Leverages both labeled and unlabeled data.
- Reduces the risk of incorrect label propagation through mutual agreement.
- Can improve classification performance with limited labeled data.

## Applications
- Text classification.
- Image recognition.
- Any domain with a large amount of unlabeled data.



## Model Evaluation for Tri-Training Classifier

### 1. Accuracy Score
Formula:
$$
\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}
$$
Description:
- Accuracy measures the ratio of correct predictions to total predictions.
- Commonly used as a primary metric for balanced datasets.
Interpretation:
- Higher accuracy indicates better overall performance.
- Limitations:
  - May not be suitable for imbalanced datasets.
  - Should be used alongside other metrics for comprehensive evaluation.
---

### 2. Gini Impurity
Formula:
$$
\text{Gini} = 1 - \sum_{i=1}^{c} (p_i)^2
$$
Description:
- Gini Impurity measures the probability of incorrect classification of a randomly chosen element.
- Used as a splitting criterion during tree construction.
Interpretation:
- Ranges from 0 (pure node) to 0.5 (maximum impurity for binary classification).
- Lower values indicate better class separation.
---

### 3. Information Gain
Formula:
$$
\text{IG}(T,a) = H(T) - \sum_{v \in \text{values}(a)} \frac{|T_v|}{|T|} H(T_v)
$$
Description:
- Information Gain measures the reduction in entropy after splitting on an attribute.
- Alternative splitting criterion to Gini impurity.
Interpretation:
- Higher values indicate more informative splits.
- Used to select the best features for splitting nodes.
---

### 4. Model Complexity Metrics
Description:
- Number of Iterations: The number of iterations required for the algorithm to converge.
- Number of Label Changes: The total number of label changes during the training process.
Interpretation:
- Lower complexity often indicates better generalization.
- Used for tuning hyperparameters and improving model efficiency.
---

### 5. Precision
Formula:
$$
\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
$$
Description:
- Precision shows the accuracy of positive predictions.
- Important when false positives are costly.
Interpretation:
- Higher precision means fewer false positive predictions.
- Use case: Particularly important in medical diagnosis and spam detection.
---

### 6. Recall (Sensitivity)
Formula:
$$
\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
$$
Description:
- Recall indicates the model's ability to identify all relevant cases.
- Critical in scenarios where missing positive cases is costly.
Interpretation:
- Higher recall means fewer false negatives.
- Use case: Essential in medical screening and fraud detection.
---

### 7. Feature Importance
Formula:
$$
\text{Importance}(x_i) = \sum_{t \in \text{splits on }x_i} n_t \cdot \Delta\text{impurity}
$$
Description:
- Measures the contribution of each feature to the model's decisions.
- Based on the total reduction in impurity from splits on each feature.
Interpretation:
- Higher values indicate more influential features.
- Useful for feature selection and model understanding.
---

### 8. Cross-Validation Scores
Description:
- K-fold cross-validation provides robust performance estimates.
- Includes metrics for each fold and their statistical distribution.
Interpretation:
- Low variance across folds indicates stable model performance.
- High variance may suggest overfitting or data inconsistencies.
---

### 9. Confusion Matrix
Description:
- Provides detailed breakdown of prediction outcomes:
  - True Positives (TP)
  - True Negatives (TN)
  - False Positives (FP)
  - False Negatives (FN)
Interpretation:
- Helps identify specific types of errors.
- Essential for understanding class-wise performance.
---


## tritraining template [TriTrainingClassifier](https://github.com/lyttonhao/Tri-Training)

### class tritraining.TriTrainingClassifier(*, base_estimator=None, max_iter=10)

| **Parameter**               | **Description**                                                                                                                                        | **Default**      |
|----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|------------------|
| `base_estimator`           | The base estimator to be used for tri-training.                                                                                                       | `None`           |
| `max_iter`                 | The maximum number of tri-training iterations.                                                                                                        | `10`             |

-

| **Attribute**              | **Description**                                                                                                                                        |
|----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| `base_estimators_`         | The base estimators clones.                                                                                                                           |
| `transductions_`           | The predicted labels for the input data.                                                                                                              |
| `n_iter_`                  | The number of iterations run.                                                                                                                          |

-

| **Method**                 | **Description**                                                                                                                                        |
|----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| `fit(X, y)`                | Fit the tri-training classifier from the training set.                                                                                                |
| `predict(X)`               | Predict class for X.                                                                                                                                  |
| `predict_proba(X)`         | Predict class probabilities of the input samples X.                                                                                                   |
| `score(X, y)`              | Returns the mean accuracy on the given test data and labels.                                                                                           |
| `get_params()`             | Get parameters for this estimator.                                                                                                                     |
| `set_params(**params)`     | Set the parameters of this estimator.                                                                                                                  |


# XXXXXXXX regression - Example

## Data loading

In [1]:
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report

# Step 1: Data Import
digits = datasets.load_digits()
X = digits.data
y = digits.target

# Step 2: Data Processing
# Standardizing the data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Create a partially labeled dataset
rng = np.random.RandomState(42)
random_unlabeled_points = rng.rand(len(y)) < 0.7
y[random_unlabeled_points] = -1  # Label some points as -1 (unlabeled)

# Step 3: Model Definition and Training
# Define three base classifiers
clf1 = RandomForestClassifier(n_estimators=100)
clf2 = SVC(probability=True, gamma='scale')
clf3 = KNeighborsClassifier(n_neighbors=5)

# Define self-training classifiers
self_training_clf1 = SelfTrainingClassifier(clf1)
self_training_clf2 = SelfTrainingClassifier(clf2)
self_training_clf3 = SelfTrainingClassifier(clf3)

# Train classifiers
self_training_clf1.fit(X_scaled, y)
self_training_clf2.fit(X_scaled, y)
self_training_clf3.fit(X_scaled, y)

# Step 4: Model Evaluation
# Predict with each classifier
y_pred1 = self_training_clf1.predict(X_scaled)
y_pred2 = self_training_clf2.predict(X_scaled)
y_pred3 = self_training_clf3.predict(X_scaled)

# Combine predictions (majority vote)
y_pred = np.array([y_pred1, y_pred2, y_pred3])
y_pred_combined = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=y_pred)

y_true = digits.target

# Evaluating the model
accuracy = accuracy_score(y_true, y_pred_combined)
report = classification_report(y_true, y_pred_combined)

print(f"Accuracy: {accuracy * 100:.2f}%")
print("Classification Report:")
print(report)


Accuracy: 29.44%
Classification Report:
              precision    recall  f1-score   support

          -1       0.00      0.00      0.00      1266
           0       0.24      1.00      0.38        42
           1       0.25      1.00      0.40        46
           2       0.33      1.00      0.49        58
           3       0.28      1.00      0.43        49
           4       0.30      1.00      0.46        53
           5       0.33      1.00      0.50        62
           6       0.27      1.00      0.43        49
           7       0.34      1.00      0.50        61
           8       0.30      0.98      0.46        54
           9       0.31      0.98      0.48        57

    accuracy                           0.29      1797
   macro avg       0.27      0.91      0.41      1797
weighted avg       0.09      0.29      0.14      1797



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
