In [1]:
import argparse
import numpy as np
import torch
import sys
import os
# Parent folder imports
current_dir = os.path.abspath(os.getcwd())
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from data_loading_sym import PartialMNIST_AE_Dataloader, RotMNIST_AE_Dataloader
from torchvision import models
import pytorch_lightning as pl

# Configuration
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument("--dataloader_batch_sz", type=int, default=256)

# Dataset
parser.add_argument("--dataset_root", type=str,
                    default="../datasets")
parser.add_argument("--dataset", type=str,
                    default="PartMNIST")
parser.add_argument("--customdata_train_path", type=str,
                    default="../datasets/mnist60/invariant_dataset_train.pkl")
parser.add_argument("--customdata_test_path", type=str,
                    default="../datasets/mnist60/invariant_dataset_test.pkl")


# Net params
parser.add_argument("--discrete_groups", default=False, type=bool)
parser.add_argument("--in_channels", default=1, type=int)  # Size of the networks in Inv AE
parser.add_argument("--hidden_dim", default=128, type=int)  # Size of the networks in Inv AE
parser.add_argument("--emb_dim", default=32, type=int)  # Dimension of latent spaces
parser.add_argument("--hidden_dim_theta", default=64, type=int)  # Size of theta network
parser.add_argument("--emb_dim_theta", default=100, type=int)  # Size of embedding space in theta network
parser.add_argument("--use_one_layer", action='store_true', default=False)
parser.add_argument("--pretrained_path", type=str, default="./")  # Pretrained Model Path

config, _ = parser.parse_known_args()

In [2]:
# Set seed
if config.seed == -1:
    config.seed = np.random.randint(0, 100000)
pl.seed_everything(config.seed)

[rank: 0] Global seed set to 0


0

In [3]:
def train(config, sym_std):
    if sym_std:
        print("Symmetry Standardization")
        # Change config to load datasets with sym std
        if "MNIST" in EXPERIMENT:
            if EXPERIMENT == "ROTMNIST60":
                config.customdata_train_path = "../models/mnist60/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnist60/invariant_dataset_test.pkl"
            if EXPERIMENT == "ROTMNIST60-90":
                config.customdata_train_path = "../models/mnist6090/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnist6090/invariant_dataset_test.pkl"
            if EXPERIMENT == "ROTMNIST":
                config.customdata_train_path = "../models/mnistrot/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnistrot/invariant_dataset_test.pkl"
            if EXPERIMENT == "MNISTMULTIPLE":
                config.customdata_train_path = "../models/mnistmultiple/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnistmultiple/invariant_dataset_test.pkl"
            if EXPERIMENT == "MNISTMULTIPLE_GAUSSIAN":
                config.customdata_train_path = "../models/mnistgaussian/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnistgaussian/invariant_dataset_test.pkl"
            if EXPERIMENT == "MNISTC2C4":
                config.customdata_train_path = "../models/mnistc2c4/invariant_dataset_train.pkl"
                config.customdata_test_path = "../models/mnistc2c4/invariant_dataset_test.pkl"
    else:
        print("NO Symmetry Standardization")
        if "MNIST" in EXPERIMENT:
            if EXPERIMENT == "ROTMNIST60":
                config.customdata_train_path = "../datasets/mnist60_train.pkl"
                config.customdata_test_path = "../datasets/mnist60_test.pkl"
            if EXPERIMENT == "ROTMNIST60-90":
                config.customdata_train_path = "../datasets/mnist60_90_train.pkl"
                config.customdata_test_path = "../datasets/mnist60_90_test.pkl"
            if EXPERIMENT == "MNISTMULTIPLE":
                config.customdata_train_path = "../datasets/mnist_multiple_train.pkl"
                config.customdata_test_path = "../datasets/mnist_multiple_test.pkl"
            if EXPERIMENT == "ROTMNIST":
                config.customdata_train_path = "../datasets/mnist_all_rotation_normalized_float_train_valid.amat"
                config.customdata_test_path = "../datasets/mnist_all_rotation_normalized_float_test.amat"
            if EXPERIMENT == "MNISTMULTIPLE_GAUSSIAN":
                config.customdata_train_path = "../datasets/mnist_multiple_gaussian_train.pkl"
                config.customdata_test_path = "../datasets/mnist_multiple_gaussian_test.pkl"
            if EXPERIMENT == "MNISTC2C4":
                config.customdata_train_path = "../datasets/mnist_c2c4_train.pkl"
                config.customdata_test_path = "../datasets/mnist_c2c4_test.pkl"

    # Train data loading
    if ".pkl" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        main_dataloader = PartialMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
        train_dataloader = main_dataloader[0]
        val_dataloader = main_dataloader[1]
        num_classes = 10
    if ".amat" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        main_dataloader = RotMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
        train_dataloader = main_dataloader[0]
        val_dataloader = main_dataloader[1]
        num_classes = 10

    # Supervised baseline: ResNet-18

    print("Loading ResNet model")
    model = models.resnet18(weights=None)
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  # grayscale input
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    model = model.cuda()
        
    # Training
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf')
    best_model = None
    for epoch in range(100):
        model.train()
        for x, label in train_dataloader:
            x = x.cuda()

            label = label.long().cuda()

            optimizer.zero_grad()

            # Forward
            outputs = model(x)

            # Backprop
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

        model.eval()
        total_loss = 0
        with torch.no_grad():
            for x, label in val_dataloader:
                x = x.cuda()
                label = label.long().cuda()

                outputs = model(x)
                loss = criterion(outputs, label)
                total_loss += loss.item()

        avg_val_loss = total_loss / len(val_dataloader)
        if (epoch+1) % 10 == 0:
            print(f"Epoch {epoch+1} validation loss: {avg_val_loss}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()

    # Load the best model
    model.load_state_dict(best_model)

    # Test data loading
    if ".pkl" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        test_dataloader = PartialMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                     no_val_split=True)
        test_dataloader = test_dataloader[0]
    if ".amat" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        test_dataloader = RotMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                 no_val_split=True)
        test_dataloader = test_dataloader[0]

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for x, label in test_dataloader:
            x = x.cuda()
            label = label.long().cuda()

            outputs = model(x)

            # Get predicted class
            _, predicted = outputs.max(1)

            # Update
            total += label.size(0)
            correct += (predicted == label).sum().item()

    test_accuracy = 100 * correct / total
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    return test_accuracy

