In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from early_exit_resnet import *


In [4]:
head_net_part1 = HeadNetworkPart1(block=Bottleneck, in_planes=64, num_blocks=[3], num_classes=10)
head_net_part2 = HeadNetworkPart2(Bottleneck, 256, [4], num_classes=10)
head_net_part3 = HeadNetworkPart3(block=Bottleneck, in_planes=512, num_blocks=[6], num_classes=10)
tail_net = TailNetwork(block=Bottleneck, in_planes=1024, num_blocks=[3, 4, 6, 3], num_classes=10)


In [2]:
combined_state_dict = torch.load('lightning_logs/version_15/checkpoints/epoch=3-step=149660.ckpt')


In [7]:

head1_state_dict = {}
head2_state_dict = {}
head3_state_dict = {}
tail_state_dict = {}


for key, value in combined_state_dict['state_dict'].items():
    if key.startswith('head1'):
        head1_state_dict[key.removeprefix('head1.')] = value
    elif key.startswith('head2'):
        head2_state_dict[key.removeprefix('head2.')] = value
    elif key.startswith('head3'):
        head3_state_dict[key.removeprefix('head3.')] = value
    else:
        tail_state_dict[key.removeprefix('tail.')] = value


In [9]:
torch.save(head1_state_dict,"models/head1_resnet50.pth")
torch.save(head2_state_dict,"models/head2_resnet50.pth")
torch.save(head3_state_dict,"models/head3_resnet50.pth")
torch.save(tail_state_dict,"models/tail_resnet50.pth")

In [10]:
head_net_part1.load_state_dict(torch.load("models/head1_resnet50.pth"))
head_net_part2.load_state_dict(torch.load("models/head2_resnet50.pth"))
head_net_part3.load_state_dict(torch.load("models/head3_resnet50.pth"))
tail_net.load_state_dict(torch.load("models/tail_resnet50.pth"))


<All keys matched successfully>

In [12]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    total_samples = len(test_loader.dataset)
    correct_predictions = [0, 0, 0, 0]  # Assuming three exits

    with torch.no_grad():  # No need to compute gradients
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            exits = model(inputs)  # Forward pass
            softmax = torch.nn.Softmax(dim=1)
            
            for i, exit in enumerate(exits):
                predictions = softmax(exit).argmax(dim=1)
                correct_predictions[i] += (predictions == labels).type(torch.float).sum().item()

    accuracies = [correct / total_samples for correct in correct_predictions]
    return accuracies

In [3]:
model = EarlyExitResNet50(num_classes=3)
model.load_state_dict(combined_state_dict['state_dict'])



<All keys matched successfully>

In [None]:
from dataset import Flame2DataModule
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

data_module = Flame2DataModule(image_dir="../Flame2/254_RGB", batch_size=32,transform=transform)
data_module.setup()
print(data_module.evaluate_model(model,"cuda"))