# AutoPhaseNN
#### Pytorch version (under development)

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from tqdm.notebook import tqdm 
import numpy as np
from numpy.fft import fftn, fftshift
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import matplotlib

In [None]:
from data_loader2 import *

In [None]:
plt.style.use('seaborn-white')
#matplotlib.rc('xtick', labelsize=20) 
#matplotlib.rc('ytick', labelsize=20)
matplotlib.rc('font',family='Times New Roman')
matplotlib.rcParams['font.size'] = 20
plt.viridis()
%matplotlib inline

## Define Contants
GPUs, Batch size, Training and Validation size, Shrink Wrap params, Epochs

In [None]:
EPOCHS = 60 #Full cycle is 12 epochs, good to end on a minimum

scale_I = 1 # normalize diff or not

MODEL_SAVE_PATH = '/lcrc/project/AutoPhase/CDInodefect_resized32/'
print(MODEL_SAVE_PATH)

In [None]:
NGPUS = torch.cuda.device_count()
BATCH_SIZE = NGPUS * 16 # 48
LR = NGPUS * 5e-3
print("GPUs:", NGPUS, "Batch size:", BATCH_SIZE, "Learning rate:", LR)

# N_TRAIN = 100000 # (including train and validation) Max is ~54k 
N_TRAIN = 1000
N_VALID = int(0.1*N_TRAIN)
TRAIN_ratio = 0.9

INIT_SW = 0.07
FINAL_SW = 0.1 #Anything >0.07 seems to break it
CONST_EPOCHS = 0 #How many epochs to not increase SW

SW_INCREMENT = (FINAL_SW-INIT_SW)/(EPOCHS-1-CONST_EPOCHS)
print("SW Thresh increment", SW_INCREMENT)


In [None]:
print ("PyTorch version", torch.__version__)
!nvidia-smi
#!cat /proc/cpuinfo | grep processor

### Plotting Helper

In [None]:
def plot3(data,titles):
    if(len(titles)<3):
        titles=["Plot1", "Plot2", "Plot3"]
    ind = data[0].shape[0]//2
    fig,ax = plt.subplots(1,3, figsize=(19,5))
    im=ax[0].imshow(data[0][ind])
    plt.colorbar(im, ax=ax[0])
    ax[0].set_title(titles[0])
    im=ax[1].imshow(data[1][ind])
    plt.colorbar(im, ax=ax[1])
    ax[1].set_title(titles[1])
    im=ax[2].imshow(data[2][ind])
    plt.colorbar(im, ax=ax[2])
    ax[2].set_title(titles[2])

In [None]:
def plot6(data,titles):
    if(len(titles)<3):
        titles=["Plot1", "Plot2", "Plot3", "Plot4", "Plot5", "Plot6"]
    ind = data[0].shape[0]//2
    fig, axes = plt.subplots(1,6, figsize=(19,3), constrained_layout=True)
    for ix, ax in enumerate(axes):
        im=ax.imshow(data[ix][ind])
        plt.colorbar(im, ax=ax)
        ax.set_title(titles[ix])

## Load and check data, then prepare as tensors

In [None]:
# # data loader
# data_path = '/lcrc/project/AutoPhase/CDI_simulation_upsamp_noise/'
# data_path = '/lcrc/project/AutoPhase/CDI_simulation_upsamp_aug_220429/'

data_path = '/lcrc/project/AutoPhase/CDInodefect_resized32_centered/'

# dataname_list = os.path.join(data_path, '3D_upsamp.txt')
dataname_list = os.path.join(data_path, 'CDI_defectFree.txt')
filelist = []

with open(dataname_list, 'r') as f:
    txtfile = f.readlines()
for i in range(len(txtfile)):
    tmp = str(txtfile[i]).split('/')[-1]
    tmp = tmp.split('\n')[0]

    filelist.append(tmp)
f.close()
print('number of available file:%d' % len(filelist))

# give training data size and filelist
train_file_indxs = np.random.randint(len(filelist), size=N_TRAIN).astype('int')

