In [1]:
import torch
import numpy as np

from models import MLP
from helpers import model_to_list

from tqdm.notebook import tqdm
import os
import csv
import yaml

In [2]:
with open('config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

datasets_path = config["directories"]["datasets_path"]
models_path = config["directories"]["models_path"]

In [3]:
def weights_to_csv(dataset_name: str) -> bool:
    '''
    Save all model weights from a dataset into single csv file.

    Parameters:
        dataset_name (str): name of the dataset.

    Returns:
        bool: True if the csv file was created, False otherwise.
    '''

    path = os.path.join(models_path, dataset_name)
    if(os.path.exists(path)):
        csv_path = os.path.join(datasets_path, f"{dataset_name}.csv")
        with(open(csv_path, "w")) as f:
            fieldnames = [f"weight_{i}" for i in range(0, 151)]
            fieldnames.append("angle")
            writer = csv.writer(f, lineterminator = '\n')
            writer.writerow(fieldnames)

            models = os.listdir(path)
            for model in tqdm(models):
                angle = int(model.split("_")[1])
                m = MLP()
                m.load_state_dict(torch.load(f"{path}/{model}"))
                weights = model_to_list(m)
                row = weights.tolist()
                row.append(angle)
                writer.writerow(row)
        return True
    else:
        print("No models for this dataset found!")
        return False

In [4]:
weights_to_csv("four_angles")

  0%|          | 0/8000 [00:00<?, ?it/s]

True