# Datapreparation

Prepare and save the data in ./data/train/ and ./data/valid/. This only needs to be done once.

In [None]:
from dataset import prepare_data

prepare_data(100,range(1,300),True)


# Optimize Reconstruction

### Initialize

In [None]:
from explore import Optim
device = 'cuda'
angles = 100
net_file = "network_100a.pt"
optim = Optim(angles,name=net_file,device=device)

### Run the optimization using Bi-level optimization
Select an image by choosing an *image id* and insert the desired value for *m*.<br>
Select the *Sparse_rep_learning_rate* for learning rate of Sparsity network and *Sparse_rep_number_of_epochs* number of epochs to learn sparsity network<br>
Select the *Dictionary_learnin_rate* for learning rate of Dicitonary network and *Dictionary_number_of_epochs* number of epochs to learn Dictionary network<br>
Select an *number of atoms* in dictionary.
Select *l1_lambda* for L1 Regularizer.

In [None]:
permutation = {
        "rotate": list(range(0, 360, 360//4)),
        "zoom": lambda x: [x-2, x, x+2],
        "shift": [0]
    }

image_id = 169
m = 1.0 
patchsize = 2
number_of_atoms = 10
Sparse_rep_learning_rate = 1e-3
Sparse_rep_number_of_epochs = 3
Dictionary_learnin_rate = 1e-5
Dictionary_number_of_epochs = 3
l1_lambda = 4e-3

results = optim.optimize_GD(dataid= image_id, epochs=1000, mask_w=11, w_r=1.0, w_c=1.0, w_tv=0.01, lr=1, opt_to = m, perms=permutation,
                            lr_Dict =Dictionary_learnin_rate, epochs_Dict=Dictionary_number_of_epochs,
                            lr_SV=Sparse_rep_learning_rate, epochs_SP=Sparse_rep_number_of_epochs,
                            l1_lambda=l1_lambda, patchsize=patchsize, number_of_atoms=number_of_atoms
                            )

dict_loss, spars_network_loss, psnrs, xt, errors, losses, losses_r, losses_c, preds, stopiter, data_specs = results
tmean, tvar, slice, sino, low_dose, loc, malig, angles, end_sino, ld_sino = data_specs

# Plot Results

### Plot of the interior and the exterior error

In [None]:
import torch 
import matplotlib.pyplot as plt

# prepare errors
errors2=list(zip(*errors))
errors3=list(zip(*errors2[1]))
error_nodule = torch.stack(errors3[0])[:,0,0]
error_sur = torch.stack(errors3[1])[:,0,0]

# plot errors
plt.figure(figsize=(5,3))
plt.plot(error_nodule)
plt.plot(error_sur)
plt.legend(["error nodule", "error surrounding"])

### Plot the PSNR of each iteration

In [None]:
# plot psnrs
plt.figure(figsize=(5,3))
plt.plot(psnrs, linewidth=0.5)
plt.legend(["PSNR of each iteration"])
plt.savefig(f'res300/{image_id}/{number_of_atoms}/{patchsize}/{m}/WF-PSNR for non_coeff.png')

### Plot the losses of Dictionary Network

In [None]:
# plot helping networks losses
plt.figure(figsize=(5,3))
plt.plot(dict_loss, linewidth=0.5)
plt.xlabel("iteration")
plt.legend(["Dictionary Network Lossses"])
plt.show()

### Plot the losses of Sparse Network

In [None]:
plt.figure(figsize=(5,3))
plt.plot(spars_network_loss , linewidth=0.5)
plt.legend(["Sparse Network Lossses"])
plt.show()

### Plot of the (cropped) reconstruction

In [None]:
import dataset as ds

result_xt = (ds.crop_center(xt, loc, size=optim.nosz*3))[0,0]

fig, (ax0) = plt.subplots(1,1)
ax0.axis('off')
ax0.imshow(result_xt ,cmap = 'viridis', interpolation='bicubic')
plt.savefig(f'res300/{image_id}/{number_of_atoms}/{patchsize}/{m}/WF-Image for non_coeff.png')

fig, (ax0) = plt.subplots(1,1, figsize=(5,3))
ax0.plot(preds)
plt.savefig(f'res300/{image_id}/{number_of_atoms}/{patchsize}/{m}/WF-Prediction for non_coeff.png')
torch.cuda.empty_cache()


### Plot of the loss and the network prediction

In [None]:
fig, (ax0,ax1) = plt.subplots(1,2,figsize=(15,4))


ax0.plot(losses)
ax0.plot(losses_r)
ax0.plot(losses_c)
ax0.legend(["loss","loss E1","loss E2(1)"])
ax0.set_title("Losses")

ax1.plot(preds)
ax1.set_title(f"Network prediction")
plt.savefig(f'Res-normal/{image_id}/{m}/WF-Prediction.png')
torch.cuda.empty_cache()

