
# Final layer updation of a pre-trained model, if new classes come in the target dataset



In [4]:
!pip install catboost

Collecting ace_tools
  Downloading ace_tools-0.0-py3-none-any.whl.metadata (300 bytes)
Downloading ace_tools-0.0-py3-none-any.whl (1.1 kB)
Installing collected packages: ace_tools
Successfully installed ace_tools-0.0


In [24]:
import numpy as np
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import accuracy_score, log_loss
from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import pandas as pd


In [25]:

# Load the MNIST dataset and filter for the incremental setup
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist.data.to_numpy(), mnist.target.astype(int)

# Split into initial classes (0-7) and increment classes (8 and 9)
initial_indices = np.where(y < 8)[0]
increment_indices = np.where((y == 8) | (y == 9))[0]

# Create train-test splits for the initial set (0-7)
X_initial, X_test_initial, y_initial, y_test_initial = train_test_split(
    X[initial_indices], y[initial_indices], test_size=0.2, random_state=42
)

# Prepare the data for the incremental set (8, 9)
X_increment, X_test_increment, y_increment, y_test_increment = train_test_split(
    X[increment_indices], y[increment_indices], test_size=0.2, random_state=42
)


# Create a test set with all classes (0-9)
X_train_full, X_test_full, y_train_full, y_test_full = train_test_split(X, y, test_size=0.2, random_state=42)


In [26]:

# Initialize CatBoost model
catboost_model = CatBoostClassifier(iterations=100, learning_rate=0.1, depth=5, verbose=10)

# Train the model on the initial set (0-7)
catboost_model.fit(X_initial, y_initial)


<catboost.core.CatBoostClassifier at 0x7df39c269120>

In [27]:

# Evaluate on the initial test set (0-7)
y_pred_initial = catboost_model.predict(X_test_initial)
accuracy_initial = accuracy_score(y_test_initial, y_pred_initial)
loss_initial = log_loss(y_test_initial, catboost_model.predict_proba(X_test_initial))


print("Initial Model Evaluation (Classes 0-7):")
print(f"Accuracy: {accuracy_initial:.4f}")
print(f"Log Loss: {loss_initial:.4f}")


Initial Model Evaluation (Classes 0-7):
Accuracy: 0.9588
Log Loss: 0.1759


In [28]:
# Combine initial and incremental data
X_combined = np.vstack([X_initial, X_increment])
y_combined = np.concatenate([y_initial, y_increment])

# Retrain CatBoost model on the combined dataset
catboost_model_combined = CatBoostClassifier(iterations=100, learning_rate=0.1, depth=5, verbose=10)
catboost_model_combined.fit(X_combined, y_combined)


<catboost.core.CatBoostClassifier at 0x7df389524430>

In [32]:

# Evaluate on the initial test set (0-7)
y_pred_initial = catboost_model.predict(X_test_initial)
accuracy_initial = accuracy_score(y_test_initial, y_pred_initial)
loss_initial = log_loss(y_test_initial, catboost_model.predict_proba(X_test_initial))


print("Initial Model Evaluation after increment learning (Classes 0-7):")
print(f"Accuracy: {accuracy_initial:.4f}")
print(f"Log Loss: {loss_initial:.4f}")


Initial Model Evaluation after increment learning (Classes 0-7):
Accuracy: 0.9588
Log Loss: 0.1759


In [29]:
# Evaluate on the combined test set including new classes
X_test_combined = np.vstack([X_test_initial, X_test_increment])
y_test_combined = np.concatenate([y_test_initial, y_test_increment])

y_pred_combined = catboost_model_combined.predict(X_test_combined)
accuracy_combined = accuracy_score(y_test_combined, y_pred_combined)
loss_combined = log_loss(y_test_combined, catboost_model_combined.predict_proba(X_test_combined))



print("\nAfter Adding Classes (8, 9):")
print(f"Accuracy: {accuracy_combined:.4f}")
print(f"Log Loss: {loss_combined:.4f}")


After Adding Classes (8, 9):
Accuracy: 0.9375
Log Loss: 0.2654


In [30]:
# Evaluate on the full test set (all classes 0-9)
y_pred_full = catboost_model_combined.predict(X_test_full)
accuracy_full = accuracy_score(y_test_full, y_pred_full)
loss_full = log_loss(y_test_full, catboost_model_combined.predict_proba(X_test_full))

print("\nOverall Model Evaluation (All Classes 0-9):")
print(f"Total Accuracy: {accuracy_full:.4f}")
print(f"Total Log Loss: {loss_full:.4f}")


Overall Model Evaluation (All Classes 0-9):
Total Accuracy: 0.9431
Total Log Loss: 0.2475