train_filelist = [filelist[idx] for idx in train_file_indxs]
print('number of training:%d' % len(train_filelist))


In [None]:
# load training data and validation data
train_dataset = Dataset(
    train_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='train',scale_I=scale_I)
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NGPUS*8)
# train_sampler = DistributedSampler(
#     train_dataset, num_replicas=NGPUS, rank=rank, shuffle=True)
# train_loader = DataLoader(
#     train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, shuffle=False, num_workers=NGPUS*8)


validation_dataset = Dataset(
    train_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='validation',scale_I=scale_I)
validation_loader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NGPUS*64)
# validation_sampler = DistributedSampler(
#     validation_dataset, num_replicas=NGPUS, rank=rank, shuffle=True)
# validation_loader = DataLoader(
#     validation_dataset, batch_size=BATCH_SIZE, sampler=validation_sampler, shuffle=False, num_workers=NGPUS*8)

In [None]:
nconv = 32
h,w,t = 32,32,32

class recon_model(nn.Module):

  H,W = h,w

  def __init__(self):
    super(recon_model, self).__init__()
    
    self.sw_thresh = INIT_SW
    
    self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
      nn.Conv3d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(in_channels=nconv, out_channels=nconv * 2, kernel_size=3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv * 2),
      #nn.MaxPool3d((2,2,2)),
    
      nn.Conv3d(nconv* 2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*4, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv* 4),
      #nn.MaxPool3d((2,2,2)),
        
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*8, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*8),
      #=nn.MaxPool3d((2,2,2)),
        
      nn.Conv3d(nconv*8, nconv*8, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*8),
      #nn.MaxPool3d((2,2,2))
      )
    
    self.decoder1 = nn.Sequential(
      nn.Conv3d(nconv*8, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Upsample(scale_factor=2, mode='trilinear'),
    
      nn.Conv3d(nconv*4, nconv*2, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv*4, nconv*2, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Upsample(scale_factor=2, mode='trilinear'), 
       
      nn.Conv3d(nconv*2, nconv, 3, stride=1, padding=1),  
      #nn.ConvTranspose3d(nconv*2, nconv, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(nconv, nconv, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Upsample(scale_factor=2, mode='trilinear'),
 
      nn.Conv3d(nconv, 1, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv, 1, 3, stride=2, padding=1, output_padding=1),
      nn.Sigmoid() #Amplitude model
#       nn.ReLU(),
      )
    
    self.decoder2 = nn.Sequential(
      nn.Conv3d(nconv*8, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Upsample(scale_factor=2, mode='trilinear'),
    
      nn.Conv3d(nconv*4, nconv*2, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv*4, nconv*2, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Upsample(scale_factor=2, mode='trilinear'), 
       
      nn.Conv3d(nconv*2, nconv, 3, stride=1, padding=1),  
      #nn.ConvTranspose3d(nconv*2, nconv, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(nconv, nconv, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Upsample(scale_factor=2, mode='trilinear'),
 
      nn.Conv3d(nconv, 1, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv, 1, 3, stride=2, padding=1, output_padding=1),
      nn.Tanh() #Phase model
#       nn.ReLU(),
      )


  def forward(self,x):
    x1 = self.encoder(x)
    
    amp = self.decoder1(x1)
    ph = self.decoder2(x1)
    
    #Normalize amp to max 1 before applying support
    
    #max_A = torch.amax(amp, dim=[-1, -2, -3], keepdim=True)
    #amp = torch.div(amp,max_A+1e-6) #Prevent zero div
    amp = torch.clip(amp, min=0, max=1.0)
    
    #Apply the support to amplitude
    mask = torch.tensor([0,1],dtype=amp.dtype, device=amp.device)
    amp = torch.where(amp<self.sw_thresh,mask[0],amp)
    
    #Restore -pi to pi range
    ph = ph*np.pi #Using tanh activation (-1 to 1) for phase so multiply by pi

    #Pad the predictions to 2X
    pad = nn.ConstantPad3d(int(self.H/2),0)
    amp = pad(amp)
    ph = pad(ph)

    #Create the complex number
    complex_x = torch.complex(amp*torch.cos(ph),amp*torch.sin(ph))

    #Compute FT, shift and take abs
    y = torch.fft.fftn(complex_x,dim=(-3,-2,-1))
    y = torch.fft.fftshift(y,dim=(-3,-2,-1)) #FFT shift will move the wrong dimensions if not specified
    y = torch.abs(y)
    
    #Normalize to scale_I
    if scale_I>0:
        max_I = torch.amax(y, dim=[-1, -2, -3], keepdim=True)
        y = scale_I*torch.div(y,max_I+1e-6) #Prevent zero div
    
    #get support for viz
    support = torch.zeros(amp.shape,device=amp.device)
    support = torch.where(amp<self.sw_thresh,mask[0],mask[1])
    #return amp, ph, support
    return y, complex_x, amp, ph, support

### Check the model works

In [None]:
model = recon_model()
for ft_images,amps,phs in train_loader:
    print("batch size:", ft_images.shape, amps.shape, phs.shape)
    #outs = model(ft_images)
    #print(*[o.shape for o in outs])
    #print(*[o.dtype for o in outs])
    y, complex_x, amp, ph, _ = model(ft_images)
    print(y.shape, complex_x.shape, amp.shape, ph.shape)
    print(y.dtype, complex_x.dtype, amp.dtype, ph.dtype)
    break

In [None]:
summary(model, (1, 1, h*2,w*2,t*2), device='cpu')

### Move model to appropriate device

In [None]:
model = recon_model()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if NGPUS > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model) #Default all devices

model = model.to(device)

## Setup optimizer and cyclicLR

In [None]:
#Optimizer details
iterations_per_epoch = np.floor((N_TRAIN*TRAIN_ratio)/BATCH_SIZE)+1 #Final batch will be less than batch size
step_size = 6*iterations_per_epoch #Paper recommends 2-10 number of iterations, step_size is half cycle
print("LR step size is:", step_size, "which is every %d epochs" %(step_size/iterations_per_epoch))

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=LR, 
                                              max_lr=LR*5, step_size_up=step_size,
                                              cycle_momentum=False, mode='triangular2')

## Training and Validation Loop

In [None]:
#Function to update saved model if validation loss is minimum
def update_saved_model(model, path):
    if not os.path.isdir(path):
        os.mkdir(path)
    for f in os.listdir(path):
        os.remove(os.path.join(path, f))
    if NGPUS > 1:
        torch.save(model.module.state_dict(),path+'best_model.pth') #Have to save the underlying model else will always need 4 GPUs
    else:
        torch.save(model.state_dict(),path+'best_model.pth')

In [None]:
def train(trainloader,metrics):
    loss_ft = 0.0
    loss_amp = 0.0
    loss_ph = 0.0
    
    for i, (ft_images,amps,phs) in tqdm(enumerate(trainloader)):
        ft_images = ft_images.to(device) #Move everything to device
        amps = amps.to(device)
        phs = phs.to(device)

        y, _, pred_amps, pred_phs, support = model(ft_images) #Forward pass
        
        #Compute losses
        loss_f = criterion(y, ft_images)
        loss_a = criterion(pred_amps,amps) #Monitor amplitude loss
        loss_p = criterion(pred_phs*support,phs) #Monitor phase loss but only within support (which may not be same as true amp)
        loss = loss_f #Use only FT loss for gradients
        #loss = loss_a + loss_p + loss_f

        #Zero current grads and do backprop
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()

        loss_ft += loss_f.detach().item()
        loss_amp += loss_a.detach().item()
        loss_ph += loss_p.detach().item()
        #loss_ft = 0

        #Update the LR according to the schedule -- CyclicLR updates each batch
        scheduler.step() 
        metrics['lrs'].append(scheduler.get_last_lr())
        
        
    #Divide cumulative loss by number of batches-- sli inaccurate because last batch is different size
    metrics['losses'].append([loss_ft/i,loss_amp/i,loss_ph/i]) 
    

def validate(validloader,metrics):
    val_loss_ft = 0.0
    val_loss_amp = 0.0
    val_loss_ph = 0.0
    for j, (ft_images,amps,phs) in enumerate(validloader):
        ft_images = ft_images.to(device)
        amps = amps.to(device)
        phs = phs.to(device)
        y, _, pred_amps, pred_phs, support = model(ft_images) #Forward pass
        #pred_amps, pred_phs, support = model(ft_images) #Forward pass
    
        val_loss_f = criterion(y, ft_images)
        val_loss_a = criterion(pred_amps,amps) 
        val_loss_p = criterion(pred_phs*support,phs) 
    
        val_loss_ft += val_loss_f.detach().item()
        val_loss_amp += val_loss_a.detach().item()
        val_loss_ph += val_loss_p.detach().item()  
        
    metrics['val_losses'].append([val_loss_ft/j,val_loss_amp/j,val_loss_ph/j])
    
  #Update saved model if val loss is lower
    if(val_loss_ft/j<metrics['best_val_loss']):
        print("Saving improved model after Val Loss improved from %.5f to %.5f" %(metrics['best_val_loss'],val_loss_ft/j))
        metrics['best_val_loss'] = val_loss_ft/j
        update_saved_model(model, MODEL_SAVE_PATH)

In [None]:
metrics = {'losses':[],'val_losses':[], 'lrs':[], 'best_val_loss' : np.inf}

for epoch in tqdm(range(EPOCHS)):
    
    #Set model to train mode
    model.train() 

    #Training loop
    train(train_loader,metrics)

    #Switch model to eval mode
    model.eval()

#     Validation loop
    validate(validation_loader,metrics)
    l = metrics['losses'][-1]
    lv =  metrics['val_losses'][-1]

    print(f'Epoch: {epoch} | FT  | Train Loss: {l[0]:.5f} | Val Loss: {lv[0]:.5f}')
    print(f'Epoch: {epoch} | Amp | Train Loss: {l[1]:.4f} | Val Loss: {lv[1]:.5f}')
    print(f'Epoch: {epoch} | Ph  | Train Loss: {l[2]:.3f} | Val Loss: {lv[2]:.5f}')
    if NGPUS>1:
        print(f'Epoch: {epoch} | SW Thresh: {model.module.sw_thresh:.4f}')
    else:
        print(f'Epoch: {epoch} | SW Thresh: {model.sw_thresh:.4f}')
           
    print(f'Epoch: {epoch} | Ending LR: {metrics["lrs"][-1][0]:.6f}')

    if(epoch>=(CONST_EPOCHS-1)): #Keep SW thresh =0 i.e fixed half box support for first N epochs
        if NGPUS>1:
            model.module.sw_thresh+=SW_INCREMENT #Update shrink-wrap threshold every epoch
        else:
            model.sw_thresh+=SW_INCREMENT #Update shrink-wrap threshold every epoch

In [None]:
batches = np.linspace(0,len(metrics['lrs']),len(metrics['lrs'])+1)
epoch_list = batches/iterations_per_epoch

plt.plot(epoch_list[1:],metrics['lrs'], 'C3-')
plt.grid()
plt.ylabel("Learning rate")
plt.xlabel("Epoch")

In [None]:
losses_arr = np.array(metrics['losses'])
val_losses_arr = np.array(metrics['val_losses'])
losses_arr.shape
fig, ax = plt.subplots(3,sharex=True, figsize=(15, 8))
ax[0].plot(losses_arr[:,0], 'C3o-', label = "Train FT loss")
ax[0].plot(val_losses_arr[:,0], 'C0o-', label = "Val FT loss")
ax[0].set(ylabel='Loss')
ax[0].grid()
ax[0].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
ax[1].plot(losses_arr[:,1], 'C3o-', label = "Train Amp loss")
ax[1].plot(val_losses_arr[:,1], 'C0o-', label = "Val Amp loss")
ax[1].set(ylabel='Loss')
ax[1].grid()
ax[1].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
ax[2].plot(losses_arr[:,2], 'C3o-', label = "Train Ph loss")
ax[2].plot(val_losses_arr[:,2], 'C0o-', label = "Val Ph loss")
ax[2].set(ylabel='Loss')
ax[2].grid()
ax[2].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))

