In [None]:
import sys
sys.path.append("/home/conradb/git/ifg-ssl")
import os
import torch
import torchvision.transforms as transforms 
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import MAE.models_mae as models_mae

In [None]:
def show_image(image, title=''):
    # image is [H x W x 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image*255), 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    model = getattr(models_mae, arch)()
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    print(x.shape)
    x = torch.einsum('nhwc->nchw', x)
    print(x.shape)

    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()


In [None]:
# 2018 Kīlauea U:\scratch\SDF25\LiCSAR-web-tools\087D_07004_060904\interferograms\20180411_20180511\20180411_20180511.geo.diff.png
img_path = '/scratch/SDF25/LiCSAR-web-tools/087D_07004_060904/interferograms/20180411_20180511/20180411_20180511.geo.diff.png'

img = Image.open(img_path)
img = img.resize((224,224))
img = np.array(img)/255

assert img.shape == (224, 224, 3)

plt.rcParams['figure.figsize'] = [5, 5]
show_image(torch.tensor(img))

In [None]:
#Transform Grayscale + ToTensor (which normalizes /255)

img = Image.open(img_path)
img = img.resize((224,224))
transforms = transforms.Compose([transforms.Grayscale(num_output_channels=3), transforms.ToTensor()])
img = transforms(img)
print(img.shape)