In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F
from torch.utils.data import random_split
import random
from torchvision.transforms import ToPILImage
from sklearn.preprocessing import LabelEncoder

### Create data with concepts

In [None]:
concepts = {
    'cap_shape': ['convex', 'flat', 'bell-shaped', 'other'],
    'cap_color': ['white', 'brown', 'yellow', 'red', 'other'],
    'cap_texture': ['smooth', 'scaly', 'warty', 'other'],
    'gill_attachment': ['free', 'attached', 'decurrent'],
    'gill_color': ['white', 'brown', 'yellow', 'other'],
    'stem_shape': ['cylindrical', 'club-shaped', 'bulbous'],
    'stem_color': ['white', 'brown', 'yellow', 'other'],
    'stem_texture': ['smooth', 'fibrous', 'scaly', 'other'],
    'spore_print_color': ['white', 'brown', 'yellow', 'other']
}

class_concepts = {
    0: {'cap_color': 'white'},
    1: {'cap_shape': 'convex', 'cap_color': 'brown', 'spore_print_color': 'brown'},
    2: {'cap_shape': 'convex', 'cap_color': 'yellow', 'gill_attachment': 'attached'},
    3: {'cap_color': 'brown', 'gill_color': 'brown', 'stem_shape': 'cylindrical'},
    4: {'cap_texture': 'smooth', 'gill_attachment': 'free', 'spore_print_color': 'white'},
    5: {'cap_shape': 'convex', 'stem_color': 'brown', 'stem_texture': 'fibrous'},
    6: {'cap_color': 'orange', 'cap_texture': 'smooth'},
    7: {'cap_shape': 'flat', 'cap_texture': 'scaly', 'stem_shape': 'cylindrical'},
    8: {'gill_attachment': 'decurrent', 'stem_color': 'white'},
    9: {'cap_color': 'red', 'gill_color': 'white', 'spore_print_color': 'white'}
}


In [None]:
root_path = os.path.join(os.path.dirname(os.getcwd())) 
csv_path = os.path.join(root_path, "dataset", "csv_mappings")

original_train_csv_path = os.path.join(csv_path, "train.csv")
#original_test_csv_path = os.path.join(csv_path, "test.csv")

train_with_concepts = os.path.join(csv_path, "train_with_concepts.csv")
#test_with_concepts = os.path.join(csv_path, "test_with_concepts.csv")

original_df = pd.read_csv(original_train_csv_path)
#original_df = pd.read_csv(original_test_csv_path)

def map_class_to_concepts(label):
    return class_concepts[label]

concept_encoders = {concept: LabelEncoder() for concept in concepts.keys()}

for label, concept_values in class_concepts.items():
    for concept, value in concept_values.items():
        original_df.loc[original_df['Mushroom'] == label, concept] = value
for concept, encoder in concept_encoders.items():
    if concept in original_df.columns:
        original_df[concept] = encoder.fit_transform(original_df[concept].fillna('other'))

original_df.to_csv(train_with_concepts, index=False)
print(f"CSV with encoded concepts created at: {train_with_concepts}")



### Create CBM

##### Preprocessed data

In [None]:
class PreprocessedMushroomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, f'{int(self.annotations.iloc[idx, 0]):05d}.pt')  # Ensure 5 digits
        image = torch.load(img_name).float()
        class_label = self.annotations.iloc[idx, 1].astype('float')
        concept_labels = self.annotations.iloc[idx, 2:].values.astype('float')
        labels = torch.tensor([class_label] + concept_labels.tolist())
        if self.transform:
            image = self.transform(image)
        return image, labels


##### Paths

In [None]:
root_path = os.path.join(os.path.dirname(os.getcwd()))
preprocessed_train_path = os.path.join(root_path, 'dataset', 'preprocessed', 'train')
preprocessed_test_path = os.path.join(root_path, 'dataset', 'preprocessed', 'test')
train_csv_path = os.path.join(root_path, 'dataset', 'csv_mappings', 'train_with_concepts.csv')
test_csv_path = os.path.join(root_path, 'dataset', 'csv_mappings', 'test_with_concepts.csv')


##### Dataset and Dataloaders