plt.tight_layout()
plt.xlabel("Epochs")


In [None]:
# load test data
# give test data filelist
test_filelist = filelist[N_TRAIN:N_TRAIN+100]
print('number of test data:%d' % len(test_filelist))

# load test data 
test_dataset = Dataset(
    test_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='test',scale_I=scale_I,shuffle=False)
test_loader = DataLoader(
    test_dataset, batch_size=16, shuffle=False, num_workers=NGPUS)
# test_sampler = DistributedSampler(
#     test_dataset, num_replicas=1, rank=rank, shuffle=False)
# test_loader = DataLoader(
#     test_dataset, batch_size=16, sampler=test_sampler, shuffle=False, **kwargs)

# load saved model

In [None]:
# model_path = MODEL_SAVE_PATH + 'best_model.pth'

In [None]:
# model = recon_model()
# model.load_state_dict(torch.load(model_path))

In [None]:
model.eval() #imp when have dropout etc
ft_results=[]
complex_results=[]
amp_preds = []
ph_preds  = []
supports = []

ft_test_array = []
amp_test_array = []
ph_test_array = []


for i, (ft_images,amps,phs) in enumerate(test_loader):
    ft_images = ft_images.to(device)
    y, complex_x, amp, ph, support = model(ft_images)
    #amp, ph, support = model(ft_images)
    for j in range(amp.shape[0]):
        #prediction
        ft_results.append([y[j].detach().to("cpu").numpy()])
        complex_results.append([complex_x[j].detach().to("cpu").numpy()])
        amp_preds.append(amp[j].detach().to("cpu").numpy())
        ph_preds.append(ph[j].detach().to("cpu").numpy())
        supports.append(support[j].detach().to("cpu").numpy())
        
        #ground truth
        ft_test_array.append([ft_images[j].detach().to("cpu").numpy()])
        amp_test_array.append([amps[j].detach().to("cpu").numpy()])
        ph_test_array.append([phs[j].detach().to("cpu").numpy()])
        

