### Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
from pathlib import Path

### Paths

In [2]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent
DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
MODEL_DIR = os.path.join(ROOT_DIR, 'models')
CBM_DIR = os.path.join(MODEL_DIR, 'cbm')
os.makedirs(CBM_DIR, exist_ok=True)


In [3]:
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')
CSV_CONCEPTS_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train_concepts.csv')

# Data Preperation

In [4]:
# Class mapping
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

# Concept mapping for each class
CONCEPT_MAPPING = {
    0: ["red", "convex", "scaly", "yes", "thin"],
    1: ["brown", "flat", "smooth", "no", "medium"],
    2: ["yellow", "convex", "warty", "yes", "thick"],
    3: ["white", "bulbous", "smooth", "no", "thin"],
    4: ["brown", "inverted", "scaly", "yes", "medium"],
    5: ["yellow", "flat", "smooth", "no", "thick"],
    6: ["red", "convex", "warty", "yes", "thin"],
    7: ["white", "flat", "smooth", "no", "medium"],
    8: ["brown", "bulbous", "scaly", "yes", "thick"],
    9: ["yellow", "inverted", "smooth", "no", "thin"]
}


In [5]:
df = pd.read_csv(CSV_PATH)

In [6]:
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]


In [7]:
concept_rows = []
for _, row in df.iterrows():
    class_label = row["Mushroom"]
    concepts = CONCEPT_MAPPING[class_label]
    concept_rows.append(concepts)

concept_df = pd.DataFrame(concept_rows, columns=concept_columns)
updated_df = pd.concat([df, concept_df], axis=1)

In [8]:
updated_df.to_csv(CSV_CONCEPTS_PATH, index=False)
print(f"Updated CSV saved to {CSV_CONCEPTS_PATH}")

Updated CSV saved to c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\dataset\csv_mappings\train_concepts.csv


# Training

### Classes & Functions

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()

        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_concepts)

        self.classifier = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        concepts = self.feature_extractor(x)
        predictions = self.classifier(concepts)
        return concepts, predictions