In [None]:
train_dataset = PreprocessedMushroomDataset(csv_file=train_csv_path, root_dir=preprocessed_train_path)
#test_dataset = PreprocessedMushroomDataset(csv_file=test_csv_path, root_dir=preprocessed_test_path)

##### Split dataset

In [None]:
train_size = int(0.8 * len(train_dataset))  # 80% for training
val_size = len(train_dataset) - train_size  # 20% for validation

train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
#test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##### Concept Bottleneck Model (CBM)

In [None]:
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc_concepts = nn.Sequential(
            nn.Linear(128 * 56 * 56, 256),
            nn.ReLU(),
            nn.Linear(256, num_concepts)
        )
        self.fc_task = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        concepts = self.fc_concepts(x)
        output = self.fc_task(concepts)
        return output, concepts


##### Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_concepts = 9
num_classes = 10  

model = ConceptBottleneckModel(num_concepts=num_concepts, num_classes=num_classes).to(device)
criterion_task = nn.CrossEntropyLoss()
criterion_concept = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    # Training 
    model.train() 
    running_loss = 0.0
    running_task_loss = 0.0
    running_concept_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for images, labels in train_dataloader:
        images = images.to(device)
        class_labels = labels[:, 0].to(device).long()  
        concept_labels = labels[:, 1:].to(device).float()

        optimizer.zero_grad()
        outputs, predicted_concepts = model(images)

        # Task concept losses
        loss_task = criterion_task(outputs, class_labels)
        loss_concepts = criterion_concept(predicted_concepts, concept_labels)

        # Combined loss
        alpha = 0.5 
        loss = alpha * loss_task + (1 - alpha) * loss_concepts

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Update losses
        running_loss += loss.item()
        running_task_loss += loss_task.item()
        running_concept_loss += loss_concepts.item()

        # Accuracy
        _, predicted_classes = torch.max(outputs, 1)
        correct_predictions += (predicted_classes == class_labels).sum().item()
        total_predictions += class_labels.size(0)

    # Average losses and accuracy for training
    epoch_loss = running_loss / len(train_dataloader)
    epoch_task_loss = running_task_loss / len(train_dataloader)
    epoch_concept_loss = running_concept_loss / len(train_dataloader)
    accuracy = correct_predictions / total_predictions * 100

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, "
          f"Task Loss: {epoch_task_loss:.4f}, Concept Loss: {epoch_concept_loss:.4f}, "
          f"Accuracy: {accuracy:.2f}%")


    # Validation 
    model.eval()  
    val_running_loss = 0.0
    val_running_task_loss = 0.0
    val_running_concept_loss = 0.0
    val_correct_predictions = 0
    val_total_predictions = 0

    with torch.no_grad(): 
        for images, labels in val_dataloader:
            images = images.to(device)
            class_labels = labels[:, 0].to(device).long()  
            concept_labels = labels[:, 1:].to(device).float()

            outputs, predicted_concepts = model(images)

            # Task concept losses
            loss_task = criterion_task(outputs, class_labels)
            loss_concepts = criterion_concept(predicted_concepts, concept_labels)

            # Combined loss
            alpha = 0.5
            loss = alpha * loss_task + (1 - alpha) * loss_concepts

            # Update losses
            val_running_loss += loss.item()
            val_running_task_loss += loss_task.item()
            val_running_concept_loss += loss_concepts.item()

            # Accuracy
            _, predicted_classes = torch.max(outputs, 1)
            val_correct_predictions += (predicted_classes == class_labels).sum().item()
            val_total_predictions += class_labels.size(0)

    # Average validation losses and accuracy
    val_epoch_loss = val_running_loss / len(val_dataloader)
    val_epoch_task_loss = val_running_task_loss / len(val_dataloader)
    val_epoch_concept_loss = val_running_concept_loss / len(val_dataloader)
    val_accuracy = val_correct_predictions / val_total_predictions * 100

    print(f"Validation Loss: {val_epoch_loss:.4f}, "
          f"Validation Task Loss: {val_epoch_task_loss:.4f}, "
          f"Validation Concept Loss: {val_epoch_concept_loss:.4f}, "
          f"Validation Accuracy: {val_accuracy:.2f}%")

print('Training and validation completed')