# AutoPhaseNN
#### Pytorch version (under development)

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="7"

import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from tqdm.notebook import tqdm 
import numpy as np
from numpy.fft import fftn, fftshift
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import matplotlib

In [2]:
from data_loader import *

In [3]:
plt.style.use('seaborn-white')
#matplotlib.rc('xtick', labelsize=20) 
#matplotlib.rc('ytick', labelsize=20)
matplotlib.rc('font',family='Times New Roman')
matplotlib.rcParams['font.size'] = 20
plt.viridis()
%matplotlib inline

## Define Contants
GPUs, Batch size, Training and Validation size, Shrink Wrap params, Epochs

In [4]:
EPOCHS = 60 #Full cycle is 12 epochs, good to end on a minimum

scale_I = 1 # normalize diff or not

MODEL_SAVE_PATH = '/lcrc/project/AutoPhase/SN/'
print(MODEL_SAVE_PATH)

/lcrc/project/AutoPhase/SN/


In [5]:
NGPUS = torch.cuda.device_count()
BATCH_SIZE = NGPUS * 32 
LR = NGPUS * 1e-4
print("GPUs:", NGPUS, "Batch size:", BATCH_SIZE, "Learning rate:", LR)

N_TRAIN = 50000
N_VALID = int(0.1*N_TRAIN)
TRAIN_ratio = 0.9

INIT_SW = 0.07
FINAL_SW = 0.1 #Anything >0.07 seems to break it
CONST_EPOCHS = 0 #How many epochs to not increase SW

SW_INCREMENT = (FINAL_SW-INIT_SW)/(EPOCHS-1-CONST_EPOCHS)
print("SW Thresh increment", SW_INCREMENT)


GPUs: 1 Batch size: 32 Learning rate: 0.0001
SW Thresh increment 0.0005084745762711864


In [6]:
print ("PyTorch version", torch.__version__)
!nvidia-smi
#!cat /proc/cpuinfo | grep processor

PyTorch version 1.10.0
Fri Aug 12 15:09:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06   Driver Version: 470.129.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:07:00.0 Off |                    0 |
| N/A   29C    P0    70W / 400W |  30096MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:0F:00.0 Off |                    0 |
| N/A   28C    P0    67W / 400W |  30232MiB / 81251MiB |      0% 

### Plotting Helper

In [7]:
def plot3(data,titles):
    if(len(titles)<3):
        titles=["Plot1", "Plot2", "Plot3"]
    ind = data[0].shape[0]//2
    fig,ax = plt.subplots(1,3, figsize=(19,5))
    im=ax[0].imshow(data[0][ind])
    plt.colorbar(im, ax=ax[0])
    ax[0].set_title(titles[0])
    im=ax[1].imshow(data[1][ind])
    plt.colorbar(im, ax=ax[1])
    ax[1].set_title(titles[1])
    im=ax[2].imshow(data[2][ind])
    plt.colorbar(im, ax=ax[2])
    ax[2].set_title(titles[2])

In [8]:
def plot6(data,titles):
    if(len(titles)<3):
        titles=["Plot1", "Plot2", "Plot3", "Plot4", "Plot5", "Plot6"]
    ind = data[0].shape[0]//2
    fig, axes = plt.subplots(1,6, figsize=(19,3), constrained_layout=True)
    for ix, ax in enumerate(axes):
        im=ax.imshow(data[ix][ind])
        plt.colorbar(im, ax=ax)
        ax.set_title(titles[ix])

## Load and check data, then prepare as tensors

In [9]:
# # data loader
data_path = '/lcrc/project/AutoPhase/CDI_simulation_upsamp_noise/'

dataname_list = os.path.join(data_path, '3D_upsamp.txt')
filelist = []

with open(dataname_list, 'r') as f:
    txtfile = f.readlines()
for i in range(len(txtfile)):
    tmp = str(txtfile[i]).split('/')[-1]
    tmp = tmp.split('\n')[0]

    filelist.append(tmp)
f.close()
print('number of available file:%d' % len(filelist))

# give training data size and filelist
# train_file_indxs = np.random.randint(len(filelist), size=N_TRAIN).astype('int')
# train_filelist = [filelist[idx] for idx in train_file_indxs]
train_filelist = filelist[:N_TRAIN]
print('number of training:%d' % len(train_filelist))


number of available file:54028
number of training:50000


In [10]:
# load training data and validation data
train_dataset = Dataset(
    train_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='train',scale_I=scale_I)
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NGPUS*8)
# train_sampler = DistributedSampler(
#     train_dataset, num_replicas=NGPUS, rank=rank, shuffle=True)
# train_loader = DataLoader(
#     train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, shuffle=False, num_workers=NGPUS*8)


validation_dataset = Dataset(
    train_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='validation',scale_I=scale_I)
validation_loader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NGPUS*64)
# validation_sampler = DistributedSampler(
#     validation_dataset, num_replicas=NGPUS, rank=rank, shuffle=True)
# validation_loader = DataLoader(
#     validation_dataset, batch_size=BATCH_SIZE, sampler=validation_sampler, shuffle=False, num_workers=NGPUS*8)

In [11]:
nconv = 32
h,w,t = 32,32,32