class ConceptMushroomDataset(Dataset):
    def __init__(self, preprocessed_dir, csv_path, transform=None):
        self.preprocessed_dir = preprocessed_dir
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

        self.image_ids = self.csv_data['Image'].values
        self.labels = self.csv_data['Mushroom'].values

        concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
        self.concepts = pd.get_dummies(self.csv_data[concept_columns]).values.astype('float32')

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = str(self.image_ids[idx]).zfill(5)
        label = self.labels[idx]
        concepts = torch.tensor(self.concepts[idx], dtype=torch.float32)

        image_path = os.path.join(self.preprocessed_dir, f"{image_id}.pt")
        image = torch.load(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label, concepts


def get_data_loaders(preprocessed_dir, csv_path, batch_size=32):
    dataset = ConceptMushroomDataset(preprocessed_dir, csv_path)

    indices = list(range(len(dataset)))
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


def train_jointly(model, train_loader, val_loader, device, epochs=10, lambda_concept=0.5, lambda_classification=0.5):
    criterion_concept = nn.MSELoss()
    criterion_classification = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_total = 0
        train_correct = 0

        for images, labels, concepts in tqdm(train_loader, desc=f"[Joint Training: Epoch {epoch+1}/{epochs}]"):
            images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

            optimizer.zero_grad()
            predicted_concepts, predictions = model(images)

            loss_concept = criterion_concept(predicted_concepts, concepts)
            loss_classification = criterion_classification(predictions, labels)
            loss = lambda_concept * loss_concept + lambda_classification * loss_classification

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += labels.size(0)
            train_correct += (predictions.argmax(dim=1) == labels).sum().item()

        train_accuracy = 100 * train_correct / train_total

        model.eval()
        val_loss = 0.0
        val_total = 0
        val_correct = 0

        with torch.no_grad():
            for images, labels, concepts in val_loader:
                images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

                predicted_concepts, predictions = model(images)
                loss_concept = criterion_concept(predicted_concepts, concepts)
                loss_classification = criterion_classification(predictions, labels)
                loss = lambda_concept * loss_concept + lambda_classification * loss_classification

                val_loss += loss.item()
                val_total += labels.size(0)
                val_correct += (predictions.argmax(dim=1) == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {train_loss / len(train_loader):.4f} | Train Accuracy: {train_accuracy:.2f}%")
        print(f"  Val Loss: {val_loss / len(val_loader):.4f} | Val Accuracy: {val_accuracy:.2f}%")


def evaluate_cbm(model, test_loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc="[Evaluate]"):
            images, labels = images.to(device), labels.to(device)
            _, predictions = model(images)

            _, predicted_labels = torch.max(predictions, 1)
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = 100 * total_correct / total_samples
    print(f"Accuracy: {accuracy:.2f}%")


### Config

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader, val_loader, test_loader = get_data_loaders(PREPROCESSED_DIR, CSV_CONCEPTS_PATH)

num_concepts = len(pd.get_dummies(pd.read_csv(CSV_CONCEPTS_PATH)[["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]]).columns)
num_classes = len(pd.read_csv(CSV_CONCEPTS_PATH)['Mushroom'].unique())

model = ConceptBottleneckModel(num_concepts, num_classes).to(device)



### Training

In [None]:
print("Starting Joint Training")
train_jointly(model, train_loader, val_loader, device, epochs=20, lambda_concept=0.5, lambda_classification=0.5)

Starting Joint Training


  image = torch.load(image_path)
[Joint Training: Epoch 1/20]:  12%|█▏        | 6/52 [00:10<01:26,  1.89s/it]

In [None]:
print("Evaluating Model")
evaluate_cbm(model, test_loader, device)

##### Intervention

In [13]:
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import random

def show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform=None):
    if transform:
        image = transform(image).unsqueeze(0)
    image = image.to(device)
    concepts = torch.tensor(concepts, dtype=torch.float32).to(device)

    pil_image = ToPILImage()(image.cpu().squeeze(0))
    plt.figure(figsize=(8, 8))
    plt.imshow(pil_image)
    plt.title(f"Original Image (True label: {CLASS_NAMES[true_label]})")
    plt.axis('off')
    plt.show()

    with torch.no_grad():
        _, predictions = model(image)
        predicted_class = predictions.argmax(dim=1).cpu().item()
        predicted_concepts = concepts.cpu().numpy()

    print(f"Predicted Concepts (before intervention):")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {predicted_concepts[0][idx]}")

    print(f"Original prediction: {CLASS_NAMES[predicted_class]}")
    return predicted_concepts, predicted_class


def intervene_and_show_change(model, image, true_label, concepts, concept_columns, concept_column, device, transform=None):
    concept_idx = concept_columns.index(concept_column)
    predicted_concepts, predicted_class = show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform)
    modified_concepts = concepts.clone()
    modified_concepts[concept_idx] = 1 - modified_concepts[concept_idx]

    with torch.no_grad():
        modified_concepts = modified_concepts.unsqueeze(0).to(device)
        _, predictions = model(image)
        modified_pred_class = predictions.argmax(dim=1).cpu().item()

    print(f"\nAfter intervening on '{concept_column}':")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {modified_concepts[0][idx].item()}")

    print(f"Modified prediction: {CLASS_NAMES[modified_pred_class]}")

    plt.figure(figsize=(8, 6))
    plt.bar([CLASS_NAMES[predicted_class], CLASS_NAMES[modified_pred_class]], [1, 1], color=['blue', 'orange'])
    plt.title(f"Prediction Change after Intervening on {concept_column}")
    plt.ylabel("Prediction")
    plt.show()


In [None]:
image, true_label, concepts = next(iter(test_loader))

concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]

random_concept_column = random.choice(concept_columns)

intervene_and_show_change(
    model=model,
    image=image[0],
    true_label=true_label[0].item(),
    concepts=concepts[0],
    concept_columns=concept_columns,
    concept_column=random_concept_column,
    device=device,
    transform=None
)