In [None]:
import glob
from astropy.utils.data import get_pkg_data_filename
from astropy.io import fits
import numpy as np
from tqdm import tqdm
import stretch
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance

sigma = 1.5
bg = 0.2

data_path = '../data/calibrated/'
short_files = sorted(glob.glob(data_path+'Short/*'))
longs_files = sorted(glob.glob(data_path+'Long/*'))
shorts,longs = [],[]
load_raw = True
for (i,j) in tqdm(zip(short_files, longs_files)):
    filename = i.split('/')[-1].split('.')[0]
    if load_raw:
        print(i,j)
        short_linear = np.swapaxes(np.swapaxes(fits.getdata(get_pkg_data_filename(i), ext=0),0,2),0,1)
        long_linear = np.swapaxes(np.swapaxes(fits.getdata(get_pkg_data_filename(j), ext=0),0,2),0,1)
        median = [np.median(short_linear[:,:,c]) for c in range(3)]
        mad = [np.median(np.abs(short_linear[:,:,c] - median[c]))for c in range(3)]
        long, short = stretch.stretch(long_linear, short_linear, bg, sigma, median, mad)
        
        # dump jpg
        short_jpg = np.array(ImageEnhance.Color(Image.fromarray((np.clip(short,0,1)*255).astype(np.uint8))).enhance(2.0)) 
        long_jpg = np.array(ImageEnhance.Color(Image.fromarray((np.clip(long,0,1)*255).astype(np.uint8))).enhance(2.0)) 
        
        Image.fromarray(short_jpg).save("../data/thumbnails/short_"+filename+".jpg")
        Image.fromarray(long_jpg).save("../data/thumbnails/long_"+filename+".jpg")
        
        shorts.append(short_jpg)
        longs.append(long_jpg)
    else:
        shorts.append(plt.imread("../data/thumbnails/short_"+filename+".jpg"))
        longs.append(plt.imread("../data/thumbnails/long_"+filename+".jpg"))
shorts = np.array(shorts)
longs = np.array(longs)
longs.shape, shorts.shape

In [None]:
plt.style.use('dark_background')
subsampling = 4
patch_size=256
scale = 2

for k, imgs in enumerate([shorts, longs]):
    fig, axarr = plt.subplots(7,2, figsize=(4*2*scale,3*7*scale), sharex=True, sharey=True)
    for i in tqdm(range(len(imgs))):
        ax = axarr[i//2,i%2]
        ax.set_title(short_files[i].split('/')[-1].split('.')[0], fontsize=16)
        ax.imshow(imgs[i][::-1], origin="lower")
        ax.set_xticks([])
        ax.set_yticks([])

        axins = ax.inset_axes([-0.06125, 0.025, 0.5, 0.5])
        axins.imshow(imgs[i][::-1], origin="lower")
        centerx = imgs[i].shape[1]//2
        centery = imgs[i].shape[0]//2
        x1, x2, y1, y2 = centerx-patch_size//2, centerx+patch_size//2, centery-patch_size//2, centery+patch_size//2
        axins.set_xlim(x1, x2)
        axins.set_ylim(y1, y2)
        axins.set_xticks([])
        axins.set_yticks([])
        ax.indicate_inset_zoom(axins, edgecolor="white")
        #break
    plt.tight_layout()
    if k == 0:
        plt.savefig('shorts.jpg', bbox_inches='tight', pad_inches=0, dpi=150)
    elif k == 1:
        plt.savefig('longs.jpg', bbox_inches='tight', pad_inches=0, dpi=150)
    plt.show()

# Cluster colors to 256 kmeans clusters

In [None]:
from sklearn.cluster import KMeans
from tqdm import tqdm

gif = np.array([plt.imread('shorts.jpg'), plt.imread('longs.jpg')])
print(gif.shape)
rgbs = np.concatenate([g.reshape([np.prod(g.shape[:2]),3]) for g in gif])
print(rgbs.shape)

kmeans = KMeans(n_clusters=256)
kmeans.fit(rgbs[::100])
pred = kmeans.predict(rgbs)

crgbs = np.array([kmeans.cluster_centers_[x] for x in tqdm(pred)]).reshape([2,gif.shape[1],gif.shape[2],3]).astype(np.uint8)

plt.imshow(crgbs[0])
plt.show()

# Dump GIF

In [None]:
import imageio
from IPython import display

imageio.mimsave('dataset.gif', crgbs, fps=1)
display.Image(open("dataset.gif",'rb').read())