In [1]:
# imports
import torch
import pathlib
from pathlib import Path

temp_posix_path = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath

# set environment variables to limit cpu usage
import os
os.environ["OMP_NUM_THREADS"] = "4"  # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "4"  # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "6"  # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"  # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "6"  # export NUMEXPR_NUM_THREADS=6
from model_definitions.def_net import NNmodule

In [None]:
dspath = Path("dataset_fmnist_hyp_fix.pt")
ds = torch.load(dspath)
pathlib.PosixPath = temp_posix_path

In [47]:
print(ds.keys())
weights_test = ds["testset"].__get_weights__()
print(f"Weights test shape: {weights_test.shape}")

dict_keys(['trainset', 'valset', 'testset'])


100%|██████████| 18374/18374 [00:08<00:00, 2086.65it/s]

Weights test shape: torch.Size([18374, 2464])





In [15]:
def model_autoload(props, weights, verbose=False):
    model_types = ["CNN", "CNN2", "CNN3", "Resnet18", "MLP"]
    props["model::channels_in"] = weights[next(iter(weights))].shape[1]
    props["model::o_dim"] = weights[next(reversed(weights))].shape[0]
    props["optim::momentum"] = 0.99
    props["scheduler::steps_per_epoch"] = 1000
    for model_type in model_types:
        try:
            props["model::type"] = model_type
            model = NNmodule(config=props)
            weights2load = {f"model.{key}": value for key, value in weights.items()}
            model.load_state_dict(weights2load)
            return props, model
        except RuntimeError as e:
            continue
    if verbose:
        print(f"Could not load model. Props: {props}")
    return props, None

In [16]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import json
import glob
import os
from PIL import Image

In [17]:
class TinyImageNetDataset(Dataset):
    """Custom TinyImageNet dataset that maps folder names to sequential IDs using wnids.txt"""
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        # Load wnids.txt to get the mapping from folder names to sequential IDs
        wnids_path = os.path.join(os.path.dirname(root_dir), "wnids.txt")
        with open(wnids_path, "r") as f:
            self.wnids = [line.strip() for line in f.readlines()]
        
        # Create mapping from folder name to sequential ID
        self.class_to_idx = {wnid: idx for idx, wnid in enumerate(self.wnids)}
        
        # Collect all image paths and their labels
        self.samples = []
        
        # Check if we"re using test or train structure
        if "test" in root_dir:
            # Test structure: test/images/test_*.JPEG
            image_dir = root_dir
            if os.path.exists(image_dir):
                for img_file in glob.glob(os.path.join(image_dir, "*.JPEG")):
                    # For test images, we don"t have labels, so use -1 or handle appropriately
                    self.samples.append((img_file, -1))  # -1 indicates unknown label for test
        else:
            # Train structure: train/n02124075/images/*.JPEG
            for wnid in self.wnids:
                class_dir = os.path.join(root_dir, wnid, "images")
                if os.path.exists(class_dir):
                    label = self.class_to_idx[wnid]
                    for img_file in glob.glob(os.path.join(class_dir, "*.JPEG")):
                        self.samples.append((img_file, label))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [18]:
def load_model_data(data_path, source_type="pt_file"):
    """
    Load model data from either .pt file or folder structure
    
    Args:
        data_path: path to data source (.pt file or folder)
        source_type: "pt_file" or "folder"
    
    Returns:
        Dictionary with trainset/testset containing model data
    """
    if source_type == "pt_file":
        return torch.load(data_path)
    elif source_type == "folder":
        # Load from folder structure: folder/some_sample/checkpoint_000060/checkpoints + folder/some_sample/params.json
        data = {"trainset": [], "testset": []}
        
        # Find all sample folders
        sample_folders = glob.glob(os.path.join(data_path, "*"))
        sample_folders = [f for f in sample_folders if os.path.isdir(f)]
        
        for sample_folder in sample_folders:
            # Load params.json
            params_path = os.path.join(sample_folder, "params.json")
            acc_path = os.path.join(sample_folder, "progress.csv")

            with open(params_path, "r") as f:
                props = json.load(f)
            
            with open(acc_path, "r") as f:
                lines = f.readlines()
                first_line = lines[0].strip().split(",")
                last_line = lines[-1].strip().split(",")
                if "test_acc" in first_line:
                    props["test_acc"] = float(last_line[first_line.index("test_acc")])
            
            # Load weights from checkpoint_000060
            checkpoint_path = os.path.join(sample_folder, "checkpoint_000060")
            if not os.path.exists(checkpoint_path):
                continue
            
            weights = torch.load(os.path.join(checkpoint_path, "checkpoints"))            
            data["trainset"].append((props, weights))
        
        return data
    else:
        raise ValueError(f"Unsupported source_type: {source_type}")

