In [1]:
import argparse
import sys
import os
# Parent folder imports
current_dir = os.path.abspath(os.getcwd())
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
import numpy as np
import torch
from collections import OrderedDict
from modules_sym import PartEqMod
from PIL import Image
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torchvision.transforms import Resize, ToTensor


# Configuration
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument("--dataloader_batch_sz", type=int, default=100)

# Dataset
parser.add_argument("--dataset_root", type=str,
                    default="../datasets")
parser.add_argument("--dataset", type=str,
                    default="PartMNIST")
parser.add_argument("--customdata_train_path", type=str,
                    default="../datasets/mnist60/invariant_dataset_train.pkl")
parser.add_argument("--customdata_test_path", type=str,
                    default="../datasets/mnist60/invariant_dataset_test.pkl")


# Net params
parser.add_argument("--discrete_groups", default=False, type=bool)
parser.add_argument("--n_cyclic_groups", default=8, type=int)  # Size of the networks in Inv AE
parser.add_argument("--in_channels", default=1, type=int)  # Size of the networks in Inv AE
parser.add_argument("--hidden_dim", default=128, type=int)  # Size of the networks in Inv AE
parser.add_argument("--emb_dim", default=32, type=int)  # Dimension of latent spaces
parser.add_argument("--hidden_dim_theta", default=64, type=int)  # Size of theta network
parser.add_argument("--emb_dim_theta", default=100, type=int)  # Size of embedding space in theta network
parser.add_argument("--use_one_layer", action='store_true', default=False)
parser.add_argument("--pretrained_path", type=str, default="./")  # Pretrained Model Path

config, _ = parser.parse_known_args()

# Set seed
if config.seed == -1:
    config.seed = np.random.randint(0, 100000)
pl.seed_everything(config.seed)

[rank: 0] Global seed set to 0


0

In [2]:
# Define the Rotated MNIST dataset for out-of-distibution detection
def RotMNIST_OOD_Dataloader(config, train=False, test=True, custom_batchsize=0, shuffle=True,
                            equiv_dict=""):
    print("Loading out-of-distribution MNIST Dataset for train:",train,", and for test:",test)
    class MNISTRotationDataset(Dataset):
        def __init__(self, train=train, test=test, equiv_dict=equiv_dict):
            self.train = train
            self.test = test
            if self.train:
                self.data = np.loadtxt(config.customdata_train_path)
            elif self.test:
                self.data = np.loadtxt(config.customdata_test_path)
            self.num_samples = len(self.data)
            self.x = self.data[:, :-1].reshape(len(self.data), 28, 28)

            # Transforms
            self.resize28 = Resize(28)
            self.toTensor = ToTensor()

            self.y = self.data[:, -1]
            self.true_thetas_dict = equiv_dict

        def __len__(self):
            return self.num_samples

        def __getitem__(self, index):
            x = self.x[index]
            y = int(self.y[index])

            # Random rotation angle
            rotation = np.random.uniform(-180, 180)
            # Rotate the image using PIL
            imgRot = Image.fromarray(x)  # Convert to PIL Image and scale to 0-255
            imgRot = self.toTensor(self.resize28(imgRot.rotate(rotation, Image.BILINEAR)))

            # Flatten image
            imgRot = imgRot.reshape(1, 28, 28)

            # Define the out-of-distribution label
            true_theta = self.true_thetas_dict[y]
            is_out_of_distrib = torch.tensor(0 if -true_theta <= rotation <= true_theta else 1, dtype=torch.float)

            y = torch.from_numpy(np.array(self.y[index])).float()
            return imgRot, y, is_out_of_distrib

    dataset = MNISTRotationDataset(train=train, test=test, equiv_dict=equiv_dict)
    batch_size_value = int(custom_batchsize) if custom_batchsize else config.dataloader_batch_sz
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size_value,
                                             shuffle=shuffle,
                                             num_workers=0,
                                             drop_last=False)
    return [dataloader]

