### Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import random
from pathlib import Path

### Paths

In [2]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent
DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
MODEL_DIR = os.path.join(ROOT_DIR, 'models')
CBM_DIR = os.path.join(MODEL_DIR, 'cbm')
os.makedirs(CBM_DIR, exist_ok=True)


In [3]:
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')
CSV_CONCEPTS_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train_concepts.csv')

# Data Preperation

In [4]:
# Class mapping
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

# Concept mapping for each class
CONCEPT_MAPPING = {
    0: ["red", "convex", "scaly", "yes", "thin"],
    1: ["brown", "flat", "smooth", "no", "medium"],
    2: ["yellow", "convex", "warty", "yes", "thick"],
    3: ["white", "bulbous", "smooth", "no", "thin"],
    4: ["brown", "inverted", "scaly", "yes", "medium"],
    5: ["yellow", "flat", "smooth", "no", "thick"],
    6: ["red", "convex", "warty", "yes", "thin"],
    7: ["white", "flat", "smooth", "no", "medium"],
    8: ["brown", "bulbous", "scaly", "yes", "thick"],
    9: ["yellow", "inverted", "smooth", "no", "thin"]
}


In [5]:
df = pd.read_csv(CSV_PATH)

In [6]:
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]


In [7]:
concept_rows = []
for _, row in df.iterrows():
    class_label = row["Mushroom"]
    concepts = CONCEPT_MAPPING[class_label]
    concept_rows.append(concepts)

concept_df = pd.DataFrame(concept_rows, columns=concept_columns)
updated_df = pd.concat([df, concept_df], axis=1)

In [8]:
updated_df.to_csv(CSV_CONCEPTS_PATH, index=False)
print(f"Updated CSV saved to {CSV_CONCEPTS_PATH}")

Updated CSV saved to c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\dataset\csv_mappings\train_concepts.csv


# Training

### Classes & Functions

In [9]:
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()

        self.feature_extractor = models.resnet18(pretrained=True)

        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_concepts)

        for param in self.feature_extractor.fc.parameters():
            param.requires_grad = True

        self.classifier = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        for param in self.classifier.parameters():
            param.requires_grad = True

    def forward(self, x):
        concepts = self.feature_extractor(x)
        predictions = self.classifier(concepts)
        return concepts, predictions