In [4]:
import pandas as pd
resnet_results = {}

for EXPERIMENT in ["ROTMNIST60","ROTMNIST60-90", "MNISTMULTIPLE", "MNISTMULTIPLE_GAUSSIAN", "MNISTC2C4", "ROTMNIST"]:
    print(EXPERIMENT)
    acc_sym_std = train(config, sym_std=True)
    acc_no_sym_std = train(config, sym_std=False)
    resnet_results[EXPERIMENT] = (acc_no_sym_std, acc_sym_std)
  
df_results = pd.DataFrame.from_dict(resnet_results, orient="index", columns=["No Sym. Std.", "Symmetry Standardization"])
print("Supervised baseline comparison - Symmetry Standardization")
print(df_results)
try:
    df_results.to_csv(f"plots/resnet_results.csv")
except:
    try:
        home_directory = os.path.expanduser('~')
        file_path = os.path.join(home_directory, "Projects/alonso_syms/resnet_results.csv")
        df_results.to_csv(file_path)
        print(f"File saved to {file_path}")
    except:
        pass

ROTMNIST60
Symmetry Standardization
Loading for train: True , and for test: False
Loading ResNet model
Epoch 10 validation loss: 0.1446353793144226
Epoch 20 validation loss: 0.12452930361032485
Epoch 30 validation loss: 0.11473075971007347
Epoch 40 validation loss: 0.09664391130208969
Epoch 50 validation loss: 0.09833477810025215
Epoch 60 validation loss: 0.10033205188810826
Epoch 70 validation loss: 0.10069638285785913
Epoch 80 validation loss: 0.10334982443600893
Epoch 90 validation loss: 0.10432642884552479
Epoch 100 validation loss: 0.10550979524850845
Loading for train: False , and for test: True
Test Accuracy: 97.40%
NO Symmetry Standardization
Loading for train: True , and for test: False
Loading ResNet model
Epoch 10 validation loss: 0.18458586186170578
Epoch 20 validation loss: 0.16842306032776833
Epoch 30 validation loss: 0.19789064154028893
Epoch 40 validation loss: 0.15927521586418153
Epoch 50 validation loss: 0.12216870784759522
Epoch 60 validation loss: 0.1091009950265288

