### Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
from pathlib import Path

### Paths

In [2]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent
DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
MODEL_DIR = os.path.join(ROOT_DIR, 'models')
CBM_DIR = os.path.join(MODEL_DIR, 'cbm')
os.makedirs(CBM_DIR, exist_ok=True)


In [3]:
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')
CSV_CONCEPTS_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train_concepts.csv')

# Data Preperation

In [4]:
# Class mapping
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

# Concept mapping for each class
CONCEPT_MAPPING = {
    0: ["red", "convex", "scaly", "yes", "thin"],
    1: ["brown", "flat", "smooth", "no", "medium"],
    2: ["yellow", "convex", "warty", "yes", "thick"],
    3: ["white", "bulbous", "smooth", "no", "thin"],
    4: ["brown", "inverted", "scaly", "yes", "medium"],
    5: ["yellow", "flat", "smooth", "no", "thick"],
    6: ["red", "convex", "warty", "yes", "thin"],
    7: ["white", "flat", "smooth", "no", "medium"],
    8: ["brown", "bulbous", "scaly", "yes", "thick"],
    9: ["yellow", "inverted", "smooth", "no", "thin"]
}


In [5]:
df = pd.read_csv(CSV_PATH)

In [6]:
concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]


In [7]:
concept_rows = []
for _, row in df.iterrows():
    class_label = row["Mushroom"]
    concepts = CONCEPT_MAPPING[class_label]
    concept_rows.append(concepts)

concept_df = pd.DataFrame(concept_rows, columns=concept_columns)
updated_df = pd.concat([df, concept_df], axis=1)

In [8]:
updated_df.to_csv(CSV_CONCEPTS_PATH, index=False)
print(f"Updated CSV saved to {CSV_CONCEPTS_PATH}")

Updated CSV saved to c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\dataset\csv_mappings\train_concepts.csv


# Training

### Classes & Functions

In [9]:
class ConceptMushroomDataset(Dataset):
    def __init__(self, preprocessed_dir, csv_path, transform=None):
        self.preprocessed_dir = preprocessed_dir
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

        # Images and Labels
        self.image_ids = self.csv_data['Image'].values
        self.labels = self.csv_data['Mushroom'].values

        # Parse concepts
        concept_columns = ["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]
        self.concepts = pd.get_dummies(self.csv_data[concept_columns]).values.astype('float32')

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = str(self.image_ids[idx]).zfill(5)
        label = self.labels[idx]
        concepts = torch.tensor(self.concepts[idx], dtype=torch.float32)

        # Load image
        image_path = os.path.join(self.preprocessed_dir, f"{image_id}.pt")
        image = torch.load(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label, concepts


In [10]:
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()

        # Feature extractor (ResNet)
        self.feature_extractor = models.resnet50(pretrained=True)
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_concepts)

        # Bottleneck to Classifier
        self.classifier = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        concepts = self.feature_extractor(x)
        predictions = self.classifier(concepts)
        return concepts, predictions

In [11]:
# Data Preparation
def get_data_loaders(preprocessed_dir, csv_path, batch_size=32):
    # Init
    dataset = ConceptMushroomDataset(preprocessed_dir, csv_path)

    # Split 
    indices = list(range(len(dataset)))
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

    # Subsets
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    # Dataloaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [12]:
def train_concept_layer(model, train_loader, val_loader, device, epochs=10):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.feature_extractor.parameters(), lr=0.001)

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_total = 0
        train_correct = 0

        for images, _, concepts in tqdm(train_loader, desc=f"[Train Concepts: Epoch {epoch+1}/{epochs}]"):
            images, concepts = images.to(device), concepts.to(device)

            optimizer.zero_grad()
            predicted_concepts, _ = model(images)
            loss = criterion(predicted_concepts, concepts)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += concepts.size(0)
            train_correct += (predicted_concepts.round() == concepts).all(dim=1).sum().item()

        train_accuracy = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_total = 0
        val_correct = 0

        with torch.no_grad():
            for images, _, concepts in val_loader:
                images, concepts = images.to(device), concepts.to(device)
                predicted_concepts, _ = model(images)
                loss = criterion(predicted_concepts, concepts)

                val_loss += loss.item()
                val_total += concepts.size(0)
                val_correct += (predicted_concepts.round() == concepts).all(dim=1).sum().item()

        val_accuracy = 100 * val_correct / val_total

        # Print epoch stats
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {train_loss / len(train_loader):.4f} | Train Accuracy: {train_accuracy:.2f}%")
        print(f"  Val Loss: {val_loss / len(val_loader):.4f} | Val Accuracy: {val_accuracy:.2f}%")


In [13]:
def train_classification_layer(model, train_loader, val_loader, device, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_total = 0
        train_correct = 0

        for images, labels, _ in tqdm(train_loader, desc=f"[Train Classification: Epoch {epoch+1}/{epochs}]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            _, predictions = model(images)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += labels.size(0)
            train_correct += (predictions.argmax(dim=1) == labels).sum().item()

        train_accuracy = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_total = 0
        val_correct = 0

        with torch.no_grad():
            for images, labels, _ in val_loader:
                images, labels = images.to(device), labels.to(device)
                _, predictions = model(images)
                loss = criterion(predictions, labels)

                val_loss += loss.item()
                val_total += labels.size(0)
                val_correct += (predictions.argmax(dim=1) == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total

        # Print epoch stats
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {train_loss / len(train_loader):.4f} | Train Accuracy: {train_accuracy:.2f}%")
        print(f"  Val Loss: {val_loss / len(val_loader):.4f} | Val Accuracy: {val_accuracy:.2f}%")


In [14]:
# Evaluation
def evaluate_cbm(model, test_loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc="[Evaluate]"):
            images, labels = images.to(device), labels.to(device)
            _, predictions = model(images)
            
            _, predicted_labels = torch.max(predictions, 1)
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = 100 * total_correct / total_samples
    print(f"Accuracy: {accuracy:.2f}%")


### Config

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader, val_loader, test_loader = get_data_loaders(PREPROCESSED_DIR, CSV_CONCEPTS_PATH)

num_concepts = len(pd.get_dummies(pd.read_csv(CSV_CONCEPTS_PATH)[["cap_color", "cap_shape", "cap_texture", "ring_present", "stem_thickness"]]).columns)
num_classes = len(pd.read_csv(CSV_CONCEPTS_PATH)['Mushroom'].unique())

model = ConceptBottleneckModel(num_concepts, num_classes).to(device)



### Training

In [53]:
print("Training Concept Layer")
train_concept_layer(model, train_loader, val_loader, device)

Training Concept Layer


  image = torch.load(image_path)
[Train Concepts: Epoch 1/10]:  15%|█▌        | 8/52 [00:30<02:54,  3.97s/it]

In [None]:
print("Training Classification Layer")
train_classification_layer(model, train_loader, val_loader, device)

In [None]:
print("Evaluating Model")
evaluate_cbm(model, test_loader, device)