In [1]:
from resnet_dataset import ResnetImageDataset, get_class_str
from resnet_shuffled import ShuffledResnetVariable, ShuffledResnetVariableConv2d

from torchvision.models import resnet18, resnet50, ResNet50_Weights, ResNet18_Weights
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import os
import csv

from tqdm import tqdm


In [2]:
CONFIG = {
    "dataset":{
        "img_dir": "C:/Users/Leonard/Desktop/DP/early_MLP_implementations/data/resnet18_set/images_unpacked/",
        "img_num": 1000,
    },
    "dataloader":{
        "batch_size": 16,
        "shuffle": False,
    },
}

In [3]:
def test_accuracy(model, dataloader):
    model.eval()
    model = model.to("cuda" if torch.cuda.is_available() else "cpu")

    correct_total = 0
    sample_count = 0

    dataloader = tqdm(dataloader, total=len(dataloader))
    
    for images, labels in dataloader:
        images, labels = images.to("cuda" if torch.cuda.is_available() else "cpu"), labels.to("cuda" if torch.cuda.is_available() else "cpu")
        pred = model.forward(images)
        
        probabilities = torch.nn.functional.softmax(pred, dim=1)
        topk_prob, topk_catid = torch.topk(probabilities, 1)

        topk_catid = torch.squeeze(topk_catid, dim=1).type(torch.float64)
        correct = (labels == topk_catid).sum().item()
        correct_total += correct

        sample_count += len(images)
    return correct_total/sample_count


In [4]:
transform_pipe = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
test_dataset = ResnetImageDataset(img_dir=CONFIG['dataset']['img_dir'], img_num_cap=CONFIG['dataset']['img_num'], transform=transform_pipe)
test_dataloader = DataLoader(test_dataset, batch_size=CONFIG['dataloader']['batch_size'], shuffle=CONFIG['dataloader']['shuffle'])

print(f"The directory contains {len(os.listdir(CONFIG['dataset']['img_dir']))} image files")
print(f"We are using {CONFIG['dataset']['img_num']}/{len(os.listdir(CONFIG['dataset']['img_dir']))} images")

The dataset contains 643 unique classes
The directory contains 50000 image files
We are using 1000/50000 images


In [6]:
f = open('../data/accuracy_scores_2.csv', 'w')
writer = csv.writer(f)

header = ['model', 'layer', 'num_samples', 'num_shuffled', 'accuracy']
writer.writerow(header)



47

In [7]:
res50 = resnet50(weights=ResNet50_Weights.DEFAULT)

shuffled_res = ShuffledResnetVariable(model=res50)

shuffled_res_conv1 = ShuffledResnetVariableConv2d(model=res50)
shuffled_res_conv1.change_n_shuffled(0)

acc = test_accuracy(model=shuffled_res_conv1, dataloader=test_dataloader)
print(acc)

  indices = torch.argsort(torch.rand_like(self.weight.T), dim=-1)
100%|██████████| 63/63 [00:50<00:00,  1.25it/s]

0.003





In [8]:
for i in range(0, 2049, 32):
    shuffled_res.change_n_shuffled(i)
    accuracy = test_accuracy(model=shuffled_res, dataloader=test_dataloader)

    row = ["resnet50", "fc", CONFIG["dataset"]["img_num"], i, accuracy]
    writer.writerow(row)
    print(row)


100%|██████████| 63/63 [00:34<00:00,  1.80it/s]


['resnet50', 'fc', 1000, 0, 0.798]


100%|██████████| 63/63 [00:34<00:00,  1.81it/s]


['resnet50', 'fc', 1000, 32, 0.789]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 64, 0.779]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 96, 0.773]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 128, 0.768]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 160, 0.761]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 192, 0.748]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 224, 0.723]


100%|██████████| 63/63 [00:32<00:00,  1.96it/s]


['resnet50', 'fc', 1000, 256, 0.709]


100%|██████████| 63/63 [00:32<00:00,  1.96it/s]


['resnet50', 'fc', 1000, 288, 0.72]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 320, 0.692]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 352, 0.692]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 384, 0.677]


100%|██████████| 63/63 [00:33<00:00,  1.89it/s]


['resnet50', 'fc', 1000, 416, 0.665]


100%|██████████| 63/63 [00:36<00:00,  1.74it/s]


['resnet50', 'fc', 1000, 448, 0.665]


100%|██████████| 63/63 [00:32<00:00,  1.93it/s]


['resnet50', 'fc', 1000, 480, 0.659]


100%|██████████| 63/63 [00:32<00:00,  1.96it/s]


['resnet50', 'fc', 1000, 512, 0.61]


100%|██████████| 63/63 [00:31<00:00,  2.02it/s]


['resnet50', 'fc', 1000, 544, 0.615]


100%|██████████| 63/63 [00:31<00:00,  1.97it/s]


['resnet50', 'fc', 1000, 576, 0.615]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 608, 0.601]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 640, 0.57]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 672, 0.574]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 704, 0.566]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 736, 0.541]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 768, 0.532]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 800, 0.537]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 832, 0.503]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 864, 0.499]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 896, 0.468]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 928, 0.456]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 960, 0.434]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 992, 0.422]


100%|██████████| 63/63 [00:31<00:00,  2.02it/s]


['resnet50', 'fc', 1000, 1024, 0.398]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 1056, 0.397]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1088, 0.385]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1120, 0.373]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 1152, 0.382]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1184, 0.352]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 1216, 0.35]


100%|██████████| 63/63 [00:32<00:00,  1.97it/s]


['resnet50', 'fc', 1000, 1248, 0.29]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1280, 0.297]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1312, 0.271]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1344, 0.248]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 1376, 0.25]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1408, 0.231]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1440, 0.22]


100%|██████████| 63/63 [00:32<00:00,  1.97it/s]


['resnet50', 'fc', 1000, 1472, 0.21]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1504, 0.217]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 1536, 0.223]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1568, 0.189]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 1600, 0.166]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1632, 0.138]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1664, 0.149]


100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


['resnet50', 'fc', 1000, 1696, 0.121]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1728, 0.129]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1760, 0.084]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1792, 0.087]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1824, 0.07]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1856, 0.06]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1888, 0.064]


100%|██████████| 63/63 [00:31<00:00,  2.00it/s]


['resnet50', 'fc', 1000, 1920, 0.045]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]


['resnet50', 'fc', 1000, 1952, 0.031]


100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


['resnet50', 'fc', 1000, 1984, 0.025]


100%|██████████| 63/63 [00:31<00:00,  2.02it/s]


['resnet50', 'fc', 1000, 2016, 0.014]


100%|██████████| 63/63 [00:31<00:00,  1.99it/s]

['resnet50', 'fc', 1000, 2048, 0.002]





In [9]:
f.close()