In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Load VGG19 without the top layers
base_model = models.vgg19(pretrained=True)
base_model.classifier = nn.Identity()  # Remove the fully connected layers
base_model.eval()  # Set to evaluation mode (for frozen layers)
for param in base_model.parameters():
    param.requires_grad = False  # Freeze the base model



In [4]:
# Concept Bottleneck Layer (CBL)
class ConceptBottleneckVGG19(nn.Module):
    def __init__(self, base_model, num_concepts, num_classes):
        super(ConceptBottleneckVGG19, self).__init__()
        self.base_model = base_model
        self.flatten = nn.Flatten()
        self.concept_bottleneck = nn.Sequential(
            nn.Linear(512 * 7 * 7, num_concepts),  # Adjust input size for VGG19
            nn.ReLU(),
            nn.Dropout(0.5),
        )
        self.classifier = nn.Linear(num_concepts, num_classes)

    def forward(self, x):
        x = self.base_model.features(x)
        x = self.base_model.avgpool(x)
        x = self.flatten(x)
        x = self.concept_bottleneck(x)
        x = self.classifier(x)
        return x

In [5]:
# Parameters
num_concepts = 32
num_classes = 10
model = ConceptBottleneckVGG19(base_model, num_concepts, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
# Data transformations for CIFAR-10
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
# Load CIFAR-10 dataset
train_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
# Split training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [9]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [10]:
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader.dataset)
    val_acc = 100 * correct / total
    print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")

In [11]:
# Training loop
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = 100 * correct / total
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        # Validation
        validate_model(model, val_loader, criterion)

In [15]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=6)

Epoch 1/6, Loss: 0.7805, Accuracy: 71.22%
Validation Loss: 0.4469, Accuracy: 84.55%
Epoch 2/6, Loss: 0.6616, Accuracy: 75.58%
Validation Loss: 0.4228, Accuracy: 85.54%
Epoch 3/6, Loss: 0.5894, Accuracy: 77.97%
Validation Loss: 0.4287, Accuracy: 85.36%
Epoch 4/6, Loss: 0.5410, Accuracy: 79.62%
Validation Loss: 0.4170, Accuracy: 85.97%
Epoch 5/6, Loss: 0.4969, Accuracy: 81.03%
Validation Loss: 0.4258, Accuracy: 86.03%
Epoch 6/6, Loss: 0.4625, Accuracy: 82.64%
Validation Loss: 0.4484, Accuracy: 86.13%


In [16]:
# Evaluate the model on the test set
def test_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total
    print(f"Test Accuracy: {test_acc:.2f}%")

In [17]:
test_model(model, test_loader)

Test Accuracy: 85.70%


In [18]:
# Save the trained model
torch.save(model.state_dict(), "saved_models/cbm_vgg19_model.pth")