In [10]:
class ConceptMushroomDataset(Dataset):
    def __init__(self, preprocessed_dir, csv_path, transform=None):
        self.preprocessed_dir = preprocessed_dir
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

        self.image_ids = self.csv_data['Image'].values
        self.labels = self.csv_data['Mushroom'].values

        concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
        self.concepts = pd.get_dummies(self.csv_data[concept_columns]).values.astype('float32')

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = str(self.image_ids[idx]).zfill(5)
        label = self.labels[idx]
        concepts = torch.tensor(self.concepts[idx], dtype=torch.float32)

        image_path = os.path.join(self.preprocessed_dir, f"{image_id}.pt")
        image = torch.load(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label, concepts


In [11]:

def get_data_loaders(preprocessed_dir, csv_path, batch_size=32):
    dataset = ConceptMushroomDataset(preprocessed_dir, csv_path)

    indices = list(range(len(dataset)))
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


In [12]:
def calculate_concept_accuracy(predicted_concepts, ground_truth):
    predicted_classes = torch.argmax(predicted_concepts, dim=1)
    true_classes = torch.argmax(ground_truth, dim=1)
    correct = (predicted_classes == true_classes).sum().item()
    return correct

In [13]:
def train_jointly(model, train_loader, val_loader, device, optimizer, epochs=10, lambda_concept=0.5, lambda_classification=0.5):
    criterion_concept = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for categorical concepts
    criterion_classification = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        train_total_loss = 0.0
        train_concept_loss = 0.0
        train_classification_loss = 0.0
        train_total = 0
        train_correct = 0
        train_concept_correct = 0

        for images, labels, concepts in tqdm(train_loader, desc=f"[Joint Training: Epoch {epoch+1}/{epochs}]"):
            images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

            optimizer.zero_grad()
            predicted_concepts, predictions = model(images)

            # Calculate concept loss (CrossEntropyLoss expects class indices)
            loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))  # Use argmax on one-hot encoded concepts
            loss_classification = criterion_classification(predictions, labels)
            loss = lambda_concept * loss_concept + lambda_classification * loss_classification

            loss.backward()
            optimizer.step()

            # Track losses and accuracy
            train_total_loss += loss.item()
            train_concept_loss += loss_concept.item()
            train_classification_loss += loss_classification.item()
            
            train_total += labels.size(0)
            train_correct += (predictions.argmax(dim=1) == labels).sum().item()

            # Calculate concept accuracy
            train_concept_correct += calculate_concept_accuracy(predicted_concepts, concepts)

        # Calculate accuracies
        train_accuracy = 100 * train_correct / train_total
        train_concept_accuracy = 100 * train_concept_correct / train_total

        # Calculate combined total accuracy as the average of concept and classification accuracies
        train_total_accuracy = (train_concept_accuracy + train_accuracy) / 2

        # Print stats for the current epoch
        print(f"\nTrain Epoch {epoch+1}/{epochs}:")
        print("-------------------")
        print(f"Concept Loss:          {train_concept_loss / len(train_loader):.4f} \t Concept Accuracy:         {train_concept_accuracy:.2f}%")
        print(f"Classification Loss:   {train_classification_loss / len(train_loader):.4f} \t Classification Accuracy:  {train_accuracy:.2f}%")
        print("---------------------------------------------------------")
        print(f"Total Loss:            {train_total_loss / len(train_loader):.4f} \t Total Accuracy:            {train_total_accuracy:.2f}%\n")

        # Validation phase
        model.eval()
        val_total_loss = 0.0
        val_concept_loss = 0.0
        val_classification_loss = 0.0
        val_total = 0
        val_correct = 0
        val_concept_correct = 0

        with torch.no_grad():
            for images, labels, concepts in val_loader:
                images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

                predicted_concepts, predictions = model(images)

                loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))
                loss_classification = criterion_classification(predictions, labels)
                loss = lambda_concept * loss_concept + lambda_classification * loss_classification

                val_total_loss += loss.item()
                val_concept_loss += loss_concept.item()
                val_classification_loss += loss_classification.item()
                val_total += labels.size(0)
                val_correct += (predictions.argmax(dim=1) == labels).sum().item()

                # Concept accuracy calculation
                val_concept_correct += calculate_concept_accuracy(predicted_concepts, concepts)

        # Calculate accuracies for validation
        val_accuracy = 100 * val_correct / val_total
        val_concept_accuracy = 100 * val_concept_correct / val_total

        # Calculate combined total accuracy for validation
        val_total_accuracy = (val_concept_accuracy + val_accuracy) / 2

        # Print stats for the validation set
        print(f"Validate Epoch {epoch+1}/{epochs}:")
        print("-------------------")
        print(f"Concept Loss:          {val_concept_loss / len(val_loader):.4f} \t Concept Accuracy:         {val_concept_accuracy:.2f}%")
        print(f"Classification Loss:   {val_classification_loss / len(val_loader):.4f} \t Classification Accuracy:  {val_accuracy:.2f}%")
        print("---------------------------------------------------------")
        print(f"Total Loss:            {val_total_loss / len(val_loader):.4f} \t Total Accuracy:            {val_total_accuracy:.2f}%\n")


In [14]:
def evaluate_cbm(model, test_loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc="[Evaluate]"):
            images, labels = images.to(device), labels.to(device)
            _, predictions = model(images)

            _, predicted_labels = torch.max(predictions, 1)
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = 100 * total_correct / total_samples
    print(f"Accuracy: {accuracy:.2f}%")


### Config

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
train_loader, val_loader, test_loader = get_data_loaders(PREPROCESSED_DIR, CSV_CONCEPTS_PATH)

