In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt


In [2]:

# Define the feature extractor (f)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 7 * 7, 128)  # Assuming 28x28 -> 7x7 after pooling

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Define the classification head (θ)
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Synthesizer: A simple generator for model inversion
class Synthesizer(nn.Module):
    def __init__(self, feature_dim, image_shape):
        super(Synthesizer, self).__init__()
        self.fc = nn.Linear(feature_dim, np.prod(image_shape))
        self.image_shape = image_shape

    def forward(self, features):
        images = self.fc(features).view(-1, *self.image_shape)
        return torch.sigmoid(images)  # Generate synthetic images


In [3]:
def create_mnist_data(num_classes, task_index,csv_file="mnist.csv", image_shape=(1, 28, 28)):
    mnist_file = open(csv_file)
    df = pd.read_csv(mnist_file)

    # Determine the classes for this task
    all_classes = sorted(df['label'].unique())  # Get all unique classes
    start_class = task_index * num_classes      # Starting class for this task
    end_class = start_class + num_classes       # Ending class for this task
    valid_classes = all_classes[start_class:end_class]

    # Filter the dataset for the valid classes
    filtered_df = df[df['label'].isin(valid_classes)]

    # Separate features and labels
    X = filtered_df.drop(columns=['label']).values  # Features (flattened images)
    y = filtered_df['label'].values                # Labels

    # Normalize and reshape the images
    X = torch.tensor(X / 255.0, dtype=torch.float32)  # Normalize to [0, 1]
    X = X.view(-1, *image_shape)  # Reshape to (batch, 1, 28, 28)
    y = torch.tensor(y, dtype=torch.long)            # Convert labels to tensors

    return TensorDataset(X, y)

In [14]:

# Knowledge distillation loss
def knowledge_distillation_loss(pred, target, temperature=2):
    soft_pred = F.log_softmax(pred / temperature, dim=1)
    soft_target = F.softmax(target / temperature, dim=1)
    return F.kl_div(soft_pred, soft_target, reduction="batchmean") * (temperature ** 2)

from tqdm import tqdm

# Training loop with progress bars
def train_r_dfcil(task_data, num_classes_per_task, feature_dim=128, image_shape=(1, 28, 28), epochs=5):
    feature_extractor = FeatureExtractor()
    classification_head = ClassificationHead(feature_dim, num_classes_per_task[0])
    synthesizer = Synthesizer(feature_dim, image_shape)

    old_model = None
    optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classification_head.parameters()), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for task_idx, (train_data, test_data) in enumerate(task_data):
        print(f"\nTraining on Task {task_idx + 1}/{len(task_data)}")
        train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
        num_classes = sum(num_classes_per_task[:task_idx + 1])

        if task_idx > 0:
            # Expand the classification head for new classes
            old_classification_head = classification_head
            classification_head = ClassificationHead(feature_dim, num_classes)
            classification_head.fc.weight.data[:old_classification_head.fc.out_features] = \
                old_classification_head.fc.weight.data
            classification_head.fc.bias.data[:old_classification_head.fc.out_features] = \
                old_classification_head.fc.bias.data

        # Optimizer for expanded model
        optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classification_head.parameters()), lr=0.001)

        # Train synthesizer using the old model
        if old_model is not None:
            print("Training synthesizer...")
            synth_optimizer = optim.Adam(synthesizer.parameters(), lr=0.001)
            for epoch in tqdm(range(epochs), desc="Synthesizer Training", leave=False):
                features = torch.randn(32, feature_dim)  # Random latent features
                synthetic_images = synthesizer(features)
                synth_optimizer.zero_grad()
                features = old_model[0](synthetic_images)
                preds = old_model[1](features)
                loss = criterion(preds, torch.randint(0, num_classes - num_classes_per_task[task_idx], (32,)))
                loss.backward()
                synth_optimizer.step()

        # Train new task
        print("Training model...")
        for epoch in tqdm(range(epochs), desc=f"Task {task_idx + 1} Training", leave=False):
            feature_extractor.train()
            classification_head.train()
            epoch_loss = 0
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False):
                images, labels = images, labels
                optimizer.zero_grad()

                # Forward pass
                features = feature_extractor(images)
                outputs = classification_head(features)

                # Loss calculation
                loss = criterion(outputs, labels)

                # Add knowledge distillation loss if old model exists
                if old_model is not None:
                    with torch.no_grad():
                        old_features = old_model[0](images)
                        old_outputs = old_model[1](old_features)
                    loss += knowledge_distillation_loss(outputs[:, :num_classes - num_classes_per_task[task_idx]],
                                                        old_outputs)

                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            tqdm.write(f"Epoch {epoch + 1}/{epochs} Loss: {epoch_loss / len(train_loader):.4f}")

        # Freeze old model
        old_model = (feature_extractor, classification_head)
        print(f"Finished Task {task_idx + 1}/{len(task_data)}")
    return feature_extractor, classification_head, synthesizer