In [None]:
ft_results = np.array(ft_results).squeeze()
complex_results = np.array(complex_results).squeeze()
amp_preds = np.array(amp_preds).squeeze()
ph_preds = np.array(ph_preds).squeeze()
supports = np.array(supports).squeeze()

print(ft_results.shape, ft_results.dtype)

ft_test_array = np.array(ft_test_array).squeeze()
amp_test_array = np.array(amp_test_array).squeeze()
ph_test_array = np.array(ph_test_array).squeeze()

print(ft_test_array.shape, ft_test_array.dtype)

In [None]:
# Test images
to_plot = np.random.randint(100, size=2)
# to_plot = [20,50]

plots = {0:[], 1:[], 2:[]}
titles = {0:[], 1:[], 2:[]}
for i in to_plot:
    plots[0] += [ft_test_array[i],ft_results[i],ft_test_array[i]-ft_results[i]]
    titles[0] += [f"FT input, {i}", f"Pred FT, {i}", f"Diff FT, {i}"]
    
    plots[1] += [supports[i],amp_test_array[i],amp_preds[i]]
    titles[1] += [f"Support, {i}", f"True Amp {i}", f"Predicted Amp {i}"]
    
    tmp = ph_preds[i]*supports[i]
    plots[2] += [ph_test_array[i],tmp,ph_test_array[i]-tmp]
    titles[2] += [f"True ph, {i}", f"Pred Ph {i}", f"Diff Ph {i}"]

for j in range(3):
    plot6(plots[j], titles[j])