In [None]:

import sys
sys.path.append('./')
from fit.datamodules.super_res import MNIST_SResFITDM
from fit.utils.tomo_utils import get_polar_rfft_coords_2D

from fit.modules.SResTransformerModule import SResTransformerModule

from matplotlib import pyplot as plt
from matplotlib import gridspec

import torch

import numpy as np

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

from os.path import exists
import wget
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
seed_everything(22122020)

In [None]:
dm = MNIST_SResFITDM(root_dir='./datamodules/data/', batch_size=32)
dm.prepare_data()
dm.setup()

In [None]:
r, phi, flatten_order, order = get_polar_rfft_coords_2D(img_shape=dm.gt_shape)

In [None]:
n_heads = 8
d_query = 32

In [None]:
model = SResTransformerModule(d_model=n_heads*d_query, 
                              img_shape=dm.gt_shape,
                              coords=(r, phi),
                              dst_flatten_order=flatten_order,
                              dst_order=order,
                              loss='prod',
                              lr=0.0001, weight_decay=0.01, n_layers=8,
                              n_heads=n_heads, d_query=d_query, dropout=0.1, attention_dropout=0.1)

In [None]:
trainer = Trainer(max_epochs=100, 
                  gpus=1, # set to 0 if you want to run on CPU
                  callbacks=ModelCheckpoint(
                                            dirpath=None,
                                            save_top_k=1,
                                            verbose=False,
                                            save_last=True,
                                            monitor='Train/avg_val_loss',
                                            mode='min'
                                        ), 
                  deterministic=True)

In [None]:
# Train your own model.
# trainer.fit(model, datamodule=dm);

In [None]:
if not exists('./models/sres/mnist_sres.ckpt'):
    wget.download('https://download.fht.org/jug/fit/sres_model_mnist.ckpt',
                  out='./models/sres/mnist_sres.ckpt')

In [None]:
model.load_test_model('./models/sres/mnist_sres.ckpt')
model.cpu();

In [None]:
num_rings = 5

x, y = np.meshgrid(range(model.dft_shape[1]), range(-model.dft_shape[0] // 2 + 1, model.dft_shape[0] // 2 + 1))
radii = np.sqrt(x ** 2 + y ** 2, dtype=np.float32)
selected_rings = np.round(radii) < num_rings

model.input_seq_length = np.sum(selected_rings)
plt.imshow(selected_rings)
plt.title('Prefix');

In [None]:
for fc, (mag_min, mag_max) in dm.test_dataloader():
    break

In [None]:
lowres, pred_img, gt = model.get_lowres_pred_gt(fc, mag_min, mag_max)

In [None]:
sample = 30
fig = plt.figure(figsize=(31/2., 10/2.)) 
gs = gridspec.GridSpec(1, 5, width_ratios=[10,0.5, 10, 0.5, 10]) 
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[2])
ax2 = plt.subplot(gs[4])
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                    hspace = 0, wspace = 0)

ax0.xaxis.set_major_locator(plt.NullLocator())
ax0.yaxis.set_major_locator(plt.NullLocator())
ax0.imshow(lowres[sample], cmap='gray', vmin=gt[sample].min(), vmax=gt[sample].max())
ax0.set_title('Low-Resolution Input');
ax0.axis('equal');

ax1.xaxis.set_major_locator(plt.NullLocator())
ax1.yaxis.set_major_locator(plt.NullLocator())
ax1.imshow(pred_img[sample], cmap='gray', vmin=gt[sample].min(), vmax=gt[sample].max())
ax1.set_title('Prediction');
ax1.axis('equal');


ax2.xaxis.set_major_locator(plt.NullLocator())
ax2.yaxis.set_major_locator(plt.NullLocator())
ax2.imshow(gt[sample], cmap='gray')
ax2.set_title('Ground Truth');
ax2.axis('equal');