In [None]:
import torch, sys, os, random, json, time
import matplotlib.pyplot as plt
import numpy as np

from dataset import *
from scipy.ndimage import zoom
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from neuralop.models import TFNO
from neuralop.utils import count_params
from scipy.io import savemat, loadmat
from tqdm import tqdm

%matplotlib widget

In [None]:
def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2)-0.5, int(h/2)-0.5)
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

img = np.ones((128,128))
h, w = img.shape[:2]
mask = create_circular_mask(h, w, radius = 40)
masked_img = img.copy()
masked_img[~mask] = 0

In [None]:
mode = 'gpu'
config = 'noise0'
init_features = 32
model_name = f'unet_{init_features}'
correlation = 'high'
resolution = 128

if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(0)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float32)

model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 
                       'unet', in_channels=3, out_channels=1, 
                       init_features=init_features, pretrained=False)
    
output_dir = f'/home/sci/hdai/Projects/UltrasoundTfno/Checkpoints_big/{model_name}_grf{resolution}_{correlation}_{config}'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
    
data_path = '/home/sci/hdai/Projects/Datasets/Ultrasound/GRF/mats'

## A. Training

In [None]:
new_split = False
save_model = True
resume = False
n_epochs = 10001
start_epoch_num = 1001
log_test_interval = 500
learning_rate = 1e-3
training_loss_list, testing_loss_list = [], []

training_subset_id_list = random.sample(range(0,200),160)
testing_subset_id_list = [item for item in range(0,200) if item not in training_subset_id_list]

if new_split:
    with open(f'/home/sci/hdai/Projects/UltrasoundTfno/Checkpoints_big/training_subset_id_list.json', "w") as fp:
        json.dump(training_subset_id_list, fp)
    with open(f'/home/sci/hdai/Projects/UltrasoundTfno/Checkpoints_big/testing_subset_id_list.json', "w") as fp:
        json.dump(testing_subset_id_list, fp)
else:
    with open(f'/home/sci/hdai/Projects/UltrasoundTfno/Checkpoints_big/training_subset_id_list.json', "rb") as fp:
        training_subset_id_list = json.load(fp)
    with open(f'/home/sci/hdai/Projects/UltrasoundTfno/Checkpoints_big/testing_subset_id_list.json', "rb") as fp:
        testing_subset_id_list = json.load(fp)

In [None]:
model = model.to(device)
n_params = count_params(model)
print(f'\nOur model has {n_params} parameters.')
print('\n### MODEL ###\n', model)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoints.pth.tar')
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    start_epoch_num = 0
        
training_dataset = PairDataset(data_path, config=config, correlation=correlation, subset_idx_list=training_subset_id_list)
training_dataloader = DataLoader(training_dataset, batch_size=len(training_subset_id_list), shuffle=False, num_workers=0)
print(training_dataset.subset_idx_list)

testing_dataset = PairDataset(data_path, config=config, correlation=correlation, subset_idx_list=testing_subset_id_list)
testing_dataloader = DataLoader(testing_dataset, batch_size=len(testing_subset_id_list), shuffle=False, num_workers=0)
print(testing_dataset.subset_idx_list)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
criterion = torch.nn.MSELoss()

In [None]:
model.train()

for epoch in tqdm(range(start_epoch_num, start_epoch_num+n_epochs)):
    epoch_loss = 0
    for i, sample in enumerate(training_dataloader):
        x, y_true = sample['input'].to(device), sample['output'].to(device)
        y_pred = model(x)
        optimizer.zero_grad(set_to_none=True)
        training_loss = criterion(y_pred, y_true)
        
        training_loss.backward()
        optimizer.step()
        epoch_loss += training_loss.item()
        
    with torch.no_grad():
        for i, sample in enumerate(testing_dataloader):
            x, y_true = sample['input'].to(device), sample['output'].to(device)
            y_pred = model(x)
        testing_loss = criterion(y_pred, y_true)
    
    training_loss_list.append(epoch_loss)
    testing_loss_list.append(testing_loss.item())
    scheduler.step(epoch_loss)

    if epoch%log_test_interval==0: 
        print(f'epoch {epoch} MSELoss: {epoch_loss}')
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            }, f'{output_dir}/epoch_{epoch}_checkpoints.pth.tar')
            
with open(f'{output_dir}/training_loss.json', "w") as fp:
    json.dump(training_loss_list, fp)
with open(f'{output_dir}/testing_loss.json', "w") as fp:
    json.dump(testing_loss_list, fp)

## B. Inference

In [None]:
epoch = 10000
checkpoint = torch.load(f'{output_dir}/epoch_{epoch}_checkpoints.pth.tar')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

### B.1 Loss trend

In [None]:
plt.figure(figsize=(5,4))
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.plot(training_loss_list,color='#ffb901',linewidth=3,label='Training set')
plt.plot(testing_loss_list,color='#f25022',linewidth=3,label='Testing set')
plt.yscale('log')
plt.legend()

plt.savefig(f'{output_dir}/{model_name}_loss_trend_{correlation}_{config}.pdf', bbox_inches='tight', dpi=300)

### B.2 Test samples

In [None]:
out_list = []
sample_num = 40

testing_dataset = PairDataset(data_path, config=config, correlation=correlation, subset_idx_list=testing_subset_id_list)
testing_dataloader = DataLoader(testing_dataset, batch_size=len(testing_subset_id_list), shuffle=False, num_workers=0)
test_samples = testing_dataloader.dataset

fig = plt.figure(figsize=(8, sample_num*2))
    
for index in range(sample_num):
    data = test_samples[index]
    x = data['input'].to(device)
    y = data['output'].to(device)
    abs_id = data['abs_id']
    rel_id = data['rel_id']