In [17]:
num_concepts = len(pd.get_dummies(pd.read_csv(CSV_CONCEPTS_PATH)[["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]]).columns)
num_classes = len(pd.read_csv(CSV_CONCEPTS_PATH)['Mushroom'].unique())

In [21]:
model = ConceptBottleneckModel(num_concepts, num_classes).to(device)

### Training

In [22]:
optimizer = optim.AdamW(model.parameters(), lr=0.0005)

In [20]:
print("Starting Joint Training")
train_jointly(model, train_loader, val_loader, device, optimizer, epochs=20, lambda_concept=0.3, lambda_classification=0.7)

Starting Joint Training


  image = torch.load(image_path)
[Joint Training: Epoch 1/20]: 100%|██████████| 52/52 [00:33<00:00,  1.57it/s]



Train Epoch 1/20:
-------------------
Concept Loss:          1.5831 	 Concept Accuracy:         30.27%
Classification Loss:   2.2814 	 Classification Accuracy:  14.02%
---------------------------------------------------------
Total Loss:            2.0719 	 Total Accuracy:            14.02%

Validate Epoch 1/20:
-------------------
Concept Loss:          1.2807 	 Concept Accuracy:         39.72%
Classification Loss:   2.2182 	 Classification Accuracy:  19.15%
---------------------------------------------------------
Total Loss:            1.9369 	 Total Accuracy:            19.15%



[Joint Training: Epoch 2/20]: 100%|██████████| 52/52 [00:35<00:00,  1.47it/s]



Train Epoch 2/20:
-------------------
Concept Loss:          1.1972 	 Concept Accuracy:         48.40%
Classification Loss:   2.0995 	 Classification Accuracy:  33.53%
---------------------------------------------------------
Total Loss:            1.8288 	 Total Accuracy:            33.53%

Validate Epoch 2/20:
-------------------
Concept Loss:          1.0728 	 Concept Accuracy:         57.75%
Classification Loss:   1.9322 	 Classification Accuracy:  49.58%
---------------------------------------------------------
Total Loss:            1.6744 	 Total Accuracy:            49.58%



[Joint Training: Epoch 3/20]: 100%|██████████| 52/52 [00:42<00:00,  1.22it/s]



Train Epoch 3/20:
-------------------
Concept Loss:          1.0627 	 Concept Accuracy:         57.52%
Classification Loss:   1.7785 	 Classification Accuracy:  48.04%
---------------------------------------------------------
Total Loss:            1.5638 	 Total Accuracy:            48.04%

Validate Epoch 3/20:
-------------------
Concept Loss:          0.9520 	 Concept Accuracy:         63.38%
Classification Loss:   1.5948 	 Classification Accuracy:  52.39%
---------------------------------------------------------
Total Loss:            1.4020 	 Total Accuracy:            52.39%



[Joint Training: Epoch 4/20]: 100%|██████████| 52/52 [00:40<00:00,  1.29it/s]



Train Epoch 4/20:
-------------------
Concept Loss:          0.9646 	 Concept Accuracy:         63.93%
Classification Loss:   1.4443 	 Classification Accuracy:  56.86%
---------------------------------------------------------
Total Loss:            1.3004 	 Total Accuracy:            56.86%

Validate Epoch 4/20:
-------------------
Concept Loss:          0.8902 	 Concept Accuracy:         67.32%
Classification Loss:   1.3255 	 Classification Accuracy:  59.44%
---------------------------------------------------------
Total Loss:            1.1949 	 Total Accuracy:            59.44%



[Joint Training: Epoch 5/20]: 100%|██████████| 52/52 [00:43<00:00,  1.19it/s]



Train Epoch 5/20:
-------------------
Concept Loss:          0.9097 	 Concept Accuracy:         65.38%
Classification Loss:   1.2111 	 Classification Accuracy:  64.83%
---------------------------------------------------------
Total Loss:            1.1207 	 Total Accuracy:            64.83%

Validate Epoch 5/20:
-------------------
Concept Loss:          0.8445 	 Concept Accuracy:         67.32%
Classification Loss:   1.2096 	 Classification Accuracy:  61.13%
---------------------------------------------------------
Total Loss:            1.1000 	 Total Accuracy:            61.13%



[Joint Training: Epoch 6/20]: 100%|██████████| 52/52 [00:44<00:00,  1.18it/s]



Train Epoch 6/20:
-------------------
Concept Loss:          0.8736 	 Concept Accuracy:         66.53%
Classification Loss:   1.1038 	 Classification Accuracy:  66.47%
---------------------------------------------------------
Total Loss:            1.0347 	 Total Accuracy:            66.47%

Validate Epoch 6/20:
-------------------
Concept Loss:          0.8350 	 Concept Accuracy:         67.89%
Classification Loss:   1.1089 	 Classification Accuracy:  64.23%
---------------------------------------------------------
Total Loss:            1.0267 	 Total Accuracy:            64.23%



[Joint Training: Epoch 7/20]: 100%|██████████| 52/52 [00:44<00:00,  1.16it/s]



Train Epoch 7/20:
-------------------
Concept Loss:          0.8387 	 Concept Accuracy:         67.98%
Classification Loss:   0.9937 	 Classification Accuracy:  68.88%
---------------------------------------------------------
Total Loss:            0.9472 	 Total Accuracy:            68.88%

Validate Epoch 7/20:
-------------------
Concept Loss:          0.8137 	 Concept Accuracy:         67.04%
Classification Loss:   1.0423 	 Classification Accuracy:  64.51%
---------------------------------------------------------
Total Loss:            0.9737 	 Total Accuracy:            64.51%



[Joint Training: Epoch 8/20]:   6%|▌         | 3/52 [00:07<01:56,  2.39s/it]

In [None]:
print("Evaluating Model")
evaluate_cbm(model, test_loader, device)

##### Intervention

In [1]:
def show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform=None):
    if transform:
        image = transform(image).unsqueeze(0)
    image = image.to(device)
    concepts = torch.tensor(concepts, dtype=torch.float32).to(device)

    pil_image = ToPILImage()(image.cpu().squeeze(0))
    plt.figure(figsize=(8, 8))
    plt.imshow(pil_image)
    plt.title(f"Original Image (True label: {CLASS_NAMES[true_label]})")
    plt.axis('off')
    plt.show()

    with torch.no_grad():
        _, predictions = model(image)
        predicted_class = predictions.argmax(dim=1).cpu().item()
        predicted_concepts = concepts.cpu().numpy()

    print(f"Predicted Concepts (before intervention):")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {predicted_concepts[0][idx]}")

    print(f"Original prediction: {CLASS_NAMES[predicted_class]}")
    return predicted_concepts, predicted_class


def intervene_and_show_change(model, image, true_label, concepts, concept_columns, concept_column, device, transform=None):
    concept_idx = concept_columns.index(concept_column)
    predicted_concepts, predicted_class = show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform)
    modified_concepts = concepts.clone()
    modified_concepts[concept_idx] = 1 - modified_concepts[concept_idx]

    with torch.no_grad():
        modified_concepts = modified_concepts.unsqueeze(0).to(device)
        _, predictions = model(image)
        modified_pred_class = predictions.argmax(dim=1).cpu().item()

    print(f"\nAfter intervening on '{concept_column}':")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {modified_concepts[0][idx].item()}")

    print(f"Modified prediction: {CLASS_NAMES[modified_pred_class]}")

    plt.figure(figsize=(8, 6))
    plt.bar([CLASS_NAMES[predicted_class], CLASS_NAMES[modified_pred_class]], [1, 1], color=['blue', 'orange'])
    plt.title(f"Prediction Change after Intervening on {concept_column}")
    plt.ylabel("Prediction")
    plt.show()


In [None]:
image, true_label, concepts = next(iter(test_loader))
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
random_concept_column = random.choice(concept_columns)

intervene_and_show_change(
    model=model,
    image=image[0],
    true_label=true_label[0].item(),
    concepts=concepts[0],
    concept_columns=concept_columns,
    concept_column=random_concept_column,
    device=device,
    transform=None
)