In [6]:
import os
import re
from datetime import datetime
import tarfile
import zipfile

from torchvision import datasets
from torchvision.models import list_models
import torchvision.models as models
import torch

def log_message(message, level="info"):
    """
    Log a message, with support for info, warning, and error levels.
    If the level is "error", this function raises a ValueError with the message.

    Parameters:
        message (str): The message to log.
        level (str): The level of the message ("info", "warning", or "error").
    """
    if level == "info":
        print(f"[INFO] {message}")
    elif level == "warning":
        print(f"[WARNING] {message}")
    elif level == "error":
        print(f"[ERROR] {message}")
        raise ValueError(message)  # Raise an error if level is "error"

def sanitize_string(s):
    """
    Convert a string to lowercase, strip special characters, and replace spaces with underscores.

    Parameters:
        s (str): The input string to sanitize.

    Returns:
        str: The sanitized string.
    """
    s = s.lower()
    s = re.sub(r'[^a-z0-9_]', '', s)
    return s.replace(' ', '_')

def ensure_directory_exists(dir_path):
    """
    Ensure the specified directory exists; create it if it does not, including required subdirectories.
    Set permissions to 777 for the main directory and its contents recursively.

    Parameters:
        dir_path (str): The path of the directory to ensure exists.

    Returns:
        str: The absolute path of the directory.
    """
    os.makedirs(dir_path, exist_ok=True)
    
    # List of required subdirectories
    subdirs = ['dataset', 'src', 'checkpoints', 'output']
    
    # Create each subdirectory if it doesn’t exist
    for subdir in subdirs:
        subdir_path = os.path.join(dir_path, subdir)
        os.makedirs(subdir_path, exist_ok=True)
    
    # Set permissions recursively for main directory and all subdirectories/files
    for root, dirs, files in os.walk(dir_path):
        os.chmod(root, 0o777)
        for d in dirs:
            os.chmod(os.path.join(root, d), 0o777)
        for f in files:
            os.chmod(os.path.join(root, f), 0o777)

    return os.path.abspath(dir_path)

def check_file_exists(file_path):
    """
    Check if a file exists at the specified path.

    Parameters:
        file_path (str): The path of the file to check.

    Returns:
        bool: True if the file exists, False otherwise.
    """
    return os.path.isfile(file_path)

def is_archive_file(file_path):
    """
    Check if a file is a valid archive (zip or tar).

    Parameters:
        file_path (str): The path of the file to check.

    Returns:
        bool: True if the file is a valid archive, False otherwise.
    """
    return zipfile.is_zipfile(file_path) or tarfile.is_tarfile(file_path)

In [7]:
def setup_project(prj_name, model_type, project_folder=None, model_py_file=None, model_pth_file=None, torch_vision_model=None, dataset_type=None, custom_dataset=None, torch_vision_dataset=None):
    """
    Set up a project with a specified structure, including creating necessary directories, 
    validating model type, and checking for necessary files.

    Parameters:
        prj_name (str): The name of the project.
        model_type (str): The type of model ('untrained', 'custom_pretrained', or 'torch_vision_pretrained').
        project_folder (str, optional): The main folder for the project. A new folder is generated if not provided.
        model_py_file (str, optional): The filename of the Python script defining the model architecture, required for 'untrained' and 'custom_pretrained' models.
        model_pth_file (str, optional): The filename of the .pth file with pre-trained weights, required for 'custom_pretrained' models.
        torch_vision_model (str, optional): The name of the TorchVision model to load if 'torch_vision_pretrained' is selected.
        dataset_type (str, optional): Type of dataset for 'untrained' models ('torch_vision_dataset' or 'custom_dataset').
        custom_dataset (str, optional): Path to the custom dataset file for 'untrained' models with 'custom_dataset' dataset type.
        torch_vision_dataset (str, optional): Name of the TorchVision dataset class for 'untrained' models with 'torch_vision_dataset' dataset type.

    Returns:
        dict: A dictionary with project setup information.
    """
    log_message("Setting up project")
    working_folder = "/home/fastqnn/finn/notebooks/Fast-QNN/outputs/txaviour/"
    prj_info = {}
    
    # Ensure Project Name is provided
    if not prj_name:
        log_message("Project Name is required", level="error")

    # Sanitize project name and set project info
    prj_name_stripped = sanitize_string(prj_name)
    display_name = prj_name
    prj_info["Display_Name"] = display_name
    prj_info["Stripped_Name"] = prj_name_stripped

    if not project_folder:
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # for production mode
        timestamp = "0"
        project_folder = f"{working_folder}{prj_name_stripped}_{timestamp}"
    
    folder_path = ensure_directory_exists(project_folder)
    prj_info["Folder"] = folder_path

    # Validate Model Type
    valid_model_types = ["untrained", "custom_pretrained", "torch_vision_pretrained"]
    if model_type not in valid_model_types:
        log_message(f"Model Type must be one of {valid_model_types}", level="error")
    prj_info['Model_Type'] = model_type
    
    # Handle Model Type specific requirements
    if model_type in ["untrained", "custom_pretrained"]:
        if not model_py_file or not check_file_exists(os.path.join(project_folder, 'src', model_py_file)):
            log_message(f"Model Py File '{model_py_file}' does not exist in '{project_folder}'", level="error")
        prj_info["Model_Py_File"] = model_py_file

    if model_type == "custom_pretrained":
        if not model_pth_file or not check_file_exists(os.path.join(project_folder, 'src', model_pth_file)):
            log_message(f"Model Pth File '{model_pth_file}' does not exist in '{project_folder}'", level="error")
        prj_info["Model_Pth_File"] = model_pth_file

    if model_type == "torch_vision_pretrained":
        available_models = list_models(module=models)
        if not torch_vision_model or torch_vision_model not in available_models:
            log_message(f"Torch Vision Model must be one of: {available_models}", level="error")
        prj_info["Torch_Vision_Model"] = torch_vision_model

    # Handle Dataset requirements for Untrained models
    if model_type == "untrained":
        if dataset_type not in ["torch_vision_dataset", "custom_dataset"]:
            log_message("Dataset Type must be either 'Torch Vision' or 'Custom Dataset' for Untrained models", level="error")
        prj_info["Dataset_Type"] = dataset_type

        if dataset_type == "custom_dataset":
            custom_dataset_path = os.path.join(project_folder, 'dataset', custom_dataset)
            if not custom_dataset or not check_file_exists(custom_dataset_path) or not is_archive_file(custom_dataset_path):
                log_message(f"Custom Dataset '{custom_dataset}' must exist in '{project_folder}' and be an archive file (zip or tar)", level="error")
            prj_info["Custom_Dataset"] = custom_dataset

        elif dataset_type == "torch_vision_dataset":
            available_datasets = [cls_name.lower() for cls_name in dir(datasets) if not cls_name.startswith('_')]
            if not torch_vision_dataset or torch_vision_dataset.lower() not in available_datasets:
                log_message(f"Torch Vision Dataset must be one of: {available_datasets}", level="error")
            prj_info["Torch_Vision_Dataset"] = torch_vision_dataset
            
    log_message(f"Project setup complete. {prj_name} has been initialized.")
    return prj_info

