In [1]:
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
import sys
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import numpy as np

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]


env_path = './Act_Max_Img_MLP_2'

if env_path not in sys.path:
    sys.path.append(str(env_path))

def norm_01(array):
    return(array-np.min(array))/(np.max(array)-np.min(array))

def save_im(batch, filepath, ch, ht, wd):
    img = batch.reshape(ch, ht, wd).permute(1,2,0)
    img_np = img.detach().numpy()
    plt.imsave(filepath, norm_01(img_np))

def convert_pt_to_image(input_path, output_path):
    # Ensure the output directory exists
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    # List all files in the input directory
    files = os.listdir(input_path)

    # Filter files with ".pt" extension
    pt_files = [f for f in files if f.endswith(".pt")]

    #List of files that couldn't be processed
    err_files = []

    for pt_file in pt_files:

        print(f"Processing file {pt_file}")
        pt_filepath = os.path.join(input_path, pt_file)
        # print(pt_filepath)

        try:
          # Load PyTorch tensor from the file
          pt_tensor = torch.load(pt_filepath, map_location=torch.device('cpu'))

          # Find the right output directory
          class_index = int(pt_file.split("_")[1])

          if 0 <= class_index < len(classes):
            class_name = classes[class_index]

            # Change filename to .png
            image_filename = os.path.splitext(pt_file)[0] + ".png"
            # Specify subfolder based on class
            image_filefolder = os.path.join(output_path, class_name)
            if not os.path.exists(image_filefolder):
              os.makedirs(image_filefolder)

            image_filepath = os.path.join(image_filefolder, image_filename)
            save_im(pt_tensor, image_filepath, 3, 64, 64)

            print(f"Successfully saved file {pt_file} as an image")

        except Exception as e:
          err_files.append(pt_file)
          print(f"An error occurred in file {pt_file}: {e}")


    err_file_path = "./Act_Max_Img_MLP_Raw_Images/err_log.txt"
    with open(err_file_path, "w") as file:
        for item in err_files:
            file.write(f"{item}\n")


# Example usage:
input_directory = "./Act_Max_Img_MLP_2"  # Replace with the actual path
output_directory = "./Act_Max_Img_MLP_Raw_Images"  # Replace with the actual path

convert_pt_to_image(input_directory, output_directory)


  from .autonotebook import tqdm as notebook_tqdm


Processing file class_8_ker_3_startss_0.1_startsig_0.55_epoch_200.pt
Successfully saved file class_8_ker_3_startss_0.1_startsig_0.55_epoch_200.pt as an image
Processing file class_4_ker_3_startss_0.001_startsig_0.55_epoch_300.pt
Successfully saved file class_4_ker_3_startss_0.001_startsig_0.55_epoch_300.pt as an image
Processing file class_9_ker_3_startss_0.0001_startsig_0.55_epoch_600.pt
Successfully saved file class_9_ker_3_startss_0.0001_startsig_0.55_epoch_600.pt as an image
Processing file class_0_ker_7_startss_0.01_startsig_0.325_epoch_700.pt
Successfully saved file class_0_ker_7_startss_0.01_startsig_0.325_epoch_700.pt as an image
Processing file class_6_ker_3_startss_0.0001_startsig_0.1_epoch_500.pt
Successfully saved file class_6_ker_3_startss_0.0001_startsig_0.1_epoch_500.pt as an image
Processing file class_2_ker_5_startss_1e-05_startsig_0.1_epoch_200.pt
Successfully saved file class_2_ker_5_startss_1e-05_startsig_0.1_epoch_200.pt as an image
Processing file class_1_ker_5_st