class recon_model(nn.Module):

  H,W = h,w

  def __init__(self):
    super(recon_model, self).__init__()
    
    self.sw_thresh = INIT_SW
    
    self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
      nn.Conv3d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(in_channels=nconv, out_channels=nconv * 2, kernel_size=3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv * 2),
      #nn.MaxPool3d((2,2,2)),
    
      nn.Conv3d(nconv* 2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*4, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv* 4),
      #nn.MaxPool3d((2,2,2)),
        
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*8, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*8),
      #=nn.MaxPool3d((2,2,2)),
        
      nn.Conv3d(nconv*8, nconv*8, 3, stride=2, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*8),
      #nn.MaxPool3d((2,2,2))
      )
    
    self.decoder1 = nn.Sequential(
      nn.Conv3d(nconv*8, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Upsample(scale_factor=2, mode='trilinear'),
    
      nn.Conv3d(nconv*4, nconv*2, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv*4, nconv*2, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Upsample(scale_factor=2, mode='trilinear'), 
       
      nn.Conv3d(nconv*2, nconv, 3, stride=1, padding=1),  
      #nn.ConvTranspose3d(nconv*2, nconv, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(nconv, nconv, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Upsample(scale_factor=2, mode='trilinear'),
 
      nn.Conv3d(nconv, 1, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv, 1, 3, stride=2, padding=1, output_padding=1),
      nn.Sigmoid() #Amplitude model
#       nn.ReLU(),
      )
    
    self.decoder2 = nn.Sequential(
      nn.Conv3d(nconv*8, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Conv3d(nconv*4, nconv*4, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*4),
      nn.Upsample(scale_factor=2, mode='trilinear'),
    
      nn.Conv3d(nconv*4, nconv*2, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv*4, nconv*2, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Conv3d(nconv*2, nconv*2, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv*2),
      nn.Upsample(scale_factor=2, mode='trilinear'), 
       
      nn.Conv3d(nconv*2, nconv, 3, stride=1, padding=1),  
      #nn.ConvTranspose3d(nconv*2, nconv, 3, stride=2, padding=1, output_padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Conv3d(nconv, nconv, 3, stride=1, padding=1),
#       nn.ReLU(),
      nn.LeakyReLU(negative_slope=0.01),
      nn.BatchNorm3d(nconv),
      nn.Upsample(scale_factor=2, mode='trilinear'),
 
      nn.Conv3d(nconv, 1, 3, stride=1, padding=1),
      #nn.ConvTranspose3d(nconv, 1, 3, stride=2, padding=1, output_padding=1),
      nn.Tanh() #Phase model
#       nn.ReLU(),
      )


  def forward(self,x):
    x1 = self.encoder(x)
    
    amp = self.decoder1(x1)
    ph = self.decoder2(x1)
    
    #Normalize amp to max 1 before applying support
    
    #max_A = torch.amax(amp, dim=[-1, -2, -3], keepdim=True)
    #amp = torch.div(amp,max_A+1e-6) #Prevent zero div
    amp = torch.clip(amp, min=0, max=1.0)
    
    #Apply the support to amplitude
    mask = torch.tensor([0,1],dtype=amp.dtype, device=amp.device)
    amp = torch.where(amp<self.sw_thresh,mask[0],amp)
    
    #Restore -pi to pi range
    ph = ph*np.pi #Using tanh activation (-1 to 1) for phase so multiply by pi

    #Pad the predictions to 2X
    pad = nn.ConstantPad3d(int(self.H/2),0)
    amp = pad(amp)
    ph = pad(ph)

    #Create the complex number
    complex_x = torch.complex(amp*torch.cos(ph),amp*torch.sin(ph))

    #Compute FT, shift and take abs
    y = torch.fft.fftn(complex_x,dim=(-3,-2,-1))
    y = torch.fft.fftshift(y,dim=(-3,-2,-1)) #FFT shift will move the wrong dimensions if not specified
    y = torch.abs(y)
    
    #Normalize to scale_I
    if scale_I>0:
        max_I = torch.amax(y, dim=[-1, -2, -3], keepdim=True)
        y = scale_I*torch.div(y,max_I+1e-6) #Prevent zero div
    
    #get support for viz
    support = torch.zeros(amp.shape,device=amp.device)
    support = torch.where(amp<self.sw_thresh,mask[0],mask[1])
    #return amp, ph, support
    return y, complex_x, amp, ph, support

### Check the model works

In [12]:
model = recon_model()
for ft_images,amps,phs in train_loader:
    print("batch size:", ft_images.shape, amps.shape, phs.shape)
    #outs = model(ft_images)
    #print(*[o.shape for o in outs])
    #print(*[o.dtype for o in outs])
    y, complex_x, amp, ph, _ = model(ft_images)
    print(y.shape, complex_x.shape, amp.shape, ph.shape)
    print(y.dtype, complex_x.dtype, amp.dtype, ph.dtype)
    break

batch size: torch.Size([32, 1, 64, 64, 64]) torch.Size([32, 1, 64, 64, 64]) torch.Size([32, 1, 64, 64, 64])




torch.Size([32, 1, 64, 64, 64]) torch.Size([32, 1, 64, 64, 64]) torch.Size([32, 1, 64, 64, 64]) torch.Size([32, 1, 64, 64, 64])
torch.float32 torch.complex64 torch.float32 torch.float32


In [13]:
summary(model, (1, 1, h*2,w*2,t*2), device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
recon_model                              --                        --
├─Sequential: 1-1                        [1, 256, 4, 4, 4]         --
│    └─Conv3d: 2-1                       [1, 32, 64, 64, 64]       896
│    └─LeakyReLU: 2-2                    [1, 32, 64, 64, 64]       --
│    └─BatchNorm3d: 2-3                  [1, 32, 64, 64, 64]       64
│    └─Conv3d: 2-4                       [1, 64, 32, 32, 32]       55,360
│    └─LeakyReLU: 2-5                    [1, 64, 32, 32, 32]       --
│    └─BatchNorm3d: 2-6                  [1, 64, 32, 32, 32]       128
│    └─Conv3d: 2-7                       [1, 64, 32, 32, 32]       110,656
│    └─LeakyReLU: 2-8                    [1, 64, 32, 32, 32]       --
│    └─BatchNorm3d: 2-9                  [1, 64, 32, 32, 32]       128
│    └─Conv3d: 2-10                      [1, 128, 16, 16, 16]      221,312
│    └─LeakyReLU: 2-11                   [1, 128, 16, 16, 16]      -

### Move model to appropriate device

In [14]:
model = recon_model()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if NGPUS > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model) #Default all devices

model = model.to(device)

## Setup optimizer and cyclicLR

In [15]:
#Optimizer details
iterations_per_epoch = np.floor((N_TRAIN*TRAIN_ratio)/BATCH_SIZE)+1 #Final batch will be less than batch size
step_size = 6*iterations_per_epoch #Paper recommends 2-10 number of iterations, step_size is half cycle
print("LR step size is:", step_size, "which is every %d epochs" %(step_size/iterations_per_epoch))

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=LR, 
                                              max_lr=LR*5, step_size_up=step_size,
                                              cycle_momentum=False, mode='triangular2')

LR step size is: 8442.0 which is every 6 epochs


## Training and Validation Loop

In [16]:
#Function to update saved model if validation loss is minimum
def update_saved_model(model, path):
    if not os.path.isdir(path):
        os.mkdir(path)
    for f in os.listdir(path):
        os.remove(os.path.join(path, f))
    if NGPUS > 1:
        torch.save(model.module.state_dict(),path+'best_model.pth') #Have to save the underlying model else will always need 4 GPUs
    else:
        torch.save(model.state_dict(),path+'best_model.pth')

In [17]:
def train(trainloader,metrics):
    loss_ft = 0.0
    loss_amp = 0.0
    loss_ph = 0.0
    
    for i, (ft_images,amps,phs) in tqdm(enumerate(trainloader)):
        ft_images = ft_images.to(device) #Move everything to device
        amps = amps.to(device)
        phs = phs.to(device)

        y, _, pred_amps, pred_phs, support = model(ft_images) #Forward pass
        
        #Compute losses
        loss_f = criterion(y, ft_images)
        loss_a = criterion(pred_amps,amps) #Monitor amplitude loss
        loss_p = criterion(pred_phs*support,phs) #Monitor phase loss but only within support (which may not be same as true amp)
        loss = loss_f #Use only FT loss for gradients
        #loss = loss_a + loss_p + loss_f

        #Zero current grads and do backprop
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()

        loss_ft += loss_f.detach().item()
        loss_amp += loss_a.detach().item()
        loss_ph += loss_p.detach().item()
        #loss_ft = 0

        #Update the LR according to the schedule -- CyclicLR updates each batch
        scheduler.step() 
        metrics['lrs'].append(scheduler.get_last_lr())
        
        
    #Divide cumulative loss by number of batches-- sli inaccurate because last batch is different size
    metrics['losses'].append([loss_ft/i,loss_amp/i,loss_ph/i]) 
    

def validate(validloader,metrics):
    val_loss_ft = 0.0
    val_loss_amp = 0.0
    val_loss_ph = 0.0
    for j, (ft_images,amps,phs) in enumerate(validloader):
        ft_images = ft_images.to(device)
        amps = amps.to(device)
        phs = phs.to(device)
        y, _, pred_amps, pred_phs, support = model(ft_images) #Forward pass
        #pred_amps, pred_phs, support = model(ft_images) #Forward pass
    
        val_loss_f = criterion(y, ft_images)
        val_loss_a = criterion(pred_amps,amps) 
        val_loss_p = criterion(pred_phs*support,phs) 
    
        val_loss_ft += val_loss_f.detach().item()
        val_loss_amp += val_loss_a.detach().item()
        val_loss_ph += val_loss_p.detach().item()  
        
    metrics['val_losses'].append([val_loss_ft/j,val_loss_amp/j,val_loss_ph/j])
    
  #Update saved model if val loss is lower
    if(val_loss_ft/j<metrics['best_val_loss']):
        print("Saving improved model after Val Loss improved from %.5f to %.5f" %(metrics['best_val_loss'],val_loss_ft/j))
        metrics['best_val_loss'] = val_loss_ft/j
        update_saved_model(model, MODEL_SAVE_PATH)

In [None]:
metrics = {'losses':[],'val_losses':[], 'lrs':[], 'best_val_loss' : np.inf}

for epoch in tqdm(range(EPOCHS)):
    
    #Set model to train mode
    model.train() 

    #Training loop
    train(train_loader,metrics)

    #Switch model to eval mode
    model.eval()

#     Validation loop
    validate(validation_loader,metrics)
    l = metrics['losses'][-1]
    lv =  metrics['val_losses'][-1]

    print(f'Epoch: {epoch} | FT  | Train Loss: {l[0]:.5f} | Val Loss: {lv[0]:.5f}')
    print(f'Epoch: {epoch} | Amp | Train Loss: {l[1]:.4f} | Val Loss: {lv[1]:.5f}')
    print(f'Epoch: {epoch} | Ph  | Train Loss: {l[2]:.3f} | Val Loss: {lv[2]:.5f}')
    if NGPUS>1:
        print(f'Epoch: {epoch} | SW Thresh: {model.module.sw_thresh:.4f}')
    else:
        print(f'Epoch: {epoch} | SW Thresh: {model.sw_thresh:.4f}')
           
    print(f'Epoch: {epoch} | Ending LR: {metrics["lrs"][-1][0]:.6f}')

    if(epoch>=(CONST_EPOCHS-1)): #Keep SW thresh =0 i.e fixed half box support for first N epochs
        if NGPUS>1:
            model.module.sw_thresh+=SW_INCREMENT #Update shrink-wrap threshold every epoch
        else:
            model.sw_thresh+=SW_INCREMENT #Update shrink-wrap threshold every epoch

  0%|          | 0/60 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Saving improved model after Val Loss improved from inf to 0.00152
Epoch: 0 | FT  | Train Loss: 0.00168 | Val Loss: 0.00152
Epoch: 0 | Amp | Train Loss: 0.0302 | Val Loss: 0.02510
Epoch: 0 | Ph  | Train Loss: 0.056 | Val Loss: 0.03133
Epoch: 0 | SW Thresh: 0.0700
Epoch: 0 | Ending LR: 0.000167


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00152 to 0.00146
Epoch: 1 | FT  | Train Loss: 0.00144 | Val Loss: 0.00146
Epoch: 1 | Amp | Train Loss: 0.0245 | Val Loss: 0.02397
Epoch: 1 | Ph  | Train Loss: 0.031 | Val Loss: 0.03356
Epoch: 1 | SW Thresh: 0.0705
Epoch: 1 | Ending LR: 0.000233


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00146 to 0.00141
Epoch: 2 | FT  | Train Loss: 0.00140 | Val Loss: 0.00141
Epoch: 2 | Amp | Train Loss: 0.0240 | Val Loss: 0.02350
Epoch: 2 | Ph  | Train Loss: 0.030 | Val Loss: 0.02988
Epoch: 2 | SW Thresh: 0.0710
Epoch: 2 | Ending LR: 0.000300


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00141 to 0.00138
Epoch: 3 | FT  | Train Loss: 0.00136 | Val Loss: 0.00138
Epoch: 3 | Amp | Train Loss: 0.0234 | Val Loss: 0.02284
Epoch: 3 | Ph  | Train Loss: 0.030 | Val Loss: 0.02961
Epoch: 3 | SW Thresh: 0.0715
Epoch: 3 | Ending LR: 0.000367


0it [00:00, ?it/s]

Epoch: 4 | FT  | Train Loss: 0.00134 | Val Loss: 0.00140
Epoch: 4 | Amp | Train Loss: 0.0218 | Val Loss: 0.01787
Epoch: 4 | Ph  | Train Loss: 0.029 | Val Loss: 0.02524
Epoch: 4 | SW Thresh: 0.0720
Epoch: 4 | Ending LR: 0.000433


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00138 to 0.00138
Epoch: 5 | FT  | Train Loss: 0.00132 | Val Loss: 0.00138
Epoch: 5 | Amp | Train Loss: 0.0172 | Val Loss: 0.01458
Epoch: 5 | Ph  | Train Loss: 0.025 | Val Loss: 0.02163
Epoch: 5 | SW Thresh: 0.0725
Epoch: 5 | Ending LR: 0.000500


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00138 to 0.00133
Epoch: 6 | FT  | Train Loss: 0.00129 | Val Loss: 0.00133
Epoch: 6 | Amp | Train Loss: 0.0128 | Val Loss: 0.01192
Epoch: 6 | Ph  | Train Loss: 0.021 | Val Loss: 0.02039
Epoch: 6 | SW Thresh: 0.0731
Epoch: 6 | Ending LR: 0.000433


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00133 to 0.00126
Epoch: 7 | FT  | Train Loss: 0.00124 | Val Loss: 0.00126
Epoch: 7 | Amp | Train Loss: 0.0127 | Val Loss: 0.01263
Epoch: 7 | Ph  | Train Loss: 0.021 | Val Loss: 0.02128
Epoch: 7 | SW Thresh: 0.0736
Epoch: 7 | Ending LR: 0.000367


0it [00:00, ?it/s]

Epoch: 8 | FT  | Train Loss: 0.00120 | Val Loss: 0.00142
Epoch: 8 | Amp | Train Loss: 0.0129 | Val Loss: 0.01390
Epoch: 8 | Ph  | Train Loss: 0.021 | Val Loss: 0.02383
Epoch: 8 | SW Thresh: 0.0741
Epoch: 8 | Ending LR: 0.000300


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00126 to 0.00121
Epoch: 9 | FT  | Train Loss: 0.00120 | Val Loss: 0.00121
Epoch: 9 | Amp | Train Loss: 0.0131 | Val Loss: 0.01320
Epoch: 9 | Ph  | Train Loss: 0.021 | Val Loss: 0.02088
Epoch: 9 | SW Thresh: 0.0746
Epoch: 9 | Ending LR: 0.000233


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00121 to 0.00118
Epoch: 10 | FT  | Train Loss: 0.00114 | Val Loss: 0.00118
Epoch: 10 | Amp | Train Loss: 0.0129 | Val Loss: 0.01272
Epoch: 10 | Ph  | Train Loss: 0.021 | Val Loss: 0.02053
Epoch: 10 | SW Thresh: 0.0751
Epoch: 10 | Ending LR: 0.000167


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00118 to 0.00116
Epoch: 11 | FT  | Train Loss: 0.00111 | Val Loss: 0.00116
Epoch: 11 | Amp | Train Loss: 0.0129 | Val Loss: 0.01292
Epoch: 11 | Ph  | Train Loss: 0.021 | Val Loss: 0.02063
Epoch: 11 | SW Thresh: 0.0756
Epoch: 11 | Ending LR: 0.000100


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00116 to 0.00116
Epoch: 12 | FT  | Train Loss: 0.00109 | Val Loss: 0.00116
Epoch: 12 | Amp | Train Loss: 0.0129 | Val Loss: 0.01320
Epoch: 12 | Ph  | Train Loss: 0.021 | Val Loss: 0.02120
Epoch: 12 | SW Thresh: 0.0761
Epoch: 12 | Ending LR: 0.000133


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00116 to 0.00115
Epoch: 13 | FT  | Train Loss: 0.00108 | Val Loss: 0.00115
Epoch: 13 | Amp | Train Loss: 0.0129 | Val Loss: 0.01290
Epoch: 13 | Ph  | Train Loss: 0.021 | Val Loss: 0.02081
Epoch: 13 | SW Thresh: 0.0766
Epoch: 13 | Ending LR: 0.000167


0it [00:00, ?it/s]

Epoch: 14 | FT  | Train Loss: 0.00108 | Val Loss: 0.00124
Epoch: 14 | Amp | Train Loss: 0.0129 | Val Loss: 0.01253
Epoch: 14 | Ph  | Train Loss: 0.021 | Val Loss: 0.02019
Epoch: 14 | SW Thresh: 0.0771
Epoch: 14 | Ending LR: 0.000200


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00115 to 0.00115
Epoch: 15 | FT  | Train Loss: 0.00107 | Val Loss: 0.00115
Epoch: 15 | Amp | Train Loss: 0.0129 | Val Loss: 0.01262
Epoch: 15 | Ph  | Train Loss: 0.021 | Val Loss: 0.02011
Epoch: 15 | SW Thresh: 0.0776
Epoch: 15 | Ending LR: 0.000233


0it [00:00, ?it/s]

Epoch: 16 | FT  | Train Loss: 0.00107 | Val Loss: 0.00119
Epoch: 16 | Amp | Train Loss: 0.0129 | Val Loss: 0.01293
Epoch: 16 | Ph  | Train Loss: 0.021 | Val Loss: 0.02021
Epoch: 16 | SW Thresh: 0.0781
Epoch: 16 | Ending LR: 0.000267


0it [00:00, ?it/s]

Epoch: 17 | FT  | Train Loss: 0.00106 | Val Loss: 0.00115
Epoch: 17 | Amp | Train Loss: 0.0129 | Val Loss: 0.01306
Epoch: 17 | Ph  | Train Loss: 0.021 | Val Loss: 0.02096
Epoch: 17 | SW Thresh: 0.0786
Epoch: 17 | Ending LR: 0.000300


0it [00:00, ?it/s]

Epoch: 18 | FT  | Train Loss: 0.00105 | Val Loss: 0.00121
Epoch: 18 | Amp | Train Loss: 0.0129 | Val Loss: 0.01305
Epoch: 18 | Ph  | Train Loss: 0.021 | Val Loss: 0.02089
Epoch: 18 | SW Thresh: 0.0792
Epoch: 18 | Ending LR: 0.000267


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00115 to 0.00114
Epoch: 19 | FT  | Train Loss: 0.00103 | Val Loss: 0.00114
Epoch: 19 | Amp | Train Loss: 0.0129 | Val Loss: 0.01299
Epoch: 19 | Ph  | Train Loss: 0.021 | Val Loss: 0.02101
Epoch: 19 | SW Thresh: 0.0797
Epoch: 19 | Ending LR: 0.000233


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00114 to 0.00111
Epoch: 20 | FT  | Train Loss: 0.00101 | Val Loss: 0.00111
Epoch: 20 | Amp | Train Loss: 0.0128 | Val Loss: 0.01325
Epoch: 20 | Ph  | Train Loss: 0.020 | Val Loss: 0.02105
Epoch: 20 | SW Thresh: 0.0802
Epoch: 20 | Ending LR: 0.000200


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00111 to 0.00108
Epoch: 21 | FT  | Train Loss: 0.00099 | Val Loss: 0.00108
Epoch: 21 | Amp | Train Loss: 0.0128 | Val Loss: 0.01299
Epoch: 21 | Ph  | Train Loss: 0.020 | Val Loss: 0.02058
Epoch: 21 | SW Thresh: 0.0807
Epoch: 21 | Ending LR: 0.000167


0it [00:00, ?it/s]

Epoch: 22 | FT  | Train Loss: 0.00097 | Val Loss: 0.00108
Epoch: 22 | Amp | Train Loss: 0.0128 | Val Loss: 0.01303
Epoch: 22 | Ph  | Train Loss: 0.020 | Val Loss: 0.02050
Epoch: 22 | SW Thresh: 0.0812
Epoch: 22 | Ending LR: 0.000133


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00108 to 0.00106
Epoch: 23 | FT  | Train Loss: 0.00096 | Val Loss: 0.00106
Epoch: 23 | Amp | Train Loss: 0.0128 | Val Loss: 0.01279
Epoch: 23 | Ph  | Train Loss: 0.020 | Val Loss: 0.02029
Epoch: 23 | SW Thresh: 0.0817
Epoch: 23 | Ending LR: 0.000100


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00106 to 0.00106
Epoch: 24 | FT  | Train Loss: 0.00095 | Val Loss: 0.00106
Epoch: 24 | Amp | Train Loss: 0.0128 | Val Loss: 0.01284
Epoch: 24 | Ph  | Train Loss: 0.020 | Val Loss: 0.02051
Epoch: 24 | SW Thresh: 0.0822
Epoch: 24 | Ending LR: 0.000117


0it [00:00, ?it/s]

Epoch: 25 | FT  | Train Loss: 0.00095 | Val Loss: 0.00107
Epoch: 25 | Amp | Train Loss: 0.0128 | Val Loss: 0.01284
Epoch: 25 | Ph  | Train Loss: 0.020 | Val Loss: 0.02033
Epoch: 25 | SW Thresh: 0.0827
Epoch: 25 | Ending LR: 0.000133


0it [00:00, ?it/s]

Epoch: 26 | FT  | Train Loss: 0.00094 | Val Loss: 0.00106
Epoch: 26 | Amp | Train Loss: 0.0128 | Val Loss: 0.01272
Epoch: 26 | Ph  | Train Loss: 0.020 | Val Loss: 0.02033
Epoch: 26 | SW Thresh: 0.0832
Epoch: 26 | Ending LR: 0.000150


0it [00:00, ?it/s]

Epoch: 27 | FT  | Train Loss: 0.00094 | Val Loss: 0.00106
Epoch: 27 | Amp | Train Loss: 0.0128 | Val Loss: 0.01291
Epoch: 27 | Ph  | Train Loss: 0.020 | Val Loss: 0.02053
Epoch: 27 | SW Thresh: 0.0837
Epoch: 27 | Ending LR: 0.000167


0it [00:00, ?it/s]

Epoch: 28 | FT  | Train Loss: 0.00094 | Val Loss: 0.00106
Epoch: 28 | Amp | Train Loss: 0.0128 | Val Loss: 0.01285
Epoch: 28 | Ph  | Train Loss: 0.020 | Val Loss: 0.02054
Epoch: 28 | SW Thresh: 0.0842
Epoch: 28 | Ending LR: 0.000183


0it [00:00, ?it/s]

Epoch: 29 | FT  | Train Loss: 0.00095 | Val Loss: 0.00111
Epoch: 29 | Amp | Train Loss: 0.0128 | Val Loss: 0.01258
Epoch: 29 | Ph  | Train Loss: 0.020 | Val Loss: 0.02013
Epoch: 29 | SW Thresh: 0.0847
Epoch: 29 | Ending LR: 0.000200


0it [00:00, ?it/s]

Epoch: 30 | FT  | Train Loss: 0.00094 | Val Loss: 0.00107
Epoch: 30 | Amp | Train Loss: 0.0127 | Val Loss: 0.01277
Epoch: 30 | Ph  | Train Loss: 0.020 | Val Loss: 0.02044
Epoch: 30 | SW Thresh: 0.0853
Epoch: 30 | Ending LR: 0.000183


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00106 to 0.00105
Epoch: 31 | FT  | Train Loss: 0.00093 | Val Loss: 0.00105
Epoch: 31 | Amp | Train Loss: 0.0127 | Val Loss: 0.01277
Epoch: 31 | Ph  | Train Loss: 0.020 | Val Loss: 0.02023
Epoch: 31 | SW Thresh: 0.0858
Epoch: 31 | Ending LR: 0.000167


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00105 to 0.00105
Epoch: 32 | FT  | Train Loss: 0.00092 | Val Loss: 0.00105
Epoch: 32 | Amp | Train Loss: 0.0127 | Val Loss: 0.01296
Epoch: 32 | Ph  | Train Loss: 0.020 | Val Loss: 0.02055
Epoch: 32 | SW Thresh: 0.0863
Epoch: 32 | Ending LR: 0.000150


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00105 to 0.00104
Epoch: 33 | FT  | Train Loss: 0.00091 | Val Loss: 0.00104
Epoch: 33 | Amp | Train Loss: 0.0127 | Val Loss: 0.01274
Epoch: 33 | Ph  | Train Loss: 0.020 | Val Loss: 0.02026
Epoch: 33 | SW Thresh: 0.0868
Epoch: 33 | Ending LR: 0.000133


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00104 to 0.00103
Epoch: 34 | FT  | Train Loss: 0.00090 | Val Loss: 0.00103
Epoch: 34 | Amp | Train Loss: 0.0127 | Val Loss: 0.01262
Epoch: 34 | Ph  | Train Loss: 0.020 | Val Loss: 0.02019
Epoch: 34 | SW Thresh: 0.0873
Epoch: 34 | Ending LR: 0.000117


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00103
Epoch: 35 | FT  | Train Loss: 0.00090 | Val Loss: 0.00103
Epoch: 35 | Amp | Train Loss: 0.0127 | Val Loss: 0.01264
Epoch: 35 | Ph  | Train Loss: 0.020 | Val Loss: 0.02005
Epoch: 35 | SW Thresh: 0.0878
Epoch: 35 | Ending LR: 0.000100


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00103
Epoch: 36 | FT  | Train Loss: 0.00089 | Val Loss: 0.00103
Epoch: 36 | Amp | Train Loss: 0.0127 | Val Loss: 0.01265
Epoch: 36 | Ph  | Train Loss: 0.020 | Val Loss: 0.02018
Epoch: 36 | SW Thresh: 0.0883
Epoch: 36 | Ending LR: 0.000108


0it [00:00, ?it/s]

Epoch: 37 | FT  | Train Loss: 0.00089 | Val Loss: 0.00104
Epoch: 37 | Amp | Train Loss: 0.0127 | Val Loss: 0.01277
Epoch: 37 | Ph  | Train Loss: 0.020 | Val Loss: 0.02023
Epoch: 37 | SW Thresh: 0.0888
Epoch: 37 | Ending LR: 0.000117


0it [00:00, ?it/s]

Epoch: 38 | FT  | Train Loss: 0.00089 | Val Loss: 0.00103
Epoch: 38 | Amp | Train Loss: 0.0127 | Val Loss: 0.01268
Epoch: 38 | Ph  | Train Loss: 0.020 | Val Loss: 0.02013
Epoch: 38 | SW Thresh: 0.0893
Epoch: 38 | Ending LR: 0.000125


0it [00:00, ?it/s]

Epoch: 39 | FT  | Train Loss: 0.00089 | Val Loss: 0.00103
Epoch: 39 | Amp | Train Loss: 0.0127 | Val Loss: 0.01265
Epoch: 39 | Ph  | Train Loss: 0.020 | Val Loss: 0.02016
Epoch: 39 | SW Thresh: 0.0898
Epoch: 39 | Ending LR: 0.000133


0it [00:00, ?it/s]

Epoch: 40 | FT  | Train Loss: 0.00089 | Val Loss: 0.00104
Epoch: 40 | Amp | Train Loss: 0.0127 | Val Loss: 0.01257
Epoch: 40 | Ph  | Train Loss: 0.020 | Val Loss: 0.02021
Epoch: 40 | SW Thresh: 0.0903
Epoch: 40 | Ending LR: 0.000142


0it [00:00, ?it/s]

Epoch: 41 | FT  | Train Loss: 0.00089 | Val Loss: 0.00104
Epoch: 41 | Amp | Train Loss: 0.0127 | Val Loss: 0.01264
Epoch: 41 | Ph  | Train Loss: 0.020 | Val Loss: 0.02001
Epoch: 41 | SW Thresh: 0.0908
Epoch: 41 | Ending LR: 0.000150


0it [00:00, ?it/s]

Epoch: 42 | FT  | Train Loss: 0.00089 | Val Loss: 0.00106
Epoch: 42 | Amp | Train Loss: 0.0127 | Val Loss: 0.01274
Epoch: 42 | Ph  | Train Loss: 0.020 | Val Loss: 0.02069
Epoch: 42 | SW Thresh: 0.0914
Epoch: 42 | Ending LR: 0.000142


0it [00:00, ?it/s]

Epoch: 43 | FT  | Train Loss: 0.00088 | Val Loss: 0.00105
Epoch: 43 | Amp | Train Loss: 0.0127 | Val Loss: 0.01271
Epoch: 43 | Ph  | Train Loss: 0.020 | Val Loss: 0.02036
Epoch: 43 | SW Thresh: 0.0919
Epoch: 43 | Ending LR: 0.000133


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00103
Epoch: 44 | FT  | Train Loss: 0.00088 | Val Loss: 0.00103
Epoch: 44 | Amp | Train Loss: 0.0127 | Val Loss: 0.01257
Epoch: 44 | Ph  | Train Loss: 0.020 | Val Loss: 0.02005
Epoch: 44 | SW Thresh: 0.0924
Epoch: 44 | Ending LR: 0.000125


0it [00:00, ?it/s]

Epoch: 45 | FT  | Train Loss: 0.00088 | Val Loss: 0.00104
Epoch: 45 | Amp | Train Loss: 0.0127 | Val Loss: 0.01269
Epoch: 45 | Ph  | Train Loss: 0.020 | Val Loss: 0.02020
Epoch: 45 | SW Thresh: 0.0929
Epoch: 45 | Ending LR: 0.000117


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00103
Epoch: 46 | FT  | Train Loss: 0.00087 | Val Loss: 0.00103
Epoch: 46 | Amp | Train Loss: 0.0127 | Val Loss: 0.01256
Epoch: 46 | Ph  | Train Loss: 0.020 | Val Loss: 0.02004
Epoch: 46 | SW Thresh: 0.0934
Epoch: 46 | Ending LR: 0.000108


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00103
Epoch: 47 | FT  | Train Loss: 0.00087 | Val Loss: 0.00103
Epoch: 47 | Amp | Train Loss: 0.0127 | Val Loss: 0.01265
Epoch: 47 | Ph  | Train Loss: 0.020 | Val Loss: 0.02013
Epoch: 47 | SW Thresh: 0.0939
Epoch: 47 | Ending LR: 0.000100


0it [00:00, ?it/s]

Epoch: 48 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 48 | Amp | Train Loss: 0.0127 | Val Loss: 0.01252
Epoch: 48 | Ph  | Train Loss: 0.020 | Val Loss: 0.02009
Epoch: 48 | SW Thresh: 0.0944
Epoch: 48 | Ending LR: 0.000104


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00103 to 0.00102
Epoch: 49 | FT  | Train Loss: 0.00086 | Val Loss: 0.00102
Epoch: 49 | Amp | Train Loss: 0.0127 | Val Loss: 0.01277
Epoch: 49 | Ph  | Train Loss: 0.020 | Val Loss: 0.02024
Epoch: 49 | SW Thresh: 0.0949
Epoch: 49 | Ending LR: 0.000108


0it [00:00, ?it/s]

Epoch: 50 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 50 | Amp | Train Loss: 0.0126 | Val Loss: 0.01245
Epoch: 50 | Ph  | Train Loss: 0.020 | Val Loss: 0.01992
Epoch: 50 | SW Thresh: 0.0954
Epoch: 50 | Ending LR: 0.000113


0it [00:00, ?it/s]

Epoch: 51 | FT  | Train Loss: 0.00086 | Val Loss: 0.00102
Epoch: 51 | Amp | Train Loss: 0.0126 | Val Loss: 0.01260
Epoch: 51 | Ph  | Train Loss: 0.020 | Val Loss: 0.02017
Epoch: 51 | SW Thresh: 0.0959
Epoch: 51 | Ending LR: 0.000117


0it [00:00, ?it/s]

Epoch: 52 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 52 | Amp | Train Loss: 0.0126 | Val Loss: 0.01267
Epoch: 52 | Ph  | Train Loss: 0.020 | Val Loss: 0.02010
Epoch: 52 | SW Thresh: 0.0964
Epoch: 52 | Ending LR: 0.000121


0it [00:00, ?it/s]

Epoch: 53 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 53 | Amp | Train Loss: 0.0126 | Val Loss: 0.01280
Epoch: 53 | Ph  | Train Loss: 0.020 | Val Loss: 0.02033
Epoch: 53 | SW Thresh: 0.0969
Epoch: 53 | Ending LR: 0.000125


0it [00:00, ?it/s]

Epoch: 54 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 54 | Amp | Train Loss: 0.0126 | Val Loss: 0.01262
Epoch: 54 | Ph  | Train Loss: 0.020 | Val Loss: 0.02012
Epoch: 54 | SW Thresh: 0.0975
Epoch: 54 | Ending LR: 0.000121


0it [00:00, ?it/s]

Saving improved model after Val Loss improved from 0.00102 to 0.00102
Epoch: 55 | FT  | Train Loss: 0.00086 | Val Loss: 0.00102
Epoch: 55 | Amp | Train Loss: 0.0126 | Val Loss: 0.01269
Epoch: 55 | Ph  | Train Loss: 0.020 | Val Loss: 0.02018
Epoch: 55 | SW Thresh: 0.0980
Epoch: 55 | Ending LR: 0.000117


0it [00:00, ?it/s]

Epoch: 56 | FT  | Train Loss: 0.00086 | Val Loss: 0.00103
Epoch: 56 | Amp | Train Loss: 0.0126 | Val Loss: 0.01254
Epoch: 56 | Ph  | Train Loss: 0.020 | Val Loss: 0.01997
Epoch: 56 | SW Thresh: 0.0985
Epoch: 56 | Ending LR: 0.000113


0it [00:00, ?it/s]

Epoch: 57 | FT  | Train Loss: 0.00085 | Val Loss: 0.00102
Epoch: 57 | Amp | Train Loss: 0.0126 | Val Loss: 0.01265
Epoch: 57 | Ph  | Train Loss: 0.020 | Val Loss: 0.02012
Epoch: 57 | SW Thresh: 0.0990
Epoch: 57 | Ending LR: 0.000108


0it [00:00, ?it/s]

Epoch: 58 | FT  | Train Loss: 0.00085 | Val Loss: 0.00102
Epoch: 58 | Amp | Train Loss: 0.0126 | Val Loss: 0.01254
Epoch: 58 | Ph  | Train Loss: 0.020 | Val Loss: 0.01998
Epoch: 58 | SW Thresh: 0.0995
Epoch: 58 | Ending LR: 0.000104


0it [00:00, ?it/s]

In [None]:
batches = np.linspace(0,len(metrics['lrs']),len(metrics['lrs'])+1)
epoch_list = batches/iterations_per_epoch

plt.plot(epoch_list[1:],metrics['lrs'], 'C3-')
plt.grid()
plt.ylabel("Learning rate")
plt.xlabel("Epoch")

In [None]:
losses_arr = np.array(metrics['losses'])
val_losses_arr = np.array(metrics['val_losses'])
losses_arr.shape
fig, ax = plt.subplots(3,sharex=True, figsize=(15, 8))
ax[0].plot(losses_arr[:,0], 'C3o-', label = "Train FT loss")
ax[0].plot(val_losses_arr[:,0], 'C0o-', label = "Val FT loss")
ax[0].set(ylabel='Loss')
ax[0].grid()
ax[0].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
ax[1].plot(losses_arr[:,1], 'C3o-', label = "Train Amp loss")
ax[1].plot(val_losses_arr[:,1], 'C0o-', label = "Val Amp loss")
ax[1].set(ylabel='Loss')
ax[1].grid()
ax[1].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
ax[2].plot(losses_arr[:,2], 'C3o-', label = "Train Ph loss")
ax[2].plot(val_losses_arr[:,2], 'C0o-', label = "Val Ph loss")
ax[2].set(ylabel='Loss')
ax[2].grid()
ax[2].legend(loc='center right', bbox_to_anchor=(1.5, 0.5))

plt.tight_layout()
plt.xlabel("Epochs")


In [None]:
# load test data
# give test data filelist
test_filelist = filelist[N_TRAIN:N_TRAIN+100]
print('number of test data:%d' % len(test_filelist))

# load test data 
test_dataset = Dataset(
    test_filelist, data_path, load_all=False, ratio=TRAIN_ratio, dataset='test',scale_I=scale_I,shuffle=False)
test_loader = DataLoader(
    test_dataset, batch_size=16, shuffle=False, num_workers=NGPUS)
# test_sampler = DistributedSampler(
#     test_dataset, num_replicas=1, rank=rank, shuffle=False)
# test_loader = DataLoader(
#     test_dataset, batch_size=16, sampler=test_sampler, shuffle=False, **kwargs)

# load saved model

In [None]:
# model_path = MODEL_SAVE_PATH + 'best_model.pth'

In [None]:
# model = recon_model()
# model.load_state_dict(torch.load(model_path))

In [None]:
model.eval() #imp when have dropout etc
ft_results=[]
complex_results=[]
amp_preds = []
ph_preds  = []
supports = []

ft_test_array = []
amp_test_array = []
ph_test_array = []


for i, (ft_images,amps,phs) in enumerate(test_loader):
    ft_images = ft_images.to(device)
    y, complex_x, amp, ph, support = model(ft_images)
    #amp, ph, support = model(ft_images)
    for j in range(amp.shape[0]):
        #prediction
        ft_results.append([y[j].detach().to("cpu").numpy()])
        complex_results.append([complex_x[j].detach().to("cpu").numpy()])
        amp_preds.append(amp[j].detach().to("cpu").numpy())
        ph_preds.append(ph[j].detach().to("cpu").numpy())
        supports.append(support[j].detach().to("cpu").numpy())
        
        #ground truth
        ft_test_array.append([ft_images[j].detach().to("cpu").numpy()])
        amp_test_array.append([amps[j].detach().to("cpu").numpy()])
        ph_test_array.append([phs[j].detach().to("cpu").numpy()])
        

In [None]:
ft_results = np.array(ft_results).squeeze()
complex_results = np.array(complex_results).squeeze()
amp_preds = np.array(amp_preds).squeeze()
ph_preds = np.array(ph_preds).squeeze()
supports = np.array(supports).squeeze()

print(ft_results.shape, ft_results.dtype)

ft_test_array = np.array(ft_test_array).squeeze()
amp_test_array = np.array(amp_test_array).squeeze()
ph_test_array = np.array(ph_test_array).squeeze()

print(ft_test_array.shape, ft_test_array.dtype)

In [None]:
# Test images
to_plot = np.random.randint(100, size=2)
# to_plot = [20,50]

plots = {0:[], 1:[], 2:[]}
titles = {0:[], 1:[], 2:[]}
for i in to_plot:
    plots[0] += [ft_test_array[i],ft_results[i],ft_test_array[i]-ft_results[i]]
    titles[0] += [f"FT input, {i}", f"Pred FT, {i}", f"Diff FT, {i}"]
    
    plots[1] += [supports[i],amp_test_array[i],amp_preds[i]]
    titles[1] += [f"Support, {i}", f"True Amp {i}", f"Predicted Amp {i}"]
    
    tmp = ph_preds[i]*supports[i]
    plots[2] += [ph_test_array[i],tmp,ph_test_array[i]-tmp]
    titles[2] += [f"True ph, {i}", f"Pred Ph {i}", f"Diff Ph {i}"]

for j in range(3):
    plot6(plots[j], titles[j])