### Imports

In [26]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import random
from pathlib import Path

### Paths

In [27]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent
DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
MODEL_DIR = os.path.join(ROOT_DIR, 'models')
CBM_DIR = os.path.join(MODEL_DIR, 'cbm')
os.makedirs(CBM_DIR, exist_ok=True)


In [28]:
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')
CSV_CONCEPTS_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train_concepts.csv')

# Data Preperation

In [29]:
# Class mapping
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

# Concept mapping for each class
CONCEPT_MAPPING = {
    0: ["red", "convex", "scaly", "yes", "thin"],
    1: ["brown", "flat", "smooth", "no", "medium"],
    2: ["yellow", "convex", "warty", "yes", "thick"],
    3: ["white", "bulbous", "smooth", "no", "thin"],
    4: ["brown", "inverted", "scaly", "yes", "medium"],
    5: ["yellow", "flat", "smooth", "no", "thick"],
    6: ["red", "convex", "warty", "yes", "thin"],
    7: ["white", "flat", "smooth", "no", "medium"],
    8: ["brown", "bulbous", "scaly", "yes", "thick"],
    9: ["yellow", "inverted", "smooth", "no", "thin"]
}


In [30]:
df = pd.read_csv(CSV_PATH)

In [31]:
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]


In [32]:
concept_rows = []
for _, row in df.iterrows():
    class_label = row["Mushroom"]
    concepts = CONCEPT_MAPPING[class_label]
    concept_rows.append(concepts)

concept_df = pd.DataFrame(concept_rows, columns=concept_columns)
updated_df = pd.concat([df, concept_df], axis=1)

In [33]:
updated_df.to_csv(CSV_CONCEPTS_PATH, index=False)
print(f"Updated CSV saved to {CSV_CONCEPTS_PATH}")

Updated CSV saved to c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\dataset\csv_mappings\train_concepts.csv


# Training

### Classes & Functions

In [34]:
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()

        self.feature_extractor = models.resnet18(pretrained=True)

        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_concepts)

        for param in self.feature_extractor.fc.parameters():
            param.requires_grad = True

        self.classifier = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        for param in self.classifier.parameters():
            param.requires_grad = True

    def forward(self, x):
        concepts = self.feature_extractor(x)
        predictions = self.classifier(concepts)
        return concepts, predictions