In [21]:
def validate_and_save(ds, subset, threshold=0.8, dataset="FashionMNIST", sanity_threshold=0.5):
    """
    Validate models by accuracy threshold and sanity check on dataset
    
    Args:
        ds: dataset dictionary containing model data
        subset: dataset subset name
        threshold: minimum accuracy threshold
        dataset: dataset to use for sanity check ("MNIST", "CIFAR10", "FashionMNIST", "TinyImageNet")
        sanity_threshold: minimum accuracy on sanity dataset
    """
    # Create dataset for sanity checking
    if dataset == "MNIST":
        transform = transforms.Compose([transforms.ToTensor()])
        test_dataset = torchvision.datasets.MNIST(root="./data", train=False, 
                                                 download=True, transform=transform)
    elif dataset == "CIFAR10":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False,
                                                   download=True, transform=transform)
    elif dataset == "FashionMNIST":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        test_dataset = torchvision.datasets.FashionMNIST(root="./data", train=False,
                                                        download=True, transform=transform)
    elif dataset == "TinyImageNet":
        transform = transforms.Compose([
            transforms.Resize((64, 64)),  # TinyImageNet is 64x64
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Use custom TinyImageNet dataset with proper ID mapping
        test_dataset = TinyImageNetDataset(
            root_dir="./data/tiny-imagenet-200/train",  # Use train set for validation
            transform=transform
        )
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")
    
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
    
    valid_models = []
    pbar = tqdm(range(len(ds[subset])), total=len(ds[subset]), desc="Validating models")
    for model_n in pbar:
        props, weights = ds[subset][model_n]

        # Check accuracy threshold
        acc_key = "test_acc"
        if props.get(acc_key, 0) < threshold:
            continue
            
        # Try to load model
        props_loaded, model = model_autoload(props, weights)
        if model is None:
            pbar.set_description(f"Model {model_n}: acc={props[acc_key]:.3f} - FAILED to load")
            continue
        # Sanity check: evaluate model on dataset
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                if batch_idx >= 10:  # Only test on first 10 batches for speed
                    break
                
                # Skip samples with unknown labels (test set)
                if dataset == "TinyImageNet":
                    valid_mask = target != -1
                    if valid_mask.sum() == 0:
                        continue
                    data = data[valid_mask]
                    target = target[valid_mask]
                
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        sanity_accuracy = correct / total if total > 0 else 0
        
        # Only include if passes sanity check
        if sanity_accuracy >= sanity_threshold:
            valid_models.append((props_loaded, model))
            pbar.set_description(f"Model {model_n}: acc={props[acc_key]:.3f}, sanity_acc={sanity_accuracy:.3f} - PASSED")
        else:
            pbar.set_description(f"Model {model_n}: acc={props[acc_key]:.3f}, sanity_acc={sanity_accuracy:.3f} - FAILED sanity check")

    pbar.close()
    print(f"Valid models: {len(valid_models)}/{len(ds[subset])}")
    return valid_models

In [None]:
# save valid train models to .pt file
torch.cuda.empty_cache()

# ds_resnet = load_model_data("./data/tiny-imagenet_resnet18_kaiming_uniform_subset", source_type="folder")
# valid_resnet = validate_and_save(
#     ds_resnet,
#     "trainset",
#     threshold=0.5,
#     dataset="TinyImageNet",
#     sanity_threshold=0.5
# )
# torch.save(valid_resnet, "valid_tinyimagenet_resnet_models.pt")

# ds_cifar = load_model_data("./data/dataset_cifar_large_hyp_rand.pt")
# valid_cifar = validate_and_save(
#     ds_cifar,
#     "trainset",
#     threshold=0.5,
#     dataset="CIFAR10",
#     sanity_threshold=0.5
# )
# torch.save(valid_cifar, "valid_cifar_models.pt")

ds_mnist = load_model_data("./data/dataset_mnist_hyp_rand.pt")
valid_mnist = validate_and_save(
    ds_mnist,
    "trainset",
    threshold=0.5,
    dataset="MNIST",
    sanity_threshold=0.5
)
torch.save(valid_mnist, "valid_mnist_models.pt")

ds_fmnist = load_model_data("./data/dataset_fmnist_hyp_rand.pt")
valid_fmnist = validate_and_save(
    ds_fmnist,
    "trainset",
    threshold=0.5,
    dataset="FashionMNIST",
    sanity_threshold=0.5
)
torch.save(valid_fmnist, "valid_fmnist_models.pt")

valid_resnet = torch.load("valid_tinyimagenet_resnet_models.pt")
valid_cifar = torch.load("valid_cifar_models.pt")

all_valid = valid_resnet + valid_cifar + valid_mnist + valid_fmnist
print(f"Valid ResNet models: {len(valid_resnet)}")
print(f"Valid CIFAR models: {len(valid_cifar)}")
print(f"Valid MNIST models: {len(valid_mnist)}")
print(f"Valid FashionMNIST models: {len(valid_fmnist)}")
print(f"Total valid models: {len(all_valid)}")
torch.save(all_valid, "valid_train_models.pt")

Model 82039: acc=0.828, sanity_acc=0.662 - PASSED: 100%|██████████| 82040/82040 [2:04:07<00:00, 11.02it/s]               


Valid models: 23839/82040


Model 81367: acc=0.779 - FAILED to load: 100%|██████████| 81368/81368 [2:41:31<00:00,  8.40it/s]                         


Valid models: 34745/81368
Valid ResNet models: 116
Valid CIFAR models: 12658
Valid MNIST models: 23839
Valid FashionMNIST models: 34745
Total valid models: 71358
