In [None]:
import torch
import importlib
import os
import sys
import time
from skimage import io



sys.path.append(os.getcwd()+'/model/')
sys.path.append(os.getcwd()+'/losses/')
sys.path.append(os.getcwd()+'/Dataloader/')


from Encoder_model import *
from Decoder_model import *
from Comparator_model import *
from Dense_Spatial_Transformation import * 
#from Inference_MorphSSL import *
from MorphSSL_dataloader import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def my_normalize(I):
    I=I.astype(float)
    I=I-np.min(I)
    I=I/np.max(I)
    I=I*255
    I=I.astype(np.uint8)
    
    return I

def my_normalize_scale(I):
    
    I=I.astype(float)
    I=np.abs(I)
    I=I/np.max(I)
    I=I*255
    I=I.astype(np.uint8)
    
    return I



def my_visualize(out_lst,nm):
    
    out_col=[]
    for j in range(0, len(out_lst)):
        tmp=out_lst[j]
        tmp=tmp.detach().cpu().numpy()
        tmp=np.squeeze(tmp)
        tmp=my_normalize(tmp)
        
        out_row=[]
        for slc in range(0,4):
            out_row.append(tmp[:,:,slc])
            pad=255*np.ones((30, tmp.shape[1])).astype(np.uint8)
            out_row.append(pad)
        # Concat along row
        out_row=np.concatenate(out_row, axis=0)
        out_col.append(out_row)
        pad=255*np.ones((out_row.shape[0],30)).astype(np.uint8)
        out_col.append(pad)
        
    out_col=np.concatenate(out_col, axis=1)
    # save image
    io.imsave(nm, out_col)

In [None]:
def inference(I_A, I_B, nm):
    
    with torch.no_grad():
        ftr_A=encoder_model(I_A) # time t
        ftr_B=encoder_model(I_B) # time t+k
        
    out_lst=[]
    out_lst.append(I_A[:,:,:,:,15:19]) # Only visualize the central 4 B-scans instead of all 32 B-scans in the volume
    for k in [0.2, 0.4, 0.6, 0.8, 1.0]: # could be changed as per requirement
        ftr_new=ftr_A+k*(ftr_B-ftr_A)
        D_AB, C_AB=decoder_model(ftr_A, ftr_new)
        I_out, _=spatial_transform(I_A, D_AB, grid_4,'nearest')
        I_out=I_out+C_AB
        out_lst.append(I_out[:,:,:,:,15:19])
        del D_AB, C_AB, ftr_new, I_out
    
    out_lst.append(I_B[:,:,:,:,15:19])
    my_visualize(out_lst,nm)
    return 


#####################################################################

def visualize_linear_interpolation(val_loader, out_pth):  
    decoder_model.eval()
    encoder_model.eval()
    
    for i, sample in enumerate(val_loader):
        print(i)
                
        I_A=sample['I_A']
        I_B=sample['I_B']
        
        nm_A=sample['nm_A'][0]
        nm_B=sample['nm_B'][0]
        
        nm=nm_A+'_'+nm_B+'.png'
        
        I_A=I_A.to(device)
        I_B=I_B.to(device)
        
        inference(I_A, I_B, out_pth+nm)
        del I_A, I_B
    
   
    return 
        

In [None]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [None]:
eps=10**(-14)
grid_4=create_mesh_grid(192,192,32) # Spatial index used to apply the spatial deformation

############### Dataloader #####################
# see comments in /Dataloader/MorphSSL_dataloader.py  for img_pth, img_pairs
val_data=val_dataset(img_pth='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/preprocessed_SSL_images2/',
                     img_pairs='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/step4_final_val_ssl_data.npz')
val_loader=DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


########### Instantiate the Models ##############
#### Encoder ###
encoder_model=Encoder_Architecture(base_chnls=16, out_ftr_dim=(64*2))
encoder_model.to(device)
#### Decoder ###
decoder_model=Decoder_Architecture(in_dim=64, first_dim=512)
decoder_model.to(device)

In [None]:
checkpoint = torch.load('best_weight-0.007980789116118103.pt')
encoder_model.load_state_dict(checkpoint['model_state_dict_encoder_model'])
decoder_model.load_state_dict(checkpoint['model_state_dict_deform_model']) #or model_state_dict_decoder_model
del checkpoint

In [None]:
out_pth=os.getcwd()+'/visualize_linear_interpolation/'
if not os.path.exists(out_pth):
    os.makedirs(out_pth)

visualize_linear_interpolation(val_loader, out_pth)