In [2]:
import numpy as np
from glob import glob
from tqdm import tqdm 
from model import VAE
import torch
import cv2
import json 


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.is_available())

latentDim = 25
checkpoint_path = "./ckpt_files/anar_lat_25_052324.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model = VAE(latentDim)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)
model.eval()


cpu False


VAE(
  (encoder): Encoder(
    (conv_stack): Sequential(
      (0): Conv2d(1, 32, kernel_size=(11, 11), stride=(1, 1), padding=same)
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=same)
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (7): ReLU()
      (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (9): Flatten()
      (10): Linear(in_features=8192, out_features=50, bias=True)
    )
  )
  (decoder): Decoder(
    (inverse_conv_stack): Sequential(
      (0): Linear(in_features=25, out_features=1024, bias=True)
      (1): ReLU()
      (2): UnFlatten()
      (3): ConvTranspose2d(16, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): ReLU()
      (5): ConvTranspose2d(128, 64, ke

In [3]:
# Read data from string 
def process_string_mech(dir, toNpy = True):
    # I do not know why but this works for windows os. You may need to change this if you are using linux/macbook
    # Zhijie: you can test using strings like: 
    # ./outputs-4bar/-0.001 2.728 5.504 -1.565 -5.632 -2.481 -8.711 9.682 1.320 -5.630 -7.171 3.601 RRRP  0.42 0.026 0.732 -0.026 0.42 2.011 0. 0. 1. .jpg
    input_string = dir.split('/')[-1].split('.j')[0] 
    
    # Split the string by spaces
    parts = input_string.split()
    
    # Initialize lists to hold floats
    floats_before = []
    floats_after = []
    letter_string = None
    
    # Iterate over parts to separate floats and the letter string
    for part in parts:
        try:
            # Try to convert part to float
            num = float(part)
            # Add to floats_before if letter_string is not yet found
            if letter_string is None:
                floats_before.append(num)
            else:
                floats_after.append(num)
        except ValueError:
            # If conversion fails, this part is the letter string
            letter_string = part
    
    if toNpy:
        floats_before = np.array(floats_before).reshape((-1, 2))
        floats_after = np.matrix(floats_after).reshape((3, 3))
    
    #if len(floats_before) != 10: # security check... you should change this for your specific mechanism. 
    #    print('you got fucked', dir, '\n' , floats_before, '\n')
    return floats_before, letter_string, floats_after

In [9]:
# BSIdictionary update (PRPR)
file_path = './KV_468_062324.json'

# Open and read the JSON file
with open(file_path, 'r') as file:
    KVdict = json.load(file) 

image_folder = './outputs-4bar/'

setSize = 1000 # len(imgStrings) # determine how many samples for each type. 
batchSize = 1000
four_bar = ['RRRR', 'RRRP', 'RRPR', 'PRPR']  # 
six_bar  = ['Watt1T1A1', 'Watt1T2A1', 'Watt1T3A1', 'Watt1T1A2', 'Watt1T2A2', 'Watt1T3A2', 
            'Watt2T1A1', 'Watt2T2A1', 'Watt2T1A2', 'Watt2T2A2', 'Steph1T1', 'Steph1T2',
            'Steph1T3', 'Steph2T1A1', 'Steph2T2A1', 'Steph3T1A1', 'Steph3T2A1', 'Steph3T1A2', 
            'Steph3T2A2', 'Steph2T1A2', 'Steph2T2A2']
eight_bar = list(set(KVdict.keys()) - set(four_bar) - set(six_bar))

for mechType in four_bar:
    batchImg = []
    result_zSet = []
    result_featSet = []
    value = KVdict[mechType]
    z_folder = './outputs-z/'
    e_folder = './outputs-encoded/'
    imgStrings = glob(image_folder + mechType + '/*')
    print(mechType, len(imgStrings))
    for i in tqdm(range(min(setSize, len(imgStrings)))): 
        batchImg.append(cv2.imread(imgStrings[i], cv2.IMREAD_GRAYSCALE)/ 255) # This /255 works better than not doing it 
        floats_before, letter_string, floats_after = process_string_mech(imgStrings[i], toNpy = False)
        if len(floats_after) == 6:
            floats_after = floats_after + [0, 0, 1]
        elif len(floats_after) != 9: 
            print('what?', floats_after)
        
        result_featSet.append(np.array(floats_before + [KVdict[letter_string]] + floats_after, dtype= float).flatten().tolist())
        if len(batchImg) >= batchSize:
            images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
            x = model.encoder(images)
            mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
            z = model.reparameterize(mean, logvar)
            z = z.cpu().detach().numpy()
            result_zSet.append(z)
            batchImg = []

    if len(batchImg) > 0:
        images = torch.from_numpy(np.array([batchImg])).swapaxes(0,1).float().to(device)
        x = model.encoder(images)
        mean, logvar = x[:, : model.latent_dim], x[:, model.latent_dim :]
        z = model.reparameterize(mean, logvar)
        z = z.cpu().detach().numpy()
        result_zSet.append(z)
        batchImg = []

    if len(result_zSet) > 0:
        result_zSet = np.concatenate(result_zSet)
        date = '062324'

        batchZname = z_folder + date + '-z-' + str(int(KVdict[mechType]))
        batchEname = e_folder + date + '-encoded-' + str(int(KVdict[mechType]))

        import os
        os.makedirs(z_folder, exist_ok=True)
        os.makedirs(e_folder, exist_ok=True)

        np.save(batchZname, np.array(result_zSet))
        np.save(batchEname, np.array(result_featSet))

RRRR 40662


100%|██████████| 1000/1000 [00:03<00:00, 321.99it/s]


RRRP 36526


100%|██████████| 1000/1000 [00:02<00:00, 403.76it/s]


RRPR 35064


100%|██████████| 1000/1000 [00:02<00:00, 391.34it/s]


PRPR 39771


100%|██████████| 1000/1000 [00:02<00:00, 413.85it/s]


In [10]:
length = len(result_featSet[0])
for thing in result_featSet:
    if len(thing) != length:
        print(thing, len(thing))