In [8]:
def load_torch_vision_model(model_name, src_folder, initial_channels = 3, max_size = 4096):
    """
    Loads a pre-trained model from TorchVision and ensures all downloads are in the src folder of the Project.

    Parameters:
        initial_channels: 
        max_size: 
        model_name (str): The name of the pre-trained model to load (e.g., 'alexnet', 'resnet50').
        src_folder (str): The folder where model downloads will be stored. Default is 'src'.
                
    Returns:
        torch.nn.Module: The loaded pre-trained model.
    """
    log_message(f"Loading Torch Vision Model: {model_name}")
    # Ensure the src folder exists
    os.makedirs(src_folder, exist_ok=True)
    # Set TORCH_HOME to the src folder to store the model downloads there
    os.environ['TORCH_HOME'] = src_folder
    # Load the model from torch.hub with the specified model name
    torch_model = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=True)
    # List of common input shapes to test first (both square and non-square)
    common_shapes = [
        (1, initial_channels, 32, 32),  # Typical for many models
        (1, initial_channels, 28, 28),  # Typical for many models
        (1, initial_channels, 224, 224),  # Typical for many models
        (1, initial_channels, 299, 299),  # For models like Inception
        (1, initial_channels, 128, 128),  # Smaller size
        (1, initial_channels, 256, 256),  # Larger square size
        (1, initial_channels, 320, 240),  # Common non-square size
        (1, initial_channels, 240, 320),  # Non-square (swapped dimensions)
        (1, initial_channels, 256, 128),  # Non-square
        (1, initial_channels, 128, 256),  # Non-square (swapped)
    ]

    # Try common shapes first
    for shape in common_shapes:
        try:
            dummy_input = torch.rand(*shape)
            torch_model(dummy_input)
            log_message(f"Compatible common input shape found: {shape}")
            return torch_model, shape
        except RuntimeError:
            continue

    # If no common shape worked, test all possible square and non-square shapes up to max_size
    for width in range(1, max_size + 1, 1):  # Step by 16 for efficiency
        for height in range(1, max_size + 1, 1):
            try:
                dummy_input = torch.rand(1, initial_channels, width, height)
                torch_model(dummy_input)
                torch_model_input_shape = (1, initial_channels, width, height)
                log_message(f"Compatible input shape found: {torch_model_input_shape}")
                return torch_model, torch_model_input_shape
            except RuntimeError:
                continue

    log_message("Could not determine a compatible input shape within the specified range.", level="error")


In [9]:
prj_name_input = "AlexNet 1W1A"
prj_folder_input = sanitize_string(prj_name_input)
model_type_input = "torch_vision_pretrained"
torch_vision_model_input = "alexnet"
Project_Info = setup_project(prj_name=prj_name_input, model_type=model_type_input, torch_vision_model=torch_vision_model_input)

input_model= None
torch_vision_shape = None

if Project_Info['Model_Type'] == "untrained":
    log_message("Training for untrained models are not supported at the moment!", level="error")
elif Project_Info['Model_Type'] == "custom_pretrained":
    log_message("Custom Pretrained models are not supported at the moment!", level="error")
elif Project_Info['Model_Type'] == "torch_vision_pretrained":
    torch_vision_folder = os.path.join(Project_Info['Folder'],"src")
    input_model, torch_vision_shape = load_torch_vision_model(Project_Info['Torch_Vision_Model'], torch_vision_folder)
else:
    log_message("Unsupported Model Type", level="error")

[INFO] Setting up project
[INFO] Project setup complete. AlexNet 1W1A has been initialized.
[INFO] Loading Torch Vision Model: alexnet


Using cache found in /home/fastqnn/finn/notebooks/Fast-QNN/outputs/txaviour/alexnet1w1a_0/src/hub/pytorch_vision_v0.10.0


[INFO] Compatible common input shape found: (1, 3, 224, 224)


In [10]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup

export_onnx_path = os.path.join(Project_Info['Folder'],"checkpoints",Project_Info['Stripped_Name']+"_export.onnx")
export_qonnx(input_model, torch.randn(torch_vision_shape), export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)