<a href="https://colab.research.google.com/github/happy-hamburger/Data_Structures_112/blob/main/separateearlyexiting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# In this code we will separate the training of these models to increase accuracy.
# This first piece of code is the early exiting model being trained first.
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import os
import matplotlib.pyplot as plt

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Check CUDA device
if torch.cuda.is_available():
    print(f'Using CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}')
else:
    print('CUDA not available, using CPU.')

# Parameters
batch_size = 64
num_epochs = 2
learning_rate = 0.001
valid_size = 0.1
random_seed = 1
classes = [str(i) for i in range(10)]

# Transformations for training, validation, and test data
transform_train = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

transform_test = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Split training data into training and validation
num_train = len(train_dataset)
split = int(valid_size * num_train)
train_data, valid_data = random_split(train_dataset, [num_train - split, split], generator=torch.Generator().manual_seed(random_seed))

# Create data loaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Define AlexNet model with early exit
class EarlyExit(nn.Module):
    def __init__(self, num_classes=10):
        super(EarlyExit, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 4, 0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.early_exit_fc1 = nn.Linear(384 * 13 * 13, 1024)
        self.early_exit_fc2 = nn.Linear(1024, 10)


    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        # Calculate early exit
        early_exit_out = self.early_exit_fc1(out.reshape(out.size(0), -1))
        early_exit_out = self.early_exit_fc2(early_exit_out)
        return  early_exit_out



# Initialize model, loss function, and optimizer
model = EarlyExit(10).to(device)
criterion_early_exit = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# Directory to save model weights
save_dir = 'saved_models'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)




Using CUDA device: Tesla T4
Files already downloaded and verified
Files already downloaded and verified




In [2]:
class FullModel(nn.Module):
    def __init__(self, num_classes=10):
        super(FullModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 4, 0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096))
        self.fc2 = nn.Linear(4096, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

# Initialize model, loss function, and optimizer
model = EarlyExit(10).to(device)
criterion_early_exit = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)






In [3]:

from pathlib import Path

# Ensure the saved_models directory exists
saved_models_dir = Path('saved_models')
saved_models_dir.mkdir(parents=True, exist_ok=True)

# Training loop for both EarlyExit and FullModel
num_epochs = 50

# Initialize both models
early_exit_model = EarlyExit(10).to(device)
full_model = FullModel(10).to(device)

