In [None]:
import os,sys
import torch
import numpy as np
import chart_studio.plotly as py
import plotly
import plotly.express as px
import plotly.graph_objects as go
from plotly.offline import iplot, init_notebook_mode
init_notebook_mode(connected=False)
%load_ext autoreload
%autoreload 2

In [None]:
cfg = """# LARMAE Config
model:
  weight_by_model: True
  nonzero_pixel_threshold: -0.2
  checkpoint_file: "arxiv/fallen-snowflake-27/modelweights.larmae.102000th.tar"
#  checkpoint_file: "arxiv/silvery-cosmos-16/modelweights.larmae.10000th.tar"
larmaeDataset:
  filelist:
    - larmaedata_run3extbnb_0000.root
  crop_size: 512
  adc_threshold: 10.0
  min_crop_pixels: 1000
  vector_index: 0
  use_old_root2numpy: True
"""

with open('tmp.yaml','w') as f:
    print(cfg,file=f)

In [None]:
from larmae_dataset import larmaeDataset
from larmae_mp_dataloader import larmaeMultiProcessDataloader

loader = larmaeMultiProcessDataloader("tmp.yaml", 0, 1, num_workers=1)

In [None]:
from model import load_model

model = load_model("tmp.yaml", strict=True)

In [None]:
# load data
batch = next(iter(loader))
print("batch contents: ",batch.keys())
print(batch["entry"])

In [None]:
plot = go.Heatmap(z=batch["img"][0,0,:,:],type='heatmap',colorscale='Viridis')
fig = go.Figure(data=[plot])
fig.show()

In [None]:
# run model
model.train()
imgs = batch['img']
with torch.no_grad():
    maeloss, pred_masked, true_masked, masked_indices = model(imgs,return_outputs=True)
print("maeloss: ",maeloss)

In [None]:
print(imgs.shape)
print(pred_masked.shape)
print(masked_indices.shape)

In [None]:
from larmae_dataset import larmaeDataset

# chunk input image
img_chunks = larmaeDataset.chunk( imgs, 16 )
print(img_chunks.shape)



In [None]:
def unchunk(img_chunk, patch=16):
    nh = int(512/patch)
    out = np.zeros( (1,1,512,512) )
    for ih in range(nh):
        for iw in range(nh):
            ichunk = ih*nh + iw
            out[0,0,patch*ih:patch*(ih+1),patch*iw:patch*(iw+1)] = img_chunk[0,ichunk,:].reshape( (patch,patch) )
    return out

In [None]:
img_inv = unchunk(img_chunks)
test_unchunk =  torch.abs(imgs-img_inv).sum()
print(test_unchunk)

In [None]:
#threshold = -0.2
threshold = 0.5
print("apply threshold at ADC=",threshold*50+20.0)

# replace the masked patches
rescale = pred_masked.detach().numpy() # for debugging
#rescale = true_masked.detach().numpy()
rescale[rescale<threshold] = -0.4
img_pred = np.copy(img_chunks)
img_pred[0,masked_indices[:],:] = (rescale+0.05)

In [None]:
pred_inv = unchunk(img_pred)

In [None]:
# plot original image
plot1 = go.Heatmap(z=imgs[0,0,:,:],type='heatmap',colorscale='Viridis')
plot2 = go.Heatmap(z=pred_inv[0,0,:,:],type='heatmap',colorscale='Viridis')
#plot2 = go.Heatmap(z=img_inv[0,0,:,:],type='heatmap',colorscale='Viridis') # debug check

fig1 = go.Figure(data=[plot1])
fig1.show()

fig2 = go.Figure(data=[plot2])
fig2.show()

In [None]:
# run this cell to close the data loader properly.
del loader