#     st = time.time()
    out = model(x.unsqueeze(0))
#     et = time.time()
#     print(et-st)

    ax1 = fig.add_subplot(sample_num, 3, index*3 + 1)
    im1 = ax1.imshow(x[0].to('cpu'), cmap='viridis', vmin=0, vmax=1)
    ax1.set_title(f'{abs_id}, {rel_id} TOF')
    plt.colorbar(im1, ax=ax1)
    plt.axis('off')

    ax2 = fig.add_subplot(sample_num, 3, index*3 + 2)
    im2 = ax2.imshow(y.to('cpu').squeeze(), cmap='viridis', vmin=0, vmax=1)
    if index == 0: 
        ax2.set_title('Ground Truth SS')
    plt.axis('off')

    ax3 = fig.add_subplot(sample_num, 3, index*3 + 3)
    im3 = ax3.imshow(out.to('cpu').squeeze().detach().numpy(), cmap='viridis', vmin=0, vmax=1)
    out_list.append(out.to('cpu').squeeze().detach().numpy())
    if index == 0: 
        ax3.set_title('Predicted SS')
    plt.colorbar(im3, ax=ax3)
    plt.axis('off')

plt.tight_layout()
plt.savefig(f'{output_dir}/{model_name}_test_{epoch}_{correlation}_{config}.pdf', dpi=300)

### B.3 Train samples

In [None]:
out_list = []
sample_num = 80

training_dataset = PairDataset(data_path, config=config, correlation=correlation, subset_idx_list=training_subset_id_list)
training_dataloader = DataLoader(training_dataset, batch_size=len(training_subset_id_list), shuffle=False, num_workers=0)
train_samples = training_dataloader.dataset

fig = plt.figure(figsize=(8, sample_num*2))
    
for index in range(sample_num):
    data = train_samples[index]
    x = data['input'].to(device)
    y = data['output'].to(device)
    abs_id = data['abs_id']
    rel_id = data['rel_id']
    out = model(x.unsqueeze(0))

    ax1 = fig.add_subplot(sample_num, 3, index*3 + 1)
    im1 = ax1.imshow(x[0].to('cpu'), cmap='viridis', vmin=0, vmax=1)
#     if index == 0: 
    ax1.set_title(f'{abs_id}, {rel_id} TOF')
    plt.colorbar(im1, ax=ax1)
    plt.axis('off')

    ax2 = fig.add_subplot(sample_num, 3, index*3 + 2)
    im2 = ax2.imshow(y.to('cpu').squeeze(), cmap='viridis', vmin=0, vmax=1)
    if index == 0: 
        ax2.set_title('Ground Truth SS')
    plt.axis('off')

    ax3 = fig.add_subplot(sample_num, 3, index*3 + 3)
    im3 = ax3.imshow(out.to('cpu').squeeze().detach().numpy(), cmap='viridis', vmin=0, vmax=1)
    out_list.append(out.to('cpu').squeeze().detach().numpy())
    if index == 0: 
        ax3.set_title('Predicted SS')
    plt.colorbar(im3, ax=ax3)
    plt.axis('off')

plt.tight_layout()
plt.savefig(f'{output_dir}/{model_name}_train_{epoch}_{correlation}_{config}.pdf', dpi=300)

### B.4 Final train/test MSE

In [None]:
with torch.no_grad():
    for i, sample in enumerate(training_dataloader):
        x, y_true = sample['input'].to(device), sample['output'].to(device)
        y_pred = model(x)

stacked_mask = torch.stack([torch.from_numpy(masked_img).to(device)]*160,dim=0)
y_true_masked = y_true.squeeze()*stacked_mask
y_pred_masked = y_pred.squeeze()*stacked_mask

mse_array = torch.sum(((y_true_masked-y_pred_masked))**2, (1, 2))/(resolution**2)#*torch.stack([mask]*20)
l2 = np.linalg.norm(y_true_masked.cpu().numpy().flatten()-y_pred_masked.cpu().numpy().flatten(),2)/np.linalg.norm(y_true_masked.cpu().numpy().flatten(),2)
with open(f'{output_dir}/final_train_mse.json', "w") as fp:
    json.dump({'mse_mean':torch.mean(mse_array).item(), 'mse_std':torch.std(mse_array).item(), 'l2_rel_error': l2}, fp)
print({'mse_mean':torch.mean(mse_array).item(), 'mse_std':torch.std(mse_array).item(), 'l2_rel_error': l2})

In [None]:
with torch.no_grad():
    for i, sample in enumerate(testing_dataloader):
        x, y_true = sample['input'].to(device), sample['output'].to(device)
        y_pred = model(x)

stacked_mask = torch.stack([torch.from_numpy(masked_img).to(device)]*40,dim=0)
y_true_masked = y_true.squeeze()*stacked_mask
y_pred_masked = y_pred.squeeze()*stacked_mask

mse_array = torch.sum(((y_true_masked-y_pred_masked))**2, (1, 2))/(resolution**2)#*torch.stack([mask]*20)
l2 = np.linalg.norm(y_true_masked.cpu().numpy().flatten()-y_pred_masked.cpu().numpy().flatten(),2)/np.linalg.norm(y_true_masked.cpu().numpy().flatten(),2)
with open(f'{output_dir}/final_test_mse.json', "w") as fp:
    json.dump({'mse_mean':torch.mean(mse_array).item(), 'mse_std':torch.std(mse_array).item(), 'l2_rel_error': l2}, fp)
print({'mse_mean':torch.mean(mse_array).item(), 'mse_std':torch.std(mse_array).item(), 'l2_rel_error': l2})