In [5]:
from collections import OrderedDict
from modules_sym import PartEqMod
import pytorch_lightning as pl
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
def train(config):
    if "MNIST" in EXPERIMENT:
        config.in_channels = 1
        config.emb_dim_theta = 128
        config.hidden_dim_theta = 64
        if EXPERIMENT == "ROTMNIST60":
            config.customdata_train_path = "../datasets/mnist60_train.pkl"
            config.customdata_test_path = "../datasets/mnist60_test.pkl"
            config.pretrained_path = "../models/ie-ae/mnist60/best_model_theta.pt"
            # Net
            config.hidden_dim = 64
            config.emb_dim = 200
            true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                                5: 60., 6: 60., 7: 60., 8: 60., 9: 60.}
        if EXPERIMENT == "ROTMNIST60-90":
            config.customdata_train_path = "../datasets/mnist60_90_train.pkl"
            config.customdata_test_path = "../datasets/mnist60_90_test.pkl"
            config.pretrained_path = "../models/ie-ae/mnist6090/best_model_theta.pt"
            # Net
            config.hidden_dim = 64
            config.emb_dim = 200
            true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                                5: 90., 6: 90., 7: 90., 8: 90., 9: 90.}
        if EXPERIMENT == "MNISTMULTIPLE":
            config.customdata_train_path = "../datasets/mnist_multiple_train.pkl"
            config.customdata_test_path = "../datasets/mnist_multiple_test.pkl"
            config.pretrained_path = "../models/ie-ae/mnistmultiple/best_model_theta.pt"
            # Net
            config.hidden_dim = 64
            config.emb_dim = 200
            true_thetas_dict = {0: 0, 1: 18, 2: 36, 3: 54, 4: 72,
                                5: 90, 6: 108, 7: 126, 8: 144, 9: 162}
        if EXPERIMENT == "MNISTMULTIPLE_GAUSSIAN":
            config.customdata_train_path = "../datasets/mnist_multiple_gaussian_train.pkl"
            config.customdata_test_path = "../datasets/mnist_multiple_gaussian_test.pkl"
            config.pretrained_path = "../models/ie-ae/mnistgaussian/best_model_theta.pt"
            # Net
            config.hidden_dim = 64
            config.emb_dim = 200
            std_dev_dict = {0: 0, 1: 9, 2: 18, 3: 27, 4: 36,
                            5: 45, 6: 54, 7: 63, 8: 72, 9: 81}
            true_thetas_dict = std_dev_dict
        if EXPERIMENT == "ROTMNIST":
            config.customdata_train_path = "../datasets/mnist_all_rotation_normalized_float_train_valid.amat"
            config.customdata_test_path = "../datasets/mnist_all_rotation_normalized_float_test.amat"
            config.pretrained_path = "../models/ie-ae/mnistrot/best_model_theta.pt"
            # Net
            config.hidden_dim = 64
            config.emb_dim = 200
            config.hidden_dim_theta = 32
            true_thetas_dict = {0: 180., 1: 180., 2: 180., 3: 180., 4: 180.,
                                5: 180., 6: 180., 7: 180., 8: 180., 9: 180.}
        if EXPERIMENT == "MNISTC2C4":
            config.customdata_train_path = "../datasets/mnist_c2c4_train.pkl"
            config.customdata_test_path = "../datasets/mnist_c2c4_test.pkl"
            config.pretrained_path = "../models/ie-ae/mnistc2c4/best_model_theta.pt"
            # Net
            config.hidden_dim = 164
            config.emb_dim = 200
            config.hidden_dim_theta = 32
            config.emb_dim_theta = 100
            true_thetas_dict = {0: 1, 1: 1, 2: 1, 3: 1, 4: 1,  # class 0 is C1, class 1 is C2...
                                5: 3, 6: 3, 7: 3, 8: 3, 9: 3}
        
            
    # Train data loading
    if ".pkl" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        main_dataloader = PartialMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True, no_val_split=True)
        train_dataloader = main_dataloader[0]
        num_classes = 10
    if ".amat" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        main_dataloader = RotMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True, no_val_split=True)
        train_dataloader = main_dataloader[0]
        num_classes = 10
    
    # Load SSL-SYM model
    try:
        net = PartEqMod(hparams=config)
        state_dict = torch.load(config.pretrained_path)
        
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace("model.", "")  # remove "model."
            new_state_dict[name] = v
        keys_to_load = {k: v for k, v in new_state_dict.items() if "theta_function" not in k}
        
        print(f"Loading pre-trained model for {EXPERIMENT}.")
        
        net.load_state_dict(keys_to_load, strict=False)
        net.cuda()
        net.eval()
    except:
        print("Error loading state dict")
        print(EXPERIMENT)
        #print(list(map(lambda x: x.shape, keys_to_load.values())))
        return -1

    features_list = []
    labels_list = []

    with torch.no_grad():
        for x, label in train_dataloader:
            x = x.cuda()
            label = label.cuda().long()
            features, _ = net.encoder(x)
            features = features.squeeze()
            features_list.append(features.detach().cpu().numpy())
            labels_list.append(label.cpu().numpy())

    # Convert lists to arrays
    features_array = np.concatenate(features_list, axis=0)
    labels_array = np.concatenate(labels_list, axis=0)

    # KNN classifier
    knn_classifier = KNeighborsClassifier(n_neighbors=5)
    knn_classifier.fit(features_array, labels_array)

    # Evaluation
            
    # Test data loading
    if ".pkl" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        test_dataloader = PartialMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                     no_val_split=True)
        test_dataloader = test_dataloader[0]
    if ".amat" in config.customdata_train_path and "MNIST" in EXPERIMENT:
        test_dataloader = RotMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                 no_val_split=True)
        test_dataloader = test_dataloader[0]

    features_list = []
    labels_list = []

    with torch.no_grad():
        for x, label in test_dataloader:
            x = x.cuda()
            label = label.cuda().long()
            features, _ = net.encoder(x)
            features = features.squeeze()
            features_list.append(features.detach().cpu().numpy())
            labels_list.append(label.cpu().numpy())
    # Convert lists to arrays
    features_array = np.concatenate(features_list, axis=0)
    labels_array = np.concatenate(labels_list, axis=0)

    predicted_labels = knn_classifier.predict(features_array)
    
    # Compute accuracy
    accuracy = accuracy_score(labels_array, predicted_labels)
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy

    