# Define loss function and optimizers for both models
criterion = nn.CrossEntropyLoss()
optimizer_early_exit = torch.optim.Adam(early_exit_model.parameters(), lr=learning_rate)
optimizer_full = torch.optim.Adam(full_model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    early_exit_model.train()
    full_model.train()

    running_loss_early_exit = 0.0
    running_loss_full = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer_early_exit.zero_grad()
        optimizer_full.zero_grad()

        # Forward pass for early exit model
        outputs_early_exit = early_exit_model(images)
        loss_early_exit = criterion(outputs_early_exit, labels)

        # Forward pass for full model
        outputs_full = full_model(images)
        loss_full = criterion(outputs_full, labels)

        # Backward pass and optimization for early exit model
        loss_early_exit.backward()
        optimizer_early_exit.step()

        # Backward pass and optimization for full model
        loss_full.backward()
        optimizer_full.step()

        # Accumulate the losses
        running_loss_early_exit += loss_early_exit.item()
        running_loss_full += loss_full.item()

    # Calculate average losses
    avg_loss_early_exit = running_loss_early_exit / len(train_loader)
    avg_loss_full = running_loss_full / len(train_loader)

    print(f'Epoch [{epoch+1}/{num_epochs}], Early Exit Model Loss: {avg_loss_early_exit:.4f}, Full Model Loss: {avg_loss_full:.4f}')

    # Validate the models
    early_exit_model.eval()
    full_model.eval()
    correct_early_exit = 0
    correct_full = 0
    total = 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Early exit model validation
            outputs_early_exit = early_exit_model(images)
            _, predicted_early_exit = torch.max(outputs_early_exit.data, 1)
            correct_early_exit += (predicted_early_exit == labels).sum().item()

            # Full model validation
            outputs_full = full_model(images)
            _, predicted_full = torch.max(outputs_full.data, 1)
            correct_full += (predicted_full == labels).sum().item()

            total += labels.size(0)

    accuracy_early_exit = 100 * correct_early_exit / total
    accuracy_full = 100 * correct_full / total

    print(f'Validation Accuracy: Early Exit Model: {accuracy_early_exit:.2f}%, Full Model: {accuracy_full:.2f}%')

# Save the model weights
early_exit_model_path = 'early_exit_model.pth'
full_model_path ='full_model.pth'
torch.save(early_exit_model.state_dict(), early_exit_model_path)
torch.save(full_model.state_dict(), full_model_path)
print(f"Early Exit Model saved to {early_exit_model_path}")
print(f"Full Model saved to {full_model_path}")

print('Finished Training')


Epoch [1/50], Early Exit Model Loss: 6.9750, Full Model Loss: 2.6110
Validation Accuracy: Early Exit Model: 51.52%, Full Model: 22.20%
Epoch [2/50], Early Exit Model Loss: 1.2355, Full Model Loss: 2.0869
Validation Accuracy: Early Exit Model: 52.90%, Full Model: 21.70%
Epoch [3/50], Early Exit Model Loss: 1.1360, Full Model Loss: 1.9445
Validation Accuracy: Early Exit Model: 60.10%, Full Model: 24.42%
Epoch [4/50], Early Exit Model Loss: 1.0673, Full Model Loss: 1.7976
Validation Accuracy: Early Exit Model: 64.70%, Full Model: 37.14%
Epoch [5/50], Early Exit Model Loss: 0.9684, Full Model Loss: 1.6962
Validation Accuracy: Early Exit Model: 63.78%, Full Model: 40.14%
Epoch [6/50], Early Exit Model Loss: 0.8248, Full Model Loss: 1.4712
Validation Accuracy: Early Exit Model: 70.16%, Full Model: 56.32%
Epoch [7/50], Early Exit Model Loss: 0.7251, Full Model Loss: 1.2282
Validation Accuracy: Early Exit Model: 71.12%, Full Model: 62.82%
Epoch [8/50], Early Exit Model Loss: 0.6618, Full Model

In [9]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision

# Define the EarlyExit model class
class EarlyExit(nn.Module):
    def __init__(self, num_classes=10):
        super(EarlyExit, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 4, 0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.early_exit_fc1 = nn.Linear(384 * 13 * 13, 1024)
        self.early_exit_fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        early_exit_out = self.early_exit_fc1(out.reshape(out.size(0), -1))
        early_exit_out = self.early_exit_fc2(early_exit_out)
        return early_exit_out

# Define the FullModel model class
class FullModel(nn.Module):
    def __init__(self, num_classes=10):
        super(FullModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, 4, 0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, 1, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, 1, 1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096))
        self.fc2 = nn.Linear(4096, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

# Load the saved model weights
early_exit_model_path = 'early_exit_model.pth'
full_model_path = 'full_model.pth'

early_exit_model = EarlyExit(10).to(device)
full_model = FullModel(10).to(device)

early_exit_model.load_state_dict(torch.load(early_exit_model_path))
full_model.load_state_dict(torch.load(full_model_path))

early_exit_model.eval()
full_model.eval()

# Data transformations for testing
transform_test = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

# Load CIFAR-10 test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

# Testing loop with confidence threshold
with torch.no_grad():
    ee_counter = 0
    full_counter = 0
    threshold = 0.8
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]

    for images, labels in test_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Pass through early exit model
        early_exit_outputs = early_exit_model(images)
        softmax_outputs = F.softmax(early_exit_outputs, dim=1)

        # Find the maximum probabilities and their indices
        confidence, predicted_classes = torch.max(softmax_outputs, dim=1)

        # Iterate over each sample in the batch
        for i in range(images.size(0)):
            if confidence[i].item() > threshold:
                ee_counter += 1
                predicted = predicted_classes[i]
            else:
                full_counter += 1
                final_output = full_model(images[i].unsqueeze(0))
                _, predicted = torch.max(final_output, 1)

            n_samples += 1
            n_correct += (predicted == labels[i]).sum().item()

            label = labels[i].item()
            if label == predicted.item():
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    test_acc = 100.0 * n_correct / n_samples
    print(f'Test Accuracy of the network with confidence threshold {threshold}: {test_acc:.2f} %')
    print(f'Went through full model {full_counter} times')
    print(f'Early exit counter is {ee_counter}')
    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc:.2f} %')


Files already downloaded and verified
Test Accuracy of the network with confidence threshold 0.8: 81.41 %
Went through full model 736 times
Early exit counter is 9264
Accuracy of 0: 85.60 %
Accuracy of 1: 88.10 %
Accuracy of 2: 66.10 %
Accuracy of 3: 70.60 %
Accuracy of 4: 80.20 %
Accuracy of 5: 75.60 %
Accuracy of 6: 86.10 %
Accuracy of 7: 82.80 %
Accuracy of 8: 89.10 %
Accuracy of 9: 89.90 %