In [16]:
num_classes_per_task = [5, 5]  # Task 1: 5 classes, Task 2: 5 classes
image_shape = (1, 28, 28)

# Load MNIST datasets for tasks
task1_data = create_mnist_data(num_classes_per_task[0], task_index=0, image_shape=image_shape)
task2_data = create_mnist_data(num_classes_per_task[1], task_index=1, image_shape=image_shape)

# Prepare task data
task_data = [(task1_data, task1_data), (task2_data, task2_data)]  # (train, test)

# Train R-DFCIL
feature_extractor, classification_head, synthesizer = train_r_dfcil(task_data, num_classes_per_task, image_shape=image_shape)


Training on Task 1/2
Training model...


Task 1 Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1/5:   0%|          | 0/670 [00:00<?, ?it/s][A
Epoch 1/5:   0%|          | 1/670 [00:00<01:07,  9.96it/s][A
Epoch 1/5:   0%|          | 2/670 [00:00<01:30,  7.36it/s][A
Epoch 1/5:   0%|          | 3/670 [00:00<01:41,  6.54it/s][A
Epoch 1/5:   1%|          | 4/670 [00:00<01:38,  6.75it/s][A
Epoch 1/5:   1%|          | 5/670 [00:00<01:39,  6.67it/s][A
Epoch 1/5:   1%|          | 6/670 [00:00<01:33,  7.11it/s][A
Epoch 1/5:   1%|          | 8/670 [00:00<01:06,  9.94it/s][A
Epoch 1/5:   1%|▏         | 10/670 [00:01<00:56, 11.59it/s][A
Epoch 1/5:   2%|▏         | 12/670 [00:01<00:55, 11.79it/s][A
Epoch 1/5:   2%|▏         | 14/670 [00:01<00:54, 11.98it/s][A
Epoch 1/5:   2%|▏         | 16/670 [00:01<00:55, 11.68it/s][A
Epoch 1/5:   3%|▎         | 18/670 [00:01<01:04, 10.11it/s][A
Epoch 1/5:   3%|▎         | 20/670 [00:02<01:07,  9.56it/s][A
Epoch 1/5:   3%|▎         | 21/670 [00:02<01:09,  9.29it/s][A
Epoch 1/5:   3%|

Epoch 1/5 Loss: 0.0800



Epoch 2/5:   0%|          | 0/670 [00:00<?, ?it/s][A
Epoch 2/5:   0%|          | 3/670 [00:00<00:25, 26.67it/s][A
Epoch 2/5:   1%|          | 6/670 [00:00<00:24, 26.78it/s][A
Epoch 2/5:   1%|▏         | 9/670 [00:00<00:25, 26.41it/s][A
Epoch 2/5:   2%|▏         | 12/670 [00:00<00:24, 26.87it/s][A
Epoch 2/5:   2%|▏         | 15/670 [00:00<00:23, 27.33it/s][A
Epoch 2/5:   3%|▎         | 18/670 [00:00<00:23, 27.47it/s][A
Epoch 2/5:   3%|▎         | 21/670 [00:00<00:23, 27.38it/s][A
Epoch 2/5:   4%|▎         | 24/670 [00:00<00:23, 27.41it/s][A
Epoch 2/5:   4%|▍         | 27/670 [00:01<00:25, 24.79it/s][A
Epoch 2/5:   4%|▍         | 30/670 [00:01<00:24, 25.63it/s][A
Epoch 2/5:   5%|▍         | 33/670 [00:01<00:24, 26.43it/s][A
Epoch 2/5:   5%|▌         | 36/670 [00:01<00:24, 25.88it/s][A
Epoch 2/5:   6%|▌         | 39/670 [00:01<00:24, 26.19it/s][A
Epoch 2/5:   6%|▋         | 42/670 [00:01<00:23, 26.29it/s][A
Epoch 2/5:   7%|▋         | 45/670 [00:01<00:23, 26.58it/s][A
Epo

Epoch 2/5 Loss: 0.0257



Epoch 3/5:   0%|          | 0/670 [00:00<?, ?it/s][A
Epoch 3/5:   0%|          | 3/670 [00:00<00:26, 24.82it/s][A
Epoch 3/5:   1%|          | 6/670 [00:00<00:25, 25.84it/s][A
Epoch 3/5:   1%|▏         | 9/670 [00:00<00:24, 26.46it/s][A
Epoch 3/5:   2%|▏         | 12/670 [00:00<00:24, 26.68it/s][A
Epoch 3/5:   2%|▏         | 15/670 [00:00<00:25, 25.67it/s][A
Epoch 3/5:   3%|▎         | 18/670 [00:00<00:26, 25.02it/s][A
Epoch 3/5:   3%|▎         | 21/670 [00:00<00:25, 25.75it/s][A
Epoch 3/5:   4%|▎         | 24/670 [00:00<00:24, 26.40it/s][A
Epoch 3/5:   4%|▍         | 27/670 [00:01<00:23, 26.83it/s][A
Epoch 3/5:   4%|▍         | 30/670 [00:01<00:23, 27.22it/s][A
Epoch 3/5:   5%|▍         | 33/670 [00:01<00:24, 26.54it/s][A
Epoch 3/5:   5%|▌         | 36/670 [00:01<00:23, 26.68it/s][A
Epoch 3/5:   6%|▌         | 39/670 [00:01<00:23, 26.90it/s][A
Epoch 3/5:   6%|▋         | 42/670 [00:01<00:25, 25.06it/s][A
Epoch 3/5:   7%|▋         | 45/670 [00:01<00:24, 25.53it/s][A
Epo

Epoch 3/5 Loss: 0.0154



Epoch 4/5:   0%|          | 0/670 [00:00<?, ?it/s][A
Epoch 4/5:   0%|          | 3/670 [00:00<00:26, 25.42it/s][A
Epoch 4/5:   1%|          | 6/670 [00:00<00:25, 25.91it/s][A
Epoch 4/5:   1%|▏         | 9/670 [00:00<00:24, 26.44it/s][A
Epoch 4/5:   2%|▏         | 12/670 [00:00<00:24, 26.62it/s][A
Epoch 4/5:   2%|▏         | 15/670 [00:00<00:25, 25.81it/s][A
Epoch 4/5:   3%|▎         | 18/670 [00:00<00:25, 26.07it/s][A
Epoch 4/5:   3%|▎         | 21/670 [00:00<00:25, 25.03it/s][A
Epoch 4/5:   4%|▎         | 24/670 [00:00<00:26, 24.80it/s][A
Epoch 4/5:   4%|▍         | 27/670 [00:01<00:25, 25.44it/s][A
Epoch 4/5:   4%|▍         | 30/670 [00:01<00:25, 25.44it/s][A
Epoch 4/5:   5%|▍         | 33/670 [00:01<00:24, 25.80it/s][A
Epoch 4/5:   5%|▌         | 36/670 [00:01<00:24, 26.28it/s][A
Epoch 4/5:   6%|▌         | 39/670 [00:01<00:23, 26.67it/s][A
Epoch 4/5:   6%|▋         | 42/670 [00:01<00:24, 26.01it/s][A
Epoch 4/5:   7%|▋         | 45/670 [00:01<00:27, 22.67it/s][A
Epo

Epoch 4/5 Loss: 0.0100



Epoch 5/5:   0%|          | 0/670 [00:00<?, ?it/s][A
Epoch 5/5:   0%|          | 3/670 [00:00<00:26, 25.00it/s][A
Epoch 5/5:   1%|          | 6/670 [00:00<00:25, 26.37it/s][A
Epoch 5/5:   1%|▏         | 9/670 [00:00<00:24, 26.72it/s][A
Epoch 5/5:   2%|▏         | 12/670 [00:00<00:25, 25.51it/s][A
Epoch 5/5:   2%|▏         | 15/670 [00:00<00:27, 23.80it/s][A
Epoch 5/5:   3%|▎         | 18/670 [00:00<00:26, 24.76it/s][A
Epoch 5/5:   3%|▎         | 21/670 [00:00<00:25, 25.06it/s][A
Epoch 5/5:   4%|▎         | 24/670 [00:00<00:24, 25.86it/s][A
Epoch 5/5:   4%|▍         | 27/670 [00:01<00:27, 23.61it/s][A
Epoch 5/5:   4%|▍         | 30/670 [00:01<00:29, 21.64it/s][A
Epoch 5/5:   5%|▍         | 33/670 [00:01<00:31, 20.49it/s][A
Epoch 5/5:   5%|▌         | 36/670 [00:01<00:32, 19.35it/s][A
Epoch 5/5:   6%|▌         | 38/670 [00:01<00:34, 18.37it/s][A
Epoch 5/5:   6%|▌         | 40/670 [00:01<00:34, 18.06it/s][A
Epoch 5/5:   6%|▋         | 42/670 [00:01<00:34, 18.10it/s][A
Epo

Epoch 5/5 Loss: 0.0082
Finished Task 1/2

Training on Task 2/2
Training synthesizer...




Training model...


Task 2 Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1/5:   0%|          | 0/644 [00:00<?, ?it/s][A
Epoch 1/5:   0%|          | 2/644 [00:00<00:55, 11.58it/s][A
Epoch 1/5:   1%|          | 4/644 [00:00<00:52, 12.13it/s][A
Epoch 1/5:   1%|          | 6/644 [00:00<00:52, 12.16it/s][A
Epoch 1/5:   1%|          | 8/644 [00:00<00:51, 12.28it/s][A
Epoch 1/5:   2%|▏         | 10/644 [00:00<00:50, 12.58it/s][A
Epoch 1/5:   2%|▏         | 12/644 [00:00<00:50, 12.58it/s][A
Epoch 1/5:   2%|▏         | 14/644 [00:01<00:51, 12.24it/s][A
Epoch 1/5:   2%|▏         | 16/644 [00:01<00:50, 12.52it/s][A
Epoch 1/5:   3%|▎         | 18/644 [00:01<00:48, 12.89it/s][A
Epoch 1/5:   3%|▎         | 20/644 [00:01<00:47, 13.15it/s][A
Epoch 1/5:   3%|▎         | 22/644 [00:01<00:47, 13.19it/s][A
Epoch 1/5:   4%|▎         | 24/644 [00:01<00:47, 13.13it/s][A
Epoch 1/5:   4%|▍         | 26/644 [00:02<00:49, 12.46it/s][A
Epoch 1/5:   4%|▍         | 28/644 [00:02<00:49, 12.57it/s][A
Epoch 1/5:   

Epoch 1/5 Loss: 0.1885



Epoch 2/5:   0%|          | 0/644 [00:00<?, ?it/s][A
Epoch 2/5:   0%|          | 2/644 [00:00<00:43, 14.64it/s][A
Epoch 2/5:   1%|          | 4/644 [00:00<00:37, 16.85it/s][A
Epoch 2/5:   1%|          | 6/644 [00:00<00:35, 17.98it/s][A
Epoch 2/5:   1%|          | 8/644 [00:00<00:34, 18.65it/s][A
Epoch 2/5:   2%|▏         | 10/644 [00:00<00:33, 18.74it/s][A
Epoch 2/5:   2%|▏         | 12/644 [00:00<00:33, 19.00it/s][A
Epoch 2/5:   2%|▏         | 14/644 [00:00<00:34, 18.34it/s][A
Epoch 2/5:   2%|▏         | 16/644 [00:00<00:33, 18.70it/s][A
Epoch 2/5:   3%|▎         | 18/644 [00:00<00:35, 17.87it/s][A
Epoch 2/5:   3%|▎         | 20/644 [00:01<00:37, 16.58it/s][A
Epoch 2/5:   3%|▎         | 22/644 [00:01<00:35, 17.43it/s][A
Epoch 2/5:   4%|▎         | 24/644 [00:01<00:34, 18.08it/s][A
Epoch 2/5:   4%|▍         | 27/644 [00:01<00:32, 19.16it/s][A
Epoch 2/5:   5%|▍         | 29/644 [00:01<00:32, 19.14it/s][A
Epoch 2/5:   5%|▍         | 31/644 [00:01<00:31, 19.29it/s][A
Epoc

Epoch 2/5 Loss: 0.0339



Epoch 3/5:   0%|          | 0/644 [00:00<?, ?it/s][A
Epoch 3/5:   0%|          | 2/644 [00:00<00:38, 16.78it/s][A
Epoch 3/5:   1%|          | 4/644 [00:00<00:34, 18.42it/s][A
Epoch 3/5:   1%|          | 6/644 [00:00<00:33, 18.85it/s][A
Epoch 3/5:   1%|          | 8/644 [00:00<00:33, 18.97it/s][A
Epoch 3/5:   2%|▏         | 10/644 [00:00<00:36, 17.27it/s][A
Epoch 3/5:   2%|▏         | 12/644 [00:00<00:36, 17.32it/s][A
Epoch 3/5:   2%|▏         | 14/644 [00:00<00:36, 17.29it/s][A
Epoch 3/5:   2%|▏         | 16/644 [00:00<00:35, 17.93it/s][A
Epoch 3/5:   3%|▎         | 18/644 [00:01<00:35, 17.82it/s][A
Epoch 3/5:   3%|▎         | 21/644 [00:01<00:32, 18.95it/s][A
Epoch 3/5:   4%|▎         | 23/644 [00:01<00:32, 18.91it/s][A
Epoch 3/5:   4%|▍         | 26/644 [00:01<00:31, 19.53it/s][A
Epoch 3/5:   4%|▍         | 28/644 [00:01<00:33, 18.58it/s][A
Epoch 3/5:   5%|▍         | 30/644 [00:01<00:32, 18.88it/s][A
Epoch 3/5:   5%|▍         | 32/644 [00:01<00:32, 18.93it/s][A
Epoc

Epoch 3/5 Loss: 0.0221



Epoch 4/5:   0%|          | 0/644 [00:00<?, ?it/s][A
Epoch 4/5:   0%|          | 1/644 [00:00<01:09,  9.30it/s][A
Epoch 4/5:   0%|          | 3/644 [00:00<01:00, 10.60it/s][A
Epoch 4/5:   1%|          | 5/644 [00:00<00:56, 11.30it/s][A
Epoch 4/5:   1%|          | 7/644 [00:00<00:53, 11.93it/s][A
Epoch 4/5:   1%|▏         | 9/644 [00:00<00:52, 12.17it/s][A
Epoch 4/5:   2%|▏         | 11/644 [00:00<00:51, 12.29it/s][A
Epoch 4/5:   2%|▏         | 13/644 [00:01<00:51, 12.23it/s][A
Epoch 4/5:   2%|▏         | 15/644 [00:01<00:51, 12.11it/s][A
Epoch 4/5:   3%|▎         | 17/644 [00:01<00:50, 12.35it/s][A
Epoch 4/5:   3%|▎         | 19/644 [00:01<00:50, 12.47it/s][A
Epoch 4/5:   3%|▎         | 21/644 [00:01<00:49, 12.51it/s][A
Epoch 4/5:   4%|▎         | 23/644 [00:01<00:46, 13.40it/s][A
Epoch 4/5:   4%|▍         | 25/644 [00:01<00:42, 14.70it/s][A
Epoch 4/5:   4%|▍         | 27/644 [00:02<00:40, 15.20it/s][A
Epoch 4/5:   5%|▍         | 29/644 [00:02<00:37, 16.25it/s][A
Epoch

Epoch 4/5 Loss: 0.0139



Epoch 5/5:   0%|          | 0/644 [00:00<?, ?it/s][A
Epoch 5/5:   0%|          | 2/644 [00:00<00:39, 16.26it/s][A
Epoch 5/5:   1%|          | 4/644 [00:00<00:37, 16.96it/s][A
Epoch 5/5:   1%|          | 6/644 [00:00<00:35, 17.86it/s][A
Epoch 5/5:   1%|          | 8/644 [00:00<00:35, 17.82it/s][A
Epoch 5/5:   2%|▏         | 10/644 [00:00<00:37, 17.12it/s][A
Epoch 5/5:   2%|▏         | 12/644 [00:00<00:35, 17.75it/s][A
Epoch 5/5:   2%|▏         | 14/644 [00:00<00:34, 18.20it/s][A
Epoch 5/5:   2%|▏         | 16/644 [00:00<00:33, 18.59it/s][A
Epoch 5/5:   3%|▎         | 18/644 [00:00<00:33, 18.87it/s][A
Epoch 5/5:   3%|▎         | 20/644 [00:01<00:32, 18.96it/s][A
Epoch 5/5:   3%|▎         | 22/644 [00:01<00:33, 18.74it/s][A
Epoch 5/5:   4%|▎         | 24/644 [00:01<00:33, 18.44it/s][A
Epoch 5/5:   4%|▍         | 26/644 [00:01<00:33, 18.56it/s][A
Epoch 5/5:   4%|▍         | 28/644 [00:01<00:35, 17.19it/s][A
Epoch 5/5:   5%|▍         | 30/644 [00:01<00:34, 17.73it/s][A
Epoc

Epoch 5/5 Loss: 0.0124
Finished Task 2/2




In [10]:
def inference_on_model(feature_extractor, classification_head, loader, num_classes):
    feature_extractor.eval()
    classification_head.eval()

    all_preds = []
    all_labels = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    feature_extractor.to(device)
    classification_head.to(device)

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            features = feature_extractor(images)
            logits = classification_head(features)
            predictions = torch.argmax(logits, dim=1)

            # Collect results for evaluation
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Occasionally display predictions
            if batch_idx % 10 == 0:  # Adjust frequency as needed
                print(f"Batch {batch_idx} Predictions:")
                print(f"Predicted: {predictions.cpu().numpy()}")
                print(f"True:      {labels.cpu().numpy()}")

    # Compute accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"\nAccuracy: {accuracy * 100:.2f}%")

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=list(range(num_classes)),
                yticklabels=list(range(num_classes)))
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()


In [None]:
inference_on_model(fe)