In [8]:
import matplotlib.pyplot as plt

from glob import glob

import numpy as np

import torch

from tqdm.notebook import tqdm

# import wandb

import pickle5 as pickle

import sys
sys.path.append("..")

import random

# AE TEST RESULTS

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

encoder_dirs = glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/models/*enc.pickle')
encoder_dirs = sorted(encoder_dirs, key=lambda x: int(x.split('/')[-1].split('_')[0]))

dataset_dir = '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/injections/'
img_dirs = glob(dataset_dir+'*.npy')

n_of_photos=10

total_enc_count = len(encoder_dirs)

plt.figure(figsize=(20, 20))
plt.subplots_adjust(bottom=2, right=0.5, top=5)

rand_selected_imgs_dir = []

for idx in range(n_of_photos):
    rand_selected_imgs_dir.append(random.choice(img_dirs))
    print(rand_selected_imgs_dir[idx].split('/')[-1])


for enc_idx, enc_dir in enumerate(encoder_dirs):
    dec_dir = enc_dir.replace('enc', 'dec')
    
    with open(f'{enc_dir}', 'rb') as fin:
        enc_best = pickle.load(fin).to(device)
    
    with open(f'{dec_dir}', 'rb') as fin:
        dec_best = pickle.load(fin).to(device)

    for img_idx in range(n_of_photos):
        img = np.load(rand_selected_imgs_dir[img_idx])
        img = torch.from_numpy(img).to(device).float()
        img = img.unsqueeze(0)
        
        enc_best.eval()
        dec_best.eval()
        with torch.no_grad():
            enc_out = enc_best(img)
            dec_out = dec_best(enc_out)

        ax = plt.subplot(total_enc_count*2+2, n_of_photos, (2*enc_idx+1)*n_of_photos+img_idx+1)

        if img_idx == n_of_photos//2:
            ax.set_title(f'{enc_dir.split("/")[-1]}/Original')

        plt.imshow(img.squeeze().detach().cpu().numpy())
        plt.axis('off')
        

        ax = plt.subplot(total_enc_count*2+2, n_of_photos, (2*enc_idx+2)*n_of_photos+img_idx+1)

        if img_idx == n_of_photos//2:
            ax.set_title(f'{enc_dir.split("/")[-1]}/Dec_OUT')

        plt.imshow(dec_out.squeeze().detach().cpu().numpy())
        plt.axis('off')

In [None]:
import pandas as pd

encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    enc_best.eval()
    with torch.no_grad():
        encoded_img  = enc_best(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', 
           color=encoded_samples.label.astype(str), opacity=0.7)

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))
fig = px.scatter(tsne_results, x=0, y=1,
                 color=encoded_samples.label.astype(str),
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig.show()