# Datapreparation

Prepare and save the data in ./data/train/ and ./data/valid/. This only needs to be done once.

In [None]:
from dataset import prepare_data
prepare_data(100,range(1,300),True)

# Optimize Reconstruction

### Initialize

In [None]:
from explore import Optim
device = 'cuda'
angles = 100
net_file = "network_100a.pt"
optim = Optim(angles,name=net_file,device=device)

### Run the optimization using KSVD-Model and Orthogonal Matchin pursuit
Select an image by choosing an *image id* and insert the desired value for *m*.<br>
Select the *number of coefficients that can be nonezeros* in sparse representation that is allowed.<br>
*number of atoms* in dictionary.<br>
Select *patch size* for patchifying the image  

### Extraction of the outputs

In [None]:
permutation = {
        "rotate": list(range(0, 360, 360//4)),
        "zoom": lambda x: [x-2, x, x+2],
        "shift": [0]
    }

image_id = 169
non_zero_coefs = 1
number_of_atoms = 10
m = 1.0
patchsize = 2

results = optim.optimize_ksvd(image_id, epochs=6000, mask_w=11., w_r=1.0, w_c=1.0, w_tv=0.01, lr=1,
                                  opt_to=m, perms=permutation
                                  non_zero_coefs=non_zero_coefs, number_of_atoms=number_of_atoms,
                                  patchsize=patchsize, size_of_crop=16
                                  )
psnrs, xt, errors, losses, losses_r, losses_c, preds, stopiter, data_specs = results
tmean, tvar, slice, sino, low_dose, loc, malig, angles, end_sino, ld_sino = data_specs

# Plot Results

### Plot of the interior and the exterior error

In [None]:
import torch 
import matplotlib.pyplot as plt

# prepare errors
errors2=list(zip(*errors))
errors3=list(zip(*errors2[1])) 
error_nodule = torch.stack(errors3[0])[:,0,0]
error_sur = torch.stack(errors3[1])[:,0,0]

# plot errors
plt.figure(figsize=(5,3))
plt.plot(error_nodule)
plt.plot(error_sur)
plt.title("Error values e0 and e1")
plt.xlabel("iteration")
plt.legend(["error nodule", "error surrounding"])
plt.show()

### Plot PSNR of each iteration

In [None]:
# plot psnrs
plt.figure(figsize=(15,3))
plt.plot(psnrs, linewidth=0.5)
plt.title("PSNR of each epoch")
plt.xlabel("iteration")
plt.legend(["PSNR"])
plt.show()

### Plot of the (cropped) reconstruction

In [None]:
import dataset as ds

result_xt = (ds.crop_center(xt, loc, size=optim.nosz*3))[0,0]
startg_xt = (ds.crop_center(low_dose, loc, size=optim.nosz*3))[0,0]
fig, (ax0,ax1) = plt.subplots(1,2)
ax0.axis('off')
ax0.imshow(startg_xt)
ax0.set_title("reconstr. filtered backpr.")
ax1.axis('off')
ax1.imshow(result_xt)
ax1.set_title(f"reconstr. m={m}")
plt.show()


### Plot of the loss and the network prediction

In [None]:
fig, (ax0,ax1) = plt.subplots(1,2,figsize=(15,4))

ax0.plot(losses) 
ax0.plot(losses_r)
ax0.plot(losses_c)
ax0.legend(["loss","loss E1","loss E2(1)"])
ax0.set_title("Losses")

ax1.plot(preds)
ax1.set_title(f"Network prediction")
plt.show()