In [35]:
class ConceptMushroomDataset(Dataset):
    def __init__(self, preprocessed_dir, csv_path, transform=None):
        self.preprocessed_dir = preprocessed_dir
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

        self.image_ids = self.csv_data['Image'].values
        self.labels = self.csv_data['Mushroom'].values

        concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
        self.concepts = pd.get_dummies(self.csv_data[concept_columns]).values.astype('float32')

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = str(self.image_ids[idx]).zfill(5)
        label = self.labels[idx]
        concepts = torch.tensor(self.concepts[idx], dtype=torch.float32)

        image_path = os.path.join(self.preprocessed_dir, f"{image_id}.pt")
        image = torch.load(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label, concepts


In [36]:

def get_data_loaders(preprocessed_dir, csv_path, batch_size=32):
    dataset = ConceptMushroomDataset(preprocessed_dir, csv_path)

    indices = list(range(len(dataset)))
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


In [37]:
def calculate_concept_accuracy(predicted_concepts, ground_truth):
    predicted_classes = torch.argmax(predicted_concepts, dim=1)
    true_classes = torch.argmax(ground_truth, dim=1)
    correct = (predicted_classes == true_classes).sum().item()
    return correct

In [38]:
def train_jointly(model, train_loader, val_loader, device, optimizer, epochs=10, lambda_concept=0.5, lambda_classification=0.5):
    criterion_concept = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for categorical concepts
    criterion_classification = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        train_total_loss = 0.0
        train_concept_loss = 0.0
        train_classification_loss = 0.0
        train_total = 0
        train_correct = 0
        train_concept_correct = 0

        for images, labels, concepts in tqdm(train_loader, desc=f"[Joint Training: Epoch {epoch+1}/{epochs}]"):
            images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

            optimizer.zero_grad()
            predicted_concepts, predictions = model(images)

            # Calculate concept loss (CrossEntropyLoss expects class indices)
            loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))  # Use argmax on one-hot encoded concepts
            loss_classification = criterion_classification(predictions, labels)
            loss = lambda_concept * loss_concept + lambda_classification * loss_classification

            loss.backward()
            optimizer.step()

            # Track losses and accuracy
            train_total_loss += loss.item()
            train_concept_loss += loss_concept.item()
            train_classification_loss += loss_classification.item()
            
            train_total += labels.size(0)
            train_correct += (predictions.argmax(dim=1) == labels).sum().item()

            # Calculate concept accuracy
            train_concept_correct += calculate_concept_accuracy(predicted_concepts, concepts)

        # Calculate accuracies
        train_accuracy = 100 * train_correct / train_total
        train_concept_accuracy = 100 * train_concept_correct / train_total

        # Calculate combined total accuracy as the average of concept and classification accuracies
        train_total_accuracy = (train_concept_accuracy + train_accuracy) / 2

        # Print stats for the current epoch
        print(f"\nTrain Epoch {epoch+1}/{epochs}:")
        print("-------------------")
        print(f"Concept Loss:          {train_concept_loss / len(train_loader):.4f} \t Concept Accuracy:         {train_concept_accuracy:.2f}%")
        print(f"Classification Loss:   {train_classification_loss / len(train_loader):.4f} \t Classification Accuracy:  {train_accuracy:.2f}%")
        print("---------------------------------------------------------")
        print(f"Total Loss:            {train_total_loss / len(train_loader):.4f} \t Total Accuracy:            {train_total_accuracy:.2f}%\n")

        # Validation phase
        model.eval()
        val_total_loss = 0.0
        val_concept_loss = 0.0
        val_classification_loss = 0.0
        val_total = 0
        val_correct = 0
        val_concept_correct = 0

        with torch.no_grad():
            for images, labels, concepts in val_loader:
                images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

                predicted_concepts, predictions = model(images)

                loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))
                loss_classification = criterion_classification(predictions, labels)
                loss = lambda_concept * loss_concept + lambda_classification * loss_classification

                val_total_loss += loss.item()
                val_concept_loss += loss_concept.item()
                val_classification_loss += loss_classification.item()
                val_total += labels.size(0)
                val_correct += (predictions.argmax(dim=1) == labels).sum().item()

                # Concept accuracy calculation
                val_concept_correct += calculate_concept_accuracy(predicted_concepts, concepts)

        # Calculate accuracies for validation
        val_accuracy = 100 * val_correct / val_total
        val_concept_accuracy = 100 * val_concept_correct / val_total

        # Calculate combined total accuracy for validation
        val_total_accuracy = (val_concept_accuracy + val_accuracy) / 2

        # Print stats for the validation set
        print(f"Validate Epoch {epoch+1}/{epochs}:")
        print("-------------------")
        print(f"Concept Loss:          {val_concept_loss / len(val_loader):.4f} \t Concept Accuracy:         {val_concept_accuracy:.2f}%")
        print(f"Classification Loss:   {val_classification_loss / len(val_loader):.4f} \t Classification Accuracy:  {val_accuracy:.2f}%")
        print("---------------------------------------------------------")
        print(f"Total Loss:            {val_total_loss / len(val_loader):.4f} \t Total Accuracy:            {val_total_accuracy:.2f}%\n")