In [3]:
# For discrete groups
rotation_angles_dict = {
        'C1': [0],
        'C2': [0, 180],
        'C3': [0, 120, -120],
        'C4': [0, 90, 180, -90],
        'C5': [0, 72, 144, -72, -144],
        'C6': [0, 60, 120, 180, -60, -120],
        'C7': [0, 51.4, 102.8, 154.2, -51.4, -102.8, -154.2],
        'C8': [0, 45, 90, 135, 180, -45, -90, -135]
    }
rotation_angles_tensors = {k: torch.tensor(v).cuda() for k, v in rotation_angles_dict.items()}


def ood_test(config, EXPERIMENT):
    if "MNIST" in EXPERIMENT:
        # Load the MNIST dataset to create the Out-of-Distribution rotated MNIST dataset
        config.customdata_test_path = "../datasets/mnist_test.amat"
        # Load dictionaries with per-class level of symmetry of each dataset and pretrained models
        config.in_channels = 1
        config.emb_dim_theta = 128
        config.hidden_dim_theta = 64
        if EXPERIMENT == "ROTMNIST60":
            config.pretrained_path = "../models/mnist60/best_model_theta.pt"
            true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                                5: 60., 6: 60., 7: 60., 8: 60., 9: 60.}
            config.hidden_dim = 64
            config.emb_dim = 200
        if EXPERIMENT == "ROTMNIST60-90":
            config.pretrained_path = "../models/mnist6090/best_model_theta.pt"
            true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                                5: 90., 6: 90., 7: 90., 8: 90., 9: 90.}
            config.hidden_dim = 64
            config.emb_dim = 200
        if EXPERIMENT == "MNISTMULTIPLE":
            config.pretrained_path = "../models/mnistmultiple/best_model_theta.pt"
            true_thetas_dict = {0: 0, 1: 18, 2: 36, 3: 54, 4: 72,
                                5: 90, 6: 108, 7: 126, 8: 144, 9: 162}
            config.hidden_dim = 64
            config.emb_dim = 200
        if EXPERIMENT == "MNISTMULTIPLE_GAUSSIAN":
            config.pretrained_path = "../models/mnistgaussian/best_model_theta.pt"
            std_dev_dict = {0: 0, 1: 9, 2: 18, 3: 27, 4: 36,
                            5: 45, 6: 54, 7: 63, 8: 72, 9: 81}
            true_thetas_dict = std_dev_dict
            config.hidden_dim = 64
            config.emb_dim = 200
        if EXPERIMENT == "MNISTC2C4":
            config.pretrained_path = "../models/mnistc2c4/best_model_theta.pt"
            config.discrete_groups = True
            # Net
            config.hidden_dim = 164
            config.emb_dim = 200
            config.hidden_dim_theta = 32
            config.emb_dim_theta = 100
            true_thetas_dict = {0: 1, 1: 1, 2: 1, 3: 1, 4: 1,  # class 0 is C1, class 1 is C2...
                                5: 3, 6: 3, 7: 3, 8: 3, 9: 3}

    # Load SSL-SYM model
    net = PartEqMod(hparams=config)
    state_dict = torch.load(config.pretrained_path)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("model.", "")  # remove "model."
        new_state_dict[name] = v
    keys_to_load = {k: v for k, v in new_state_dict.items()}

    print(f"Loading pre-trained model for {EXPERIMENT}.")
    net.load_state_dict(keys_to_load, strict=False)
    net.cuda()
    net.eval()

    # Load test mnist dataset with OOD labels
    # Note that no training is needeed. We created OOD classifiers using the pre-trained SSL-Sym models
    test_dataloader = RotMNIST_OOD_Dataloader(config,equiv_dict=true_thetas_dict,
                                              train=False, test=True, shuffle=True)
    test_dataloader = test_dataloader[0]
    correct_predictions = 0
    total_predictions = 0
    ood_degrees = []
    for x, label, ood_label in test_dataloader:
        x = x.cuda()
        label = label.long().cuda()
        ood_label = ood_label.cuda()

        with torch.no_grad():
            # Encoder pass
            emb, v = net.encoder(x)
            rot = net.get_rotation_matrix(v)
            degrees_rot = net.get_degrees(rot).squeeze()

            # Predict levels of symmetry
            degrees_theta = net.theta_function(x).squeeze()
            if config.discrete_groups:
                degrees_theta = torch.argmax(degrees_theta, dim=1)

        # Out-of-distribution symmetry detector
        if config.discrete_groups:
            min_diffs = torch.zeros(degrees_rot.size(0), device=degrees_rot.device)
            for i in range(degrees_rot.size(0)):
                group_key = f'C{degrees_theta[i].item() + 1}'  # Construct the group key (e.g., 'C1', 'C2', etc.)
                group_angles = rotation_angles_tensors[group_key]
                expanded_angle = degrees_rot[i].expand_as(group_angles)
                min_diffs[i] = torch.min(torch.abs(expanded_angle - group_angles))
            
            # Determine out-of-distribution samples (outside ±5º error margin)
            error_margin = 5
            is_out_of_distribution = (min_diffs > error_margin).float()
        else:
            is_out_of_distribution = (degrees_rot.abs() > degrees_theta).float()

        # Update counters
        correct_predictions += (is_out_of_distribution == ood_label.float()).sum().item()
        total_predictions += ood_label.size(0)

        # Store degrees of rotation for OOD samples
        ood_degrees.extend(degrees_rot[is_out_of_distribution == 1].cpu().numpy())


    # Compute accuracy
    accuracy = (correct_predictions / total_predictions) * 100
    print(f"Accuracy of Out-of-Distribution Symmetries classifier in {EXPERIMENT}: {accuracy:.4f}\n")
    return accuracy

