In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from early_exit_resnet import *


In [88]:
head_net_part1 = HeadNetworkPart1(block=Bottleneck, in_planes=64, num_blocks=[3], num_classes=10)
head_net_part2 = HeadNetworkPart2(Bottleneck, 256, [4], num_classes=10)
head_net_part3 = HeadNetworkPart3(block=Bottleneck, in_planes=512, num_blocks=[6], num_classes=10)
tail_net = TailNetwork(block=Bottleneck, in_planes=1024, num_blocks=[3, 4, 6, 3], num_classes=10)


In [69]:
combined_state_dict = torch.load('lightning_logs/version_39/checkpoints/epoch=9-step=520.ckpt')


In [70]:

head1_state_dict = {}
head2_state_dict = {}
head3_state_dict = {}
tail_state_dict = {}


for key, value in combined_state_dict['state_dict'].items():
    if key.startswith('head1'):
        head1_state_dict[key.removeprefix('head1.')] = value
    elif key.startswith('head2'):
        head2_state_dict[key.removeprefix('head2.')] = value
    elif key.startswith('head3'):
        head3_state_dict[key.removeprefix('head3.')] = value
    else:
        tail_state_dict[key.removeprefix('tail.')] = value


In [71]:
torch.save(head1_state_dict,"models/cifar10/head1_resnet50.pth")
torch.save(head2_state_dict,"models/cifar10/head2_resnet50.pth")
torch.save(head3_state_dict,"models/cifar10/head3_resnet50.pth")
torch.save(tail_state_dict,"models/cifar10/tail_resnet50.pth")

In [104]:
head_net_part1.load_state_dict(torch.load("models/cifar10/head1_resnet50.pth"))
# head_net_part2.load_state_dict(torch.load("models/head2_resnet50.pth"))
# head_net_part3.load_state_dict(torch.load("models/head3_resnet50.pth"))
# tail_net.load_state_dict(torch.load("models/tail_resnet50.pth"))


<All keys matched successfully>

In [110]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    total_samples = len(test_loader.dataset)
    correct = 0 # Assuming three exits

    with torch.no_grad():  # No need to compute gradients
        for inputs, labels in test_loader:
            # print(inputs)
            inputs = inputs.to(device)
            labels = labels.to(device)
            exit = model(inputs)[1]  # Forward pass
            softmax = torch.nn.Softmax(dim=1)
            exit_soft = softmax(exit)
            predictions = exit_soft.argmax(dim=1)
            print(inputs.shape)
            correct += 1 if (predictions == labels) else 0

    accuracy = correct / total_samples
    return accuracy

In [106]:
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,transform=transform)
dataloader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=1)


Files already downloaded and verified


In [111]:
accuracy = evaluate_model(head_net_part1,dataloader, 'cpu')
print(accuracy)

torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
t

In [108]:
print(head_net_part1.conv1.weight)

Parameter containing:
tensor([[[[-0.1712,  0.1379, -0.0461],
          [-0.1524,  0.1505,  0.1423],
          [ 0.0199, -0.1704,  0.1499]],

         [[-0.1138,  0.0437, -0.1105],
          [ 0.1410,  0.0629, -0.1853],
          [-0.1650, -0.0616,  0.0838]],

         [[ 0.0817, -0.0064,  0.1190],
          [ 0.0860,  0.1002, -0.0764],
          [ 0.0427, -0.0219,  0.1117]]],


        [[[-0.0328,  0.0188, -0.0039],
          [-0.1313, -0.1366, -0.1489],
          [ 0.1330, -0.1349, -0.0988]],

         [[ 0.0967, -0.0589,  0.0178],
          [-0.0275,  0.0541, -0.0468],
          [ 0.1830,  0.0110,  0.1865]],

         [[-0.1955,  0.0400,  0.0822],
          [-0.1342,  0.1742, -0.1089],
          [ 0.1046,  0.0586,  0.1181]]],


        [[[ 0.0139, -0.0056, -0.1066],
          [-0.0118,  0.1503,  0.0908],
          [ 0.1657, -0.1041, -0.0535]],

         [[-0.0890, -0.1459, -0.1690],
          [-0.0583,  0.0189,  0.0926],
          [ 0.1022, -0.0586, -0.1416]],

         [[ 0.1612, -0

In [3]:
model = EarlyExitResNet50(num_classes=3)
model.load_state_dict(combined_state_dict['state_dict'])



<All keys matched successfully>

In [None]:
from dataset import Flame2DataModule
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

data_module = Flame2DataModule(image_dir="../Flame2/254_RGB", batch_size=32,transform=transform)
data_module.setup()
print(data_module.evaluate_model(model,"cuda"))