In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from early_exit_resnet import *
import os
import pickle

In [None]:
# Define the transform for preprocessing the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

class CacheDataset(Dataset):
    def __init__(self, base_dataset, cached_data_file='cached_logits.pkl', models=None, compute_logits=False):
        self.base_dataset = base_dataset
        self.cached_data_file = cached_data_file
        self.models = models
        if not compute_logits and os.path.exists(self.cached_data_file):
            with open(self.cached_data_file, 'rb') as f:
                self.cached_data = pickle.load(f)
        else:
            if compute_logits and models:
                print("Computing and saving logits...")
                self.cached_data = self.compute_and_cache_logits()
            else:
                raise ValueError("Compute logits is set to true but no models provided or cached data file not found.")

    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]
        data = self.cached_data[idx]
        logits = data['logits']
        
        return {
            'image': image,
            'label': label,
            'logits': logits
        }

    def compute_and_cache_logits(self):
        cache = []
        for idx in range(len(self.base_dataset)):
            image, label = self.base_dataset[idx]
            
            x = image.unsqueeze(0)  # unsqueeze to add batch dimension
            features, logit1 = self.models['block1'](x)
            features, logit2 = self.models['block2'](features)
            features, logit3 = self.models['block3'](features)
            logit4 = self.models['block4'](features)

            logit1 = logit1.detach().squeeze(0)
            print(logit1.argmax())
            logit2 = logit2.detach().squeeze(0)
            logit3 = logit3.detach().squeeze(0)
            logit4 = logit4.detach().squeeze(0)
            
            logits_tuple = (logit1, logit2, logit3, logit4)
            
            cache.append({
                'logits': logits_tuple
            })

        with open(self.cached_data_file, 'wb') as f:
            pickle.dump(cache, f)

        return cache
        with open(self.cached_data_file, 'wb') as f:
            pickle.dump(cache, f)

        return cache


# Load CIFAR-10 test set
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Instantiate models
block1 = HeadNetworkPart1(block=Bottleneck, in_planes=64, num_blocks=[3], num_classes=10)
block2 = HeadNetworkPart2(Bottleneck, 256, [4], num_classes=10)
block3 = HeadNetworkPart3(block=Bottleneck, in_planes=512, num_blocks=[6], num_classes=10)
block4 = TailNetwork(block=Bottleneck, in_planes=1024, num_blocks=[3], num_classes=10)

# Load weights
block1.load_state_dict(torch.load("models/cifar10/head1_resnet50.pth"))
block2.load_state_dict(torch.load("models/cifar10/head2_resnet50.pth"))
block3.load_state_dict(torch.load("models/cifar10/head3_resnet50.pth"))
block4.load_state_dict(torch.load("models/cifar10/tail_resnet50.pth"))

# Set models to evaluation mode
block1.eval()
block2.eval()
block3.eval()
block4.eval()

# Combine models into a dictionary
models = {
    'block1': block1,
    'block2': block2,
    'block3': block3,
    'block4': block4
}

# Create custom dataset
# Set compute_logits to True only if you need to compute and cache the logits for the first time
compute_logits = True # Set to False for subsequent runs after initial cache creation
custom_test_set = CacheDataset(test_set, models=models, compute_logits=compute_logits)
# Create dataloader
dataloader = DataLoader(custom_test_set, batch_size=1, shuffle=True, num_workers=1)

# Testing the dataloader
for data in dataloader:
    print(data)
    # {
    #  'image': tensor, 
    #  'exit_number': tensor, 
    #  'logits': tensor, 
    #  'label': tensor
    # }
    break  # Remove this to loop through the entire dataset

Files already downloaded and verified
Computing and saving logits...
tensor(2)
tensor(0)
tensor(8)
tensor(2)
tensor(6)
tensor(2)
tensor(0)
tensor(2)
tensor(5)
tensor(2)
tensor(2)
tensor(7)
tensor(2)
tensor(3)
tensor(0)
tensor(2)
tensor(5)
tensor(2)
tensor(8)
tensor(2)
tensor(7)
tensor(0)
tensor(2)
tensor(9)
tensor(4)
tensor(2)
tensor(2)
tensor(2)
tensor(4)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(9)
tensor(5)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(9)
tensor(2)
tensor(4)
tensor(4)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(8)
tensor(2)
tensor(3)
tensor(2)
tensor(2)
tensor(3)
tensor(4)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(3)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(8)
t

In [5]:
dataloader = DataLoader(custom_test_set, batch_size=1, shuffle=True, num_workers=1)

# # Testing the dataloader
# count = 0 
# for data in dataloader:
#     print(data["logits"])
#     if count == 3:
#         break 

NameError: name 'custom_test_set' is not defined