In [1]:
import torch
import torchvision.transforms as tvtf
import sys
sys.path.append('..')
from datasets import PCam
import timm
import torch.nn as nn
from tqdm import tqdm
from post_hoc_equivariant import *

In [2]:
data_mean = [0., 0., 0.]  # Your mean values
data_stddev = [1., 1., 1.]  # Your standard deviation values

transform_test = tvtf.Compose([
    tvtf.Resize(96, interpolation=tvtf.InterpolationMode.BICUBIC),  # Resize the image to 96x96 using bicubic interpolation
    tvtf.CenterCrop(96),  # Center crop the image to 96x96
    tvtf.ToTensor(),  # Convert the image to a PyTorch tensor
    tvtf.Normalize(mean=data_mean, std=data_stddev)  # Normalize the tensor
])

validation_set = PCam(root="../data", train=False, valid=True, download=True, transform=transform_test, data_fraction=0.005)
test_set = PCam(root="../data", train=False, download=True, transform=transform_test, data_fraction=0.01)

val_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_size=128,
    shuffle=False,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

Total length of the dataset: 32768
Reduced length of the dataset: 163
Total length of the dataset: 32768
Reduced length of the dataset: 327


In [33]:
class PretrainedResnet50(nn.Module):

    def __init__(self):
        super().__init__()
        # https://huggingface.co/1aurent/resnet50.tiatoolbox-pcam/
        model = timm.create_model(model_name="hf-hub:1aurent/resnet50.tiatoolbox-pcam", pretrained=True)
        self.layers_before_last_linear = nn.Sequential(*list(model.children())[:-1])
        self.mlp_head = nn.Sequential(list(model.children())[-1])

    def forward(self, x, output_cls=False):
        cls = self.layers_before_last_linear(x)
        if output_cls:
            return cls
        logits = self.mlp_head(cls)
        return logits

In [43]:
model = PretrainedResnet50()

In [44]:
model.eval()
correct = total = 0
with torch.no_grad():  # disable gradient calculation during inference
    for inputs, labels in tqdm(test_loader):
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    print(test_acc)

100%|██████████| 3/3 [00:10<00:00,  3.51s/it]

83.79204892966361





In [48]:
# mean pooling
model = model.eval()
eq_model_mean = PostHocEquivariantMean(model, n_rotations=4, flips=True)

In [49]:
eq_model_mean.eval()
correct = total = 0
with torch.no_grad():  # disable gradient calculation during inference
    for inputs, labels in tqdm(test_loader):
        outputs = eq_model_mean(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    print(test_acc)

100%|██████████| 3/3 [01:45<00:00, 35.04s/it]

85.62691131498471





In [38]:
# max pooling
model = model.eval()
eq_model_max = PostHocEquivariantMax(model, n_rotations=4, flips=True)

In [39]:
eq_model_max.eval()
correct = total = 0
with torch.no_grad():  # disable gradient calculation during inference
    for inputs, labels in tqdm(test_loader):
        outputs = eq_model_max(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    print(test_acc)


100%|██████████| 3/3 [01:40<00:00, 33.37s/it]

85.62691131498471



