In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
from pathlib import Path

# Reuse paths and constants from training.py
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent
DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')
MODEL_DIR = os.path.join(ROOT_DIR, 'models')
CBM_DIR = os.path.join(MODEL_DIR, 'cbm')
os.makedirs(CBM_DIR, exist_ok=True)

# Extended Dataset for Concept Labels
class ConceptMushroomDataset(Dataset):
    def __init__(self, preprocessed_dir, csv_path, transform=None):
        self.preprocessed_dir = preprocessed_dir
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

        # Images and Labels
        self.image_ids = self.csv_data['Image'].values
        self.labels = self.csv_data['Mushroom'].values
        self.concepts = self.csv_data.iloc[:, 2:].values  # Assuming concepts are columns 2+

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = str(self.image_ids[idx]).zfill(5)
        label = self.labels[idx]
        concepts = torch.tensor(self.concepts[idx], dtype=torch.float32)

        # Load image
        image_path = os.path.join(self.preprocessed_dir, f"{image_id}.pt")
        image = torch.load(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label, concepts

# CBM Model Definition
class ConceptBottleneckModel(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ConceptBottleneckModel, self).__init__()

        # Feature extractor (ResNet)
        self.feature_extractor = models.resnet50(pretrained=True)
        self.feature_extractor.fc = nn.Linear(self.feature_extractor.fc.in_features, num_concepts)

        # Bottleneck to Classifier
        self.classifier = nn.Sequential(
            nn.Linear(num_concepts, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        concepts = self.feature_extractor(x)
        predictions = self.classifier(concepts)
        return concepts, predictions

# Data Preparation
def get_data_loaders(preprocessed_dir, csv_path, batch_size=32):
    # Init
    dataset = ConceptMushroomDataset(preprocessed_dir, csv_path)

    # Split 
    indices = list(range(len(dataset)))
    train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

    # Subsets
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    # Dataloaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

# Training Loop for Concepts
def train_concept_layer(model, train_loader, val_loader, device, epochs=10):
    criterion = nn.MSELoss()  # Mean squared error for concept predictions
    optimizer = optim.Adam(model.feature_extractor.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for images, _, concepts in tqdm(train_loader, desc=f"[Train Concepts: Epoch {epoch+1}/{epochs}]"):
            images, concepts = images.to(device), concepts.to(device)

            optimizer.zero_grad()
            predicted_concepts, _ = model(images)
            loss = criterion(predicted_concepts, concepts)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# Training Loop for Classification
def train_classification_layer(model, train_loader, val_loader, device, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for images, labels, _ in tqdm(train_loader, desc=f"[Train Classification: Epoch {epoch+1}/{epochs}]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            _, predictions = model(images)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# Evaluation
def evaluate_cbm(model, test_loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc="[Evaluate]"):
            images, labels = images.to(device), labels.to(device)
            _, predictions = model(images)
            
            _, predicted_labels = torch.max(predictions, 1)
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = 100 * total_correct / total_samples
    print(f"Accuracy: {accuracy:.2f}%")


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader, val_loader, test_loader = get_data_loaders(PREPROCESSED_DIR, CSV_PATH)

num_concepts = 20  # Example concept count
num_classes = len(pd.read_csv(CSV_PATH)['Mushroom'].unique())

model = ConceptBottleneckModel(num_concepts, num_classes).to(device)

print("Training Concept Layer")
train_concept_layer(model, train_loader, val_loader, device)

print("Training Classification Layer")
train_classification_layer(model, train_loader, val_loader, device)

print("Evaluating Model")
evaluate_cbm(model, test_loader, device)



Training Concept Layer


  image = torch.load(image_path)
  return F.mse_loss(input, target, reduction=self.reduction)
[Train Concepts: Epoch 1/10]:   0%|          | 0/52 [00:01<?, ?it/s]


RuntimeError: The size of tensor a (20) must match the size of tensor b (0) at non-singleton dimension 1