In [6]:
import pandas as pd
import os
ieae_results = {}
# all: ["ROTMNIST60", "ROTMNIST60-90", "MNISTMULTIPLE", "MNISTMULTIPLE_GAUSSIAN", "MNISTC2C4", "ROTMNIST"]
for EXPERIMENT in ["ROTMNIST60", "ROTMNIST60-90", "MNISTMULTIPLE", "MNISTMULTIPLE_GAUSSIAN", "MNISTC2C4", "ROTMNIST"]:
    print(EXPERIMENT)
    acc = train(config)
    ieae_results[EXPERIMENT] = acc
  
df_results = pd.DataFrame.from_dict(ieae_results, orient="index", columns=["IE-AE + KNN"])
print("IE-AE Invariant embeddings + KNN")
print(df_results)
try:
    df_results.to_csv(f"plots/ieae_knn_results.csv")
except:
    try:
        home_directory = os.path.expanduser('~')
        file_path = os.path.join(home_directory, "Projects/alonso_syms/ieae_knn_results.csv")
        df_results.to_csv(file_path)
        print(f"File saved to {file_path}")
    except:
        pass

ROTMNIST60
Loading for train: True , and for test: False


  full_mask[mask] = norms.to(torch.uint8)
  full_mask[mask] = norms.to(torch.uint8)


Loading pre-trained model for ROTMNIST60.
Loading for train: False , and for test: True
Test Accuracy: 0.9559
ROTMNIST60-90
Loading for train: True , and for test: False
Loading pre-trained model for ROTMNIST60-90.
Loading for train: False , and for test: True
Test Accuracy: 0.9532
MNISTMULTIPLE
Loading for train: True , and for test: False
Loading pre-trained model for MNISTMULTIPLE.
Loading for train: False , and for test: True
Test Accuracy: 0.9580
MNISTMULTIPLE_GAUSSIAN
Loading for train: True , and for test: False
Loading pre-trained model for MNISTMULTIPLE_GAUSSIAN.
Loading for train: False , and for test: True
Test Accuracy: 0.9565
MNISTC2C4
Loading for train: True , and for test: False
Error loading state dict
MNISTC2C4
ROTMNIST
Loading for train: True , and for test: False
Loading pre-trained model for ROTMNIST.
Loading for train: False , and for test: True
Test Accuracy: 0.9525
IE-AE Invariant embeddings + KNN
                        IE-AE + KNN
ROTMNIST60                  0.