## Import Libraries

In [1]:
import sys
import os
sys.path.append('/home/michele.prencipe/tesi/transformer/swin2sr')

os.chdir('/home/michele.prencipe/tesi/transformer/swin2sr')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from data_loader.read_mrc import read_mrc
from skimage import io, color
from utils.utils import set_global_seed
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io.image")


plt.ion()   # interactive mode
set_global_seed(42)

  warn(f"Failed to load image Python extension: {e}")


## Set Directories

In [2]:
from data_loader.biosr_dataset import BioSRDataLoader

# Define your working directory and data directory
work_dir = "."
tensorboard_log_dir = os.path.join(work_dir, "tensorboard_logs")
os.makedirs(tensorboard_log_dir, exist_ok=True)
data_dir = '/group/jug/ashesh/data/BioSR/'


## Create Model

### Swin2SR

In [3]:
from tests.training import Swin2SRModule
import json

model_directory = '/group/jug/Michele/training/2411/biosr/617/'
config_fpath = os.path.join(model_directory,'config.json')
with open(config_fpath,'rb') as f:
    config = json.load(f)

# Initialize the model
model = Swin2SRModule(config)

model.load_state_dict(torch.load(os.path.join(model_directory,'87d8vg2xswin2sr_best.ckpt'))['state_dict'])

AttributeError: 'dict' object has no attribute 'model'

## Stitching Predictions


In [None]:
data_type = 'biosr'#config.data.data_type
gauss_factor = 13600 #config.data.gaussian_factor
poisson_factor = 0# config.data.poisson_factor
noisy_data = True#config.data.noisy
patch_size = 256
tile_size = 128

data_shape = (5,1004,1004)#config.data.data_shape

In [None]:
from predtiler.dataset import get_tiling_dataset, get_tile_manager
from data_loader.biosr_dataloader import SplitDataset

manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size), 
                               patch_shape=(1,patch_size,patch_size))

    
dset_class = get_tiling_dataset(SplitDataset, manager)
dataset = dset_class(           data_type = data_type,
                                patch_size=patch_size,
                                transform=None,
                                noisy_data=noisy_data,
                                poisson_factor=poisson_factor, 
                                gaus_factor=gauss_factor,
                                mode = 'Test'
                                )
print(type(dataset))
test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=12)
c1_min, c1_max, c2_min, c2_max = test_loader.dataset.get_normalization_params() #of the entire dataset

In [None]:
print(manager)

In [None]:
print(len(test_loader))

In [None]:
from predtiler.tile_stitcher import stitch_predictions

predictions = []
targets = []
model = model.eval()
model = model.cuda()
inputs = []

for inp, targ in test_loader:
    inp, targ = inp.cuda(), targ.cuda()
    if len(inp.shape) == 3:  
        inp = inp.unsqueeze(1)
    if len(targ.shape) == 3:
        targ = targ.unsqueeze(1)        
    pred = model(inp)
    
    
    pred[:,0,:,:] = pred[:,0,:,:]*(c1_max - c1_min) + c1_min
    pred[:,1,:,:] = pred[:,1,:,:]*(c2_max - c2_min) + c2_min
    
    predictions.append(pred.cpu().detach().numpy())
    
    
    targ[:,0,:,:] = targ[:,0,:,:]*(c1_max - c1_min) + c1_min
    targ[:,1,:,:] = targ[:,1,:,:]*(c2_max - c2_min) + c2_min
    
    inputs.append(inp.cpu().detach().numpy())
    targets.append(targ.cpu().detach().numpy())

inputs = np.concatenate(inputs, axis = 0)
predictions = np.concatenate(predictions, axis = 0) # shape: (number_of_patches, C, patch_size, patch_size)
stitched_pred = stitch_predictions(predictions, dataset.tile_manager)
targets = np.concatenate(targets, axis = 0)
print(targets.shape)
print(manager)
print(predictions.shape)
print(stitched_pred.shape)

In [None]:
from data_loader.biosr_no_patching import NoPatchingSplitDataset
dataset_no_patching = NoPatchingSplitDataset(
                              data_type=data_type,
                              transform=None,
                              noisy_data=noisy_data,
                              poisson_factor= poisson_factor, 
                              gaus_factor= gauss_factor, mode = 'Test')


dataloader = DataLoader(dataset_no_patching, batch_size=2, shuffle=False, num_workers=4)
c1_min, c1_max, c2_min, c2_max = test_loader.dataset.get_normalization_params() #of the entire dataset

In [None]:
from predtiler.tile_stitcher import stitch_predictions

predictions = []
targets = []
inputs = []

for inp, targ in dataloader:
    inp, targ = inp.cuda(), targ.cuda()
    if len(inp.shape) == 3:  
        inp = inp.unsqueeze(1)
    if len(targ.shape) == 3:
        targ = targ.unsqueeze(1)        
    targ[:,0,:,:] = targ[:,0,:,:]*(c1_max - c1_min) + c1_min
    targ[:,1,:,:] = targ[:,1,:,:]*(c2_max - c2_min) + c2_min
    
    targets.append(targ.cpu().detach().numpy())
    inputs.append(inp.cpu().detach().numpy())
    
inputs = np.concatenate(inputs, axis = 0)
targets = np.concatenate(targets, axis = 0)
print(targets.shape)

In [None]:
from collections import defaultdict
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from utils.util_calculate_psnr_ssim import calculate_psnr, calculate_ssim
from core.psnr import PSNR

psnr_arr = {0: [], 1: []}
for ch_idx in range(targets.shape[1]):
    if ch_idx == 0:
        data_range = c1_max - c1_min
        
    else: 
        data_range = c2_max - c2_min
    psnr_arr[ch_idx].append(PSNR(targets[:,ch_idx,:,:], stitched_pred[:,:,:,ch_idx], range_= data_range))

print(psnr_arr)
psnr_1 = np.mean(psnr_arr[0])
psnr_2 = np.mean(psnr_arr[1])
print("psnr channel 1:", np.mean(psnr_arr[0]))
print("psnr channel 2:", np.mean(psnr_arr[1]))
print(np.mean([psnr_1, psnr_2]))


In [None]:
# Crea la figura con 1 riga e 2 colonne
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(100, 100))  # 1 riga, 2 colonne

_, target = dataloader.dataset[0]


# Primo subplot
ax1.imshow(stitched_pred[0,:,:,0],vmin = targets[0,0,:,:].min())
ax1.set_title('Stitched Pred 1')

# Secondo subplot
ax2.imshow(stitched_pred[0,:,:,1],vmin = targets[0,1,:,:].min())
ax2.set_title('Stitched Pred 2')

# Secondo subplot
ax3.imshow(targets[0,0,:, :])
ax3.set_title('Target Channel 1')


# Secondo subplot
ax4.imshow(targets[0,1,:,:])
ax4.set_title('Target Channel 2')


# Secondo subplot
ax5.imshow(inputs[0,0, : , :])
ax5.set_title('Input')


# Mostra il grafico
plt.tight_layout()  # Adatta il layout per evitare sovrapposizioni
plt.show()

# 