In [39]:
def train_separately_concept(model, train_loader, val_loader, device, optimizer_concept, epochs=10, patience=3):
    criterion_concept = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_total_loss = 0.0
        train_concept_correct = 0
        train_total = 0

        for images, labels, concepts in tqdm(train_loader, desc=f"\n[Concept Training: Epoch {epoch+1}/{epochs}]"):
            images, concepts = images.to(device), concepts.to(device)

            optimizer_concept.zero_grad()
            predicted_concepts, _ = model(images)

            # Compute concept loss
            loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))
            loss_concept.backward()
            optimizer_concept.step()

            train_total_loss += loss_concept.item()
            train_concept_correct += (predicted_concepts.argmax(dim=1) == torch.argmax(concepts, dim=1)).sum().item()
            train_total += concepts.size(0)

        train_concept_accuracy = 100 * train_concept_correct / train_total

        print(f"Train Concept Epoch {epoch+1}/{epochs}:")
        print(f"Concept Loss: {train_total_loss / len(train_loader):.4f}, Accuracy: {train_concept_accuracy:.2f}%")

        # Validation phase
        model.eval()
        val_total_loss = 0.0
        val_concept_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels, concepts in val_loader:
                images, concepts = images.to(device), concepts.to(device)

                predicted_concepts, _ = model(images)

                loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))
                val_total_loss += loss_concept.item()
                val_concept_correct += (predicted_concepts.argmax(dim=1) == torch.argmax(concepts, dim=1)).sum().item()
                val_total += concepts.size(0)

        val_concept_accuracy = 100 * val_concept_correct / val_total

        print(f"Validation Concept Loss: {val_total_loss / len(val_loader):.4f}, Accuracy: {val_concept_accuracy:.2f}%")

        # Early stopping check
        if val_total_loss < best_val_loss:
            best_val_loss = val_total_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_concept_model.pth")
            print("Best model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break


In [None]:
import torch.nn as nn
import torch
from tqdm import tqdm

def train_separately_classification(model, train_loader, val_loader, device, optimizer_classification, epochs=10, patience=3):
    criterion_classification = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_total_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels, concepts in tqdm(train_loader, desc=f"[Classification Training: Epoch {epoch+1}/{epochs}]"):
            images, labels = images.to(device), labels.to(device)
            
            optimizer_classification.zero_grad()
            _, predictions = model(images)
            
            loss_classification = criterion_classification(predictions, labels)
            loss_classification.backward()
            optimizer_classification.step()
            
            train_total_loss += loss_classification.item()
            train_correct += (predictions.argmax(dim=1) == labels).sum().item()
            train_total += labels.size(0)
        
        train_accuracy = 100 * train_correct / train_total
        print(f"\nTrain Classification Epoch {epoch+1}/{epochs}:")
        print(f"Classification Loss: {train_total_loss / len(train_loader):.4f} \t Classification Accuracy: {train_accuracy:.2f}%")

        # Validation phase
        model.eval()
        val_total_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels, concepts in val_loader:
                images, labels = images.to(device), labels.to(device)
                _, predictions = model(images)
                loss_classification = criterion_classification(predictions, labels)
                val_total_loss += loss_classification.item()
                val_correct += (predictions.argmax(dim=1) == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_total_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f"Validation Classification Loss: {val_loss:.4f} \t Validation Accuracy: {val_accuracy:.2f}%")

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print("Validation loss improved, saving model...")
            torch.save(model.state_dict(), 'best_classification_model.pth')
        else:
            patience_counter += 1
            print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

In [41]:
def evaluate_model(model, test_loader, device):
    criterion_concept = nn.CrossEntropyLoss()
    criterion_classification = nn.CrossEntropyLoss()

    total_samples = 0
    correct_classifications = 0
    correct_concepts = 0
    total_concept_loss = 0.0
    total_classification_loss = 0.0

    model.eval()
    with torch.no_grad():
        for images, labels, concepts in test_loader:
            images, labels, concepts = images.to(device), labels.to(device), concepts.to(device)

            predicted_concepts, predictions = model(images)

            # Calculate concept and classification loss
            loss_concept = criterion_concept(predicted_concepts, torch.argmax(concepts, dim=1))
            loss_classification = criterion_classification(predictions, labels)

            total_concept_loss += loss_concept.item()
            total_classification_loss += loss_classification.item()

            # Compute accuracy
            total_samples += labels.size(0)
            correct_classifications += (predictions.argmax(dim=1) == labels).sum().item()
            correct_concepts += (predicted_concepts.argmax(dim=1) == torch.argmax(concepts, dim=1)).sum().item()

    concept_accuracy = 100 * correct_concepts / total_samples
    classification_accuracy = 100 * correct_classifications / total_samples
    avg_concept_loss = total_concept_loss / len(test_loader)
    avg_classification_loss = total_classification_loss / len(test_loader)

    print("\nEvaluation Results:")
    print("-------------------")
    print(f"Concept Loss:          {avg_concept_loss:.4f} \t Concept Accuracy:         {concept_accuracy:.2f}%")
    print(f"Classification Loss:   {avg_classification_loss:.4f} \t Classification Accuracy:  {classification_accuracy:.2f}%")
    print("---------------------------------------------------------")

    return concept_accuracy, classification_accuracy


### Config

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
train_loader, val_loader, test_loader = get_data_loaders(PREPROCESSED_DIR, CSV_CONCEPTS_PATH)

In [19]:
num_concepts = len(pd.get_dummies(pd.read_csv(CSV_CONCEPTS_PATH)[["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]]).columns)
num_classes = len(pd.read_csv(CSV_CONCEPTS_PATH)['Mushroom'].unique())

In [20]:
model = ConceptBottleneckModel(num_concepts, num_classes).to(device)



### Joint training

In [21]:
optimizer = optim.AdamW(model.parameters(), lr=0.0005)

In [22]:
print("Starting Joint Training")
train_jointly(model, train_loader, val_loader, device, optimizer, epochs=20, lambda_concept=0.3, lambda_classification=0.7)

Starting Joint Training


  image = torch.load(image_path)
[Joint Training: Epoch 1/20]:  21%|██        | 11/52 [00:14<00:53,  1.32s/it]


KeyboardInterrupt: 

### Seperate training

In [23]:
optimizer_concept = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer_classification = torch.optim.Adam(model.parameters(), lr=0.001)

In [24]:
train_separately_concept(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    device=device, 
    optimizer_concept=optimizer_concept, 
    epochs=10, 
    patience=3
)

  image = torch.load(image_path)
[Concept Training: Epoch 1/10]: 100%|██████████| 52/52 [00:39<00:00,  1.30it/s]



Train Concept Epoch 1/10:
Concept Loss: 1.2294, Accuracy: 46.22%
Validation Concept Loss: 1.0296, Accuracy: 58.03%
Best model saved.


[Concept Training: Epoch 2/10]: 100%|██████████| 52/52 [00:40<00:00,  1.28it/s]



Train Concept Epoch 2/10:
Concept Loss: 0.9876, Accuracy: 59.09%
Validation Concept Loss: 0.9234, Accuracy: 63.38%
Best model saved.


[Concept Training: Epoch 3/10]: 100%|██████████| 52/52 [00:57<00:00,  1.10s/it]



Train Concept Epoch 3/10:
Concept Loss: 0.8790, Accuracy: 65.92%
Validation Concept Loss: 0.8366, Accuracy: 68.73%
Best model saved.


[Concept Training: Epoch 4/10]: 100%|██████████| 52/52 [00:41<00:00,  1.26it/s]



Train Concept Epoch 4/10:
Concept Loss: 0.8139, Accuracy: 67.43%
Validation Concept Loss: 0.7790, Accuracy: 69.30%
Best model saved.


[Concept Training: Epoch 5/10]: 100%|██████████| 52/52 [00:41<00:00,  1.27it/s]



Train Concept Epoch 5/10:
Concept Loss: 0.7599, Accuracy: 70.94%
Validation Concept Loss: 0.7567, Accuracy: 73.24%
Best model saved.


[Concept Training: Epoch 6/10]: 100%|██████████| 52/52 [01:07<00:00,  1.29s/it]



Train Concept Epoch 6/10:
Concept Loss: 0.7545, Accuracy: 71.24%
Validation Concept Loss: 0.7600, Accuracy: 68.17%


[Concept Training: Epoch 7/10]: 100%|██████████| 52/52 [00:41<00:00,  1.26it/s]



Train Concept Epoch 7/10:
Concept Loss: 0.7208, Accuracy: 72.51%
Validation Concept Loss: 0.7998, Accuracy: 67.32%


[Concept Training: Epoch 8/10]: 100%|██████████| 52/52 [00:41<00:00,  1.25it/s]



Train Concept Epoch 8/10:
Concept Loss: 0.7150, Accuracy: 72.27%
Validation Concept Loss: 0.7247, Accuracy: 70.70%
Best model saved.


[Concept Training: Epoch 9/10]: 100%|██████████| 52/52 [00:42<00:00,  1.22it/s]



Train Concept Epoch 9/10:
Concept Loss: 0.7036, Accuracy: 71.72%
Validation Concept Loss: 0.7185, Accuracy: 73.52%
Best model saved.


[Concept Training: Epoch 10/10]: 100%|██████████| 52/52 [01:36<00:00,  1.85s/it]



Train Concept Epoch 10/10:
Concept Loss: 0.6643, Accuracy: 74.74%
Validation Concept Loss: 0.7313, Accuracy: 71.83%


In [44]:
train_separately_classification(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    device=device, 
    optimizer_classification=optimizer_classification, 
    epochs=10, 
    patience=3
)

  image = torch.load(image_path)
[Classification Training: Epoch 1/10]:  31%|███       | 16/52 [00:05<00:12,  2.83it/s]

### Intervention

In [1]:
def show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform=None):
    if transform:
        image = transform(image).unsqueeze(0)  
    image = image.to(device)
    concepts = torch.tensor(concepts, dtype=torch.float32).to(device).unsqueeze(0)  

    pil_image = ToPILImage()(image.cpu().squeeze(0))
    plt.figure(figsize=(8, 8))
    plt.imshow(pil_image)
    plt.title(f"Original Image (True label: {CLASS_NAMES[true_label]})")
    plt.axis('off')
    plt.show()

    with torch.no_grad():
        predicted_concepts, predictions = model(image)  
        predicted_class = predictions.argmax(dim=1).cpu().item()

    predicted_concepts = predicted_concepts.cpu().numpy()

    print(f"Predicted Concepts (before intervention):")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {predicted_concepts[0][idx]:.4f}")

    print(f"Original prediction: {CLASS_NAMES[predicted_class]}")
    return predicted_concepts, predicted_class


In [None]:
def intervene_and_show_change(model, image, true_label, concepts, concept_columns, concept_column, device, transform=None):
    concept_idx = concept_columns.index(concept_column)
    
    # Get initial concepts and prediction
    predicted_concepts, predicted_class = show_image_and_predicted_concepts(model, image, true_label, concepts, concept_columns, device, transform)
    
    # Modify concept
    modified_concepts = torch.tensor(predicted_concepts, dtype=torch.float32).to(device)
    modified_concepts[0][concept_idx] = 1 - modified_concepts[0][concept_idx]  
    
    # Pass modified concepts to classifier
    with torch.no_grad():
        modified_predictions = model.classifier(modified_concepts)
        modified_pred_class = modified_predictions.argmax(dim=1).cpu().item()

    print(f"\nAfter intervening on '{concept_column}':")
    for idx, concept in enumerate(concept_columns):
        print(f"{concept}: {modified_concepts[0][idx].item():.4f}")

    print(f"Modified prediction: {CLASS_NAMES[modified_pred_class]}")

    # Visual prediction change
    plt.figure(figsize=(8, 6))
    plt.bar(
        [CLASS_NAMES[predicted_class], CLASS_NAMES[modified_pred_class]],
        [1, 1],
        color=['blue', 'orange']
    )
    plt.title(f"Prediction Change after Intervening on '{concept_column}'")
    plt.ylabel("Prediction")
    plt.show()

In [None]:
image, true_label, concepts = next(iter(test_loader))
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
random_concept_column = random.choice(concept_columns)

intervene_and_show_change(
    model=model,
    image=image[0], 
    true_label=true_label[0].item(),  
    concepts=concepts[0], 
    concept_columns=concept_columns,
    concept_column=random_concept_column,
    device=device,
    transform=None
)
