In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ee_model import *

class HeadNetwork(nn.Module):
    def __init__(self, block, in_planes, num_blocks, num_classes=10):
        super(HeadNetwork, self).__init__()
        self.in_planes = in_planes
        
        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.early_exit_1 = EarlyExitBlock(64 * block.expansion, num_classes)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.early_exit_2 = EarlyExitBlock(128 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        ee1_out = self.early_exit_1(out)
        out = self.layer2(out)
        ee2_out = self.early_exit_2(out)
        return out, ee1_out, ee2_out


  warn(


In [2]:
class HeadNetworkPart1(nn.Module):
    def __init__(self, block, in_planes, num_blocks, num_classes=10):
        super(HeadNetworkPart1, self).__init__()
        self.in_planes = in_planes

        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        
        self.early_exit_1 = EarlyExitBlock(64 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        ee1_out = self.early_exit_1(out)
        return out, ee1_out


In [3]:
class HeadNetworkPart2(nn.Module):
    def __init__(self, block, in_planes, num_blocks, num_classes=10):
        super(HeadNetworkPart2, self).__init__()
        self.in_planes = in_planes  

        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.early_exit_2 = EarlyExitBlock(128 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.layer2(x)
        ee2_out = self.early_exit_2(out)
        return out, ee2_out


In [4]:
class TailNetwork(nn.Module):
    def __init__(self, block, in_planes, num_blocks, num_classes=10):
        super(TailNetwork, self).__init__()
        self.in_planes = in_planes 

        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.layer3(x)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        final_out = self.linear(out)
        return final_out


In [5]:
head_net_part1 = HeadNetworkPart1(block=Bottleneck, in_planes=64, num_blocks=[3, 4], num_classes=10)
head_net_part2 = HeadNetworkPart2(block=Bottleneck, in_planes=256, num_blocks=[3, 4], num_classes=10)
tail_net = TailNetwork(block=Bottleneck, in_planes=512, num_blocks=[3, 4, 6, 3], num_classes=10)


In [29]:
combined_state_dict = torch.load('resnet50.pth')

head1_state_dict = {}
head2_state_dict = {}
tail_state_dict = {}


for key, value in combined_state_dict.items():
    if key.startswith('conv1') or key.startswith('bn1') or key.startswith('layer1') or key.startswith('early_exit_1'):
        head1_state_dict[key] = value
    elif key.startswith('layer2') or key.startswith('early_exit_2'):
        head2_state_dict[key] = value
    else:
        tail_state_dict[key] = value


In [6]:
torch.save(head1_state_dict,"data/head1_resnet50.pth")
torch.save(head2_state_dict,"data/head2_resnet50.pth")
torch.save(tail_state_dict,"data/tail_resnet50.pth")

NameError: name 'head1_state_dict' is not defined

In [18]:
# Load state dictionaries into the respective models
head_net_part1.load_state_dict(torch.load("models/head1_resnet50.pth"))
head_net_part2.load_state_dict(torch.load("models/head2_resnet50.pth"))
tail_net.load_state_dict(torch.load("models/tail_resnet50.pth"))
head_net_part1.eval()
head_net_part2.eval()
tail_net.eval()
head_net_part1 = head_net_part1.to("cuda")
head_net_part2 = head_net_part2.to("cuda")
tail_net = tail_net.to("cuda")


In [39]:
head_net_part1 = HeadNetworkPart1(block=Bottleneck, in_planes=64, num_blocks=[3, 4], num_classes=10)
head_net_part2 = HeadNetworkPart2(block=Bottleneck, in_planes=256, num_blocks=[3, 4], num_classes=10)
tail_net = TailNetwork(block=Bottleneck, in_planes=512, num_blocks=[3, 4, 6, 3], num_classes=10)

# Loads Saved Models
head_net_part1.load_state_dict(torch.load("head1_resnet50.pth"))
head_net_part2.load_state_dict(torch.load("head2_resnet50.pth"))
tail_net.load_state_dict(torch.load("tail_resnet50.pth"))
head_net_part1.eval()
head_net_part2.eval()
tail_net.eval()


TailNetwork(
  (layer3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [20]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
full_testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(full_testset, batch_size=64, shuffle=True, num_workers=2)


Files already downloaded and verified


In [21]:
total_samples = len(testloader.dataset)
correct_predictions = [0, 0, 0]  

with torch.no_grad():  
    for inputs, labels in testloader:
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")

        output_part1, ee1_output = head_net_part1(inputs)
        output_part2, ee2_output = head_net_part2(output_part1)
        output = tail_net(output_part2) 
        softmax = torch.nn.Softmax(dim=1)
        
        
        predictions = softmax(ee1_output).argmax(dim=1)
        correct_predictions[0] += (predictions == labels).type(torch.float).sum().item()
        predictions = softmax(ee2_output).argmax(dim=1)
        correct_predictions[1] += (predictions == labels).type(torch.float).sum().item()
        predictions = softmax(output).argmax(dim=1)
        correct_predictions[2] += (predictions == labels).type(torch.float).sum().item()

accuracies = [correct / total_samples for correct in correct_predictions]
print(accuracies)

[0.6814, 0.7622, 0.7974]