In [4]:
# Run OOD-sym prediction experiments
import pandas as pd
results = {}

for EXPERIMENT in ["ROTMNIST60", "ROTMNIST60-90", "MNISTMULTIPLE", "MNISTMULTIPLE_GAUSSIAN", "MNISTC2C4"]:
    print(f"Out-of-Distribution Symmetry Prediction. Experiment: {EXPERIMENT}")
    acc = ood_test(config, EXPERIMENT)
    results[EXPERIMENT] = acc

df_results = pd.DataFrame.from_dict(results, orient="index", columns=["Accuracy"])
print("Out-of-distribution symmetry detection")
print(df_results)
try:
    df_results.to_csv(f"plots/ood_results.csv")
except:
    try:
        home_directory = os.path.expanduser('~')
        file_path = os.path.join(home_directory, "Projects/alonso_syms/ood_results.csv")
        df_results.to_csv(file_path)
        print(f"File saved to {file_path}")
    except:
        pass

Out-of-Distribution Symmetry Prediction. Experiment: ROTMNIST60


  full_mask[mask] = norms.to(torch.uint8)
  full_mask[mask] = norms.to(torch.uint8)


Loading pre-trained model for ROTMNIST60.
Loading out-of-distribution MNIST Dataset for train: False , and for test: True
Accuracy of Out-of-Distribution Symmetries classifier in ROTMNIST60: 92.3680

Out-of-Distribution Symmetry Prediction. Experiment: ROTMNIST60-90
Loading pre-trained model for ROTMNIST60-90.
Loading out-of-distribution MNIST Dataset for train: False , and for test: True
Accuracy of Out-of-Distribution Symmetries classifier in ROTMNIST60-90: 90.9300

Out-of-Distribution Symmetry Prediction. Experiment: MNISTMULTIPLE
Loading pre-trained model for MNISTMULTIPLE.
Loading out-of-distribution MNIST Dataset for train: False , and for test: True
Accuracy of Out-of-Distribution Symmetries classifier in MNISTMULTIPLE: 89.2160

Out-of-Distribution Symmetry Prediction. Experiment: MNISTMULTIPLE_GAUSSIAN
Loading pre-trained model for MNISTMULTIPLE_GAUSSIAN.
Loading out-of-distribution MNIST Dataset for train: False , and for test: True
Accuracy of Out-of-Distribution Symmetries c