run on `pytorcch-1.9.0` kernel

# 1. Imports

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np

from typing import Any, Callable, Sequence, Optional

import pathlib
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


# 2. Load Data

In [2]:
class Img_Dataset(Dataset):
    def __init__(self, data_set, patch_size, seed=1234):
        self.data_set = data_set
        self.patch_size = patch_size
        self.seed = seed

    def __len__(self):
        return len(self.data_set[0])

    def __getitem__(self, idx):
        data = self.data_set
        image = data[0, idx]
        label = data[1, idx]
        
        patch_size = self.patch_size
        seed = self.seed
        rng = np.random.RandomState(seed)

        # img_width = 200
        # img_height = 200
        img_width = 2000
        img_height = 2000
        
        #randomly crop patch from training set
        x1 = rng.randint(img_width - patch_size)
        y1 = rng.randint(img_height - patch_size)
        S = (slice(y1, y1 + patch_size), slice(x1, x1 + patch_size))
        
        # create new arrays for training patchs
        image_patch = image[0][S]
        label_patch = label[0][S]
        
        # image_patch = image_patch[:, np.newaxis, : , :]
        # label_patch = label_patch[:, np.newaxis, :, :]
        image_patch = image_patch[np.newaxis, :, :]
        label_patch = label_patch[np.newaxis, :, :]
        
        image = torch.from_numpy(image_patch).float().cuda(device)
        label = torch.from_numpy(label_patch).float().cuda(device)
            
        return image, label

In [3]:
current_dir = pathlib.Path().resolve()
PLANE_data_path = current_dir / 'Data'
assert PLANE_data_path.exists()

def load(name):
    return np.load(PLANE_data_path / name)

training_data = load('training_data610-2000.npy')
print(training_data.shape)
training_data = np.array(training_data)
print(training_data.shape)


test_data = load('test_data200-2000.npy')
test_data = np.array(test_data)

# training_data = load('training_data200-6000.npy')
# training_data = np.array(training_data)

# test_data = load('test_data70-6000.npy')
# test_data = np.array(test_data)

# training_data = load('training_data8000-200.npy')
# training_data = np.array(training_data)

# test_data = load('test_data2000-200.npy')
# test_data = np.array(test_data)

(2, 610, 1, 2000, 2000)
(2, 610, 1, 2000, 2000)


In [4]:
train_dataset = Img_Dataset(training_data, patch_size=50)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")


test_dataset = Img_Dataset(test_data, patch_size=50)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

test_features, test_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {test_features.size()}")
print(f"Labels batch shape: {test_labels.size()}")

Feature batch shape: torch.Size([32, 1, 50, 50])
Labels batch shape: torch.Size([32, 1, 50, 50])
Feature batch shape: torch.Size([32, 1, 50, 50])
Labels batch shape: torch.Size([32, 1, 50, 50])


## 2.1 Plotting Functions

In [5]:
def single_show(sample):
    fig, ax = plt.subplots(1, 1, figsize=(8,4))
    ax.imshow(sample.cpu(), origin='lower', interpolation='none')
    ax.axis('off')
    ax.set_title('Sample')
    

def pair_show(igen):
    """
    Plot to compare a noisy training sample and its high-quality version
    
    Parameters
    ----------
    igen : int
           Index of specific sample of interest
    """
    
    train = D[0, igen]
    target = D[1, igen]
    fig, ax = plt.subplots(1, 2, figsize=(8,4))
    ax[0].imshow(train, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Training Sample')
    ax[1].imshow(target, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Target Sample')
    
def pair_show_model(igen):
    """
    Plot to compare a training sample and its residual noise image
    
    Parameters
    ----------
    igen : int
           Index of specific sample of interest
    """
    
    train = D[0, igen]
    dn_train = v[0][igen]
    fig, ax = plt.subplots(1, 2, figsize=(8,4))
    ax[0].imshow(train, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Training Sample')
    ax[1].imshow(dn_train, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Denoise Training Sample')
    
    
def dn_process_show(igen):
    """
    Plot to compare a training sample, it's residual noise sample, and the denoised training sample
    
    Parameters
    ----------
    igen : int
           Index of specific sample of interest
    """
    
    train = D[0, igen]
    dn_resid = v[0][igen]
    dn_sample = train - dn_resid
    #dn_sample =  dn_resid - train # just checking to see how results would change with this flipping of arrays

    fig, ax = plt.subplots(1, 3, figsize=(14,10))
    ax[0].imshow(train, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Training Sample')
    ax[1].imshow(dn_resid, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Residual Noise Sample')
    ax[2].imshow(dn_sample, origin='lower', interpolation='none')
    ax[2].axis('off')
    ax[2].set_title('Denoised Training Sample')
    
    
def batch_show(batch, igen):
    batch_noise = batch[0]
    batch_clean = batch[1]
    train = batch_noise[igen]
    target = batch_clean[igen]
    fig, ax = plt.subplots(1, 2, figsize=(8,4))
    ax[0].imshow(train, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Training Sample')
    ax[1].imshow(target, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Target Sample')

def pair_show_DL(dataset, igen):
    """
    Plot to compare a training sample and its residual noise image
    
    Parameters
    ----------
    igen : int
           Index of specific sample of interest
    """
    
    noisy = dataset[0, igen, 0]
    clean = dataset[1, igen, 0]
    fig, ax = plt.subplots(1, 2, figsize=(8,4))
    ax[0].imshow(noisy, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Noisy Sample')
    ax[1].imshow(clean, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Clean Sample')

# 3. Define Network

In [6]:
# Define model
class DnCNN(nn.Module):
    def __init__(self, num_layers=17, num_features=64):
        super(DnCNN, self).__init__()
        layers=[nn.Sequential(nn.Conv2d(1, num_features, kernel_size=3, stride=1, padding=1),
                                        nn.ReLU(inplace=True))]
        for i in range(num_layers - 2):
            layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(num_features),
                                       nn.ReLU(inplace=True)))
        layers.append(nn.Conv2d(num_features, 1, kernel_size=3, padding=1))
        self.layers = nn.Sequential(*layers)
        
        self._initialize_weights()
        
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

        
    def forward(self, inputs):
        y = inputs
        residual = self.layers(y)
        #return residual
        return y - residual

In [7]:
model = DnCNN().to(device)
#output = model(train_features)
# v = output.cpu().detach().numpy()# output = model(tensor)
# plt.imshow(v[0][0])
# print(output.type)

# Optimize Model Parameters

In [8]:
loss_fn = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)#, momentum=0.9)

Training loop

In [9]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.cuda(device), y.cuda(device)
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y) /(2*len(X))
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss, batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

In [10]:
epochs = 200
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")

Epoch 1
-------------------------------
loss: 779.220215 [    0/  610]
Epoch 2
-------------------------------
loss: 6.191767 [    0/  610]
Epoch 3
-------------------------------
loss: 1.975497 [    0/  610]
Epoch 4
-------------------------------
loss: 1.399537 [    0/  610]
Epoch 5
-------------------------------
loss: 1.187336 [    0/  610]
Epoch 6
-------------------------------
loss: 1.239075 [    0/  610]
Epoch 7
-------------------------------
loss: 0.992081 [    0/  610]
Epoch 8
-------------------------------
loss: 0.994866 [    0/  610]
Epoch 9
-------------------------------
loss: 0.881231 [    0/  610]
Epoch 10
-------------------------------
loss: 0.904989 [    0/  610]
Epoch 11
-------------------------------
loss: 0.917664 [    0/  610]
Epoch 12
-------------------------------
loss: 0.794490 [    0/  610]
Epoch 13
-------------------------------
loss: 0.856810 [    0/  610]
Epoch 14
-------------------------------
loss: 0.860271 [    0/  610]
Epoch 15
------------------

In [11]:
# Saving Models
current_dir = pathlib.Path().resolve()
model_params_path = current_dir / 'Model_params'
assert model_params_path.exists()
name = "2k_model_bs32_e200.pth"
path = model_params_path / name
torch.save(model.state_dict(), path)
print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth


In [12]:
# print(test_data[0][0].shape)
# print(test_data[1][0].shape)
# print(test_data.type)

# test_noise, test_clean = next(iter(test_dataloader))
# test_noise.shape

In [13]:
break 

SyntaxError: 'break' outside loop (371449011.py, line 1)

In [None]:
# def show_model_flow(model, model_pth, dataloader, idx):
def show_model_flow(model, model_path, dataset, idx):    
    
    current_dir = pathlib.Path().resolve()
    model_params_path = current_dir / 'Model_params'
    assert model_params_path.exists()
    model_path = model_params_path / model_path
    
    
    model = model()
    model.to(device)
    model.load_state_dict(torch.load(str(model_path)))
    model.eval();
    with torch.no_grad():
                          
        
        test_noise = torch.Tensor(dataset[0][:200])
        test_noise = test_noise.to(device)
        test_clean = torch.Tensor(dataset[1][:200])
        test_clean = test_clean.to(device)
        #test_noise, test_clean = next(iter(dataloader))
        #test_features = test_features.to(device)
        #test_labels = test_labels.cpu().detach().numpy()

        output = model(test_noise)
        resid_img = output.cpu().detach().numpy()

        test_noise = test_noise.cpu().detach().numpy()
        test_clean = test_clean.cpu().detach().numpy()

        fix, ax = plt.subplots(1, 3, figsize=(16,12))
        ax[0].imshow(test_noise[idx][0], origin='lower', interpolation='none')
        ax[0].axis('off')
        ax[0].set_title('Input')
        ax[1].imshow(resid_img[idx][0], origin='lower', interpolation='none')
        ax[1].axis('off')
        ax[1].set_title('Output')
        ax[2].imshow(test_clean[idx][0], origin='lower', interpolation='none')
        ax[2].axis('off')
        ax[2].set_title('Truth')

In [None]:
show_model_flow(DnCNN, "model_bs64_e5.pth", test_data, 0)

In [None]:
show_model_flow(DnCNN, "model_bs64_e10.pth", test_data, 0)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 0)

In [None]:
show_model_flow(DnCNN, "model_bs64_e5.pth", test_data, 1)

In [None]:
show_model_flow(DnCNN, "model_bs64_e10.pth", test_data, 1)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 1)

In [None]:
show_model_flow(DnCNN, "model_bs64_e5.pth", test_data, 2)

In [None]:
show_model_flow(DnCNN, "model_bs64_e10.pth", test_data, 2)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 2)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 3)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 4)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 5)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 6)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 7)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 8)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 9)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 10)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 11)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 12)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 13)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 14)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 15)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 16)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 17)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 18)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 19)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 20)

In [None]:
# # def show_model_flow(model, model_pth, dataloader, idx):
# def show_model_flow_comp(model, model_pth, dataset, idx):    
#     """
#     Function to display a comparison of the input, residual, and truth image for a 
#     specified sample from the DnCNN trained model of choice
    
#     Parameters:
#     -----------
#     model: torch.nn.module
#            The DnCNN-B architecture as described in Zheng et al.
#     model_pth: .pth file 
#            A file that contains the weights of a trained DnCNN model.
#     dataset: np.ndarray
#            Numpy array of pairs of noisy/clean images
#     idx: int
#            Index of the sample of interest
#     """
#     # Loading model & model weights 
#     model = model()
#     model.to(device)
#     model.load_state_dict(torch.load(str(model_pth)))
#     model.eval();
#     # telling pytorch this is for inference and not learning, so keeps the weights unchanged
#     with torch.no_grad():
                          
#         # Load noisy images to GPU
#         test_noise = torch.Tensor(dataset[0][:200])
#         test_noise = test_noise.to(device)
        
#         # Load clean images to GPU
#         test_clean = torch.Tensor(dataset[1][:200])
#         test_clean = test_clean.to(device)

#         # Obtain output of DnCNN model & convert from GPU tensor to np.array
#         output = model(test_noise)
#         resid_img = output.cpu().detach().numpy()
        
#         # Convert data from GPU tensor to np.array & obtain scaling for plots
#         test_noise = test_noise.cpu().detach().numpy()
#         vmin, vmax = np.percentile(test_noise[idx], (1,99))
#         test_clean = test_clean.cpu().detach().numpy()
        

#         fix, ax = plt.subplots(1, 3, figsize=(16,12))
#         ax[0].imshow(test_noise[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
#         ax[0].axis('off')
#         ax[0].set_title('Input')
#         ax[1].imshow(resid_img[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
#         ax[1].axis('off')
#         ax[1].set_title('Output')
#         ax[2].imshow(test_clean[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
#         ax[2].axis('off')
#         ax[2].set_title('Truth')

In [None]:
# def show_model_flow(model, model_pth, dataloader, idx):
def show_model_flow_comp(model, model_path, dataset, idx):    
    """
    Function to display a comparison of the input, residual, and truth image for a 
    specified sample from the DnCNN trained model of choice
    
    Parameters:
    -----------
    model: torch.nn.module
           The DnCNN-B architecture as described in Zheng et al.
    model_pth: .pth file 
           A file that contains the weights of a trained DnCNN model.
    dataset: np.ndarray
           Numpy array of pairs of noisy/clean images
    idx: int
           Index of the sample of interest
    """
    # Loading model & model weights 
    current_dir = pathlib.Path().resolve()
    model_params_path = current_dir / 'Model_params'
    assert model_params_path.exists()
    model_path = model_params_path / model_path
    
    
    model = model()
    model.to(device)
    model.load_state_dict(torch.load(str(model_path)))
    model.eval();
    # telling pytorch this is for inference and not learning, so keeps the weights unchanged
    with torch.no_grad():
                          
        # Load noisy images to GPU
        test_noise = torch.Tensor(dataset[0][:200])
        test_noise = test_noise.to(device)
        
        # Load clean images to GPU
        test_clean = torch.Tensor(dataset[1][:200])
        test_clean = test_clean.to(device)

        # Obtain output of DnCNN model & convert from GPU tensor to np.array
        output = model(test_noise)
        resid_img = output.cpu().detach().numpy()
        
        # Convert data from GPU tensor to np.array & obtain scaling for plots
        test_noise = test_noise.cpu().detach().numpy()
        vmin, vmax = np.percentile(test_noise[idx], (1,99))
        test_clean = test_clean.cpu().detach().numpy()
        

        fix, ax = plt.subplots(1, 3, figsize=(16,12))
        ax[0].imshow(test_noise[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
        ax[0].axis('off')
        ax[0].set_title('Input')
        ax[1].imshow(resid_img[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
        ax[1].axis('off')
        ax[1].set_title('Output')
        ax[2].imshow(test_clean[idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
        ax[2].axis('off')
        ax[2].set_title('Truth')

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 20)

In [None]:
show_model_flow(DnCNN, "model_bs64_e50.pth", test_data, 20)

Plots of comparisons where `vmin` & `vmax` of the input image is used for all 3 plots

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 30)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 31)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 32)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 33)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 34)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 35)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 36)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 37)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 38)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 39)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 40)

In [None]:
show_model_flow_comp(DnCNN, "model_bs64_e50.pth", test_data, 41)

In [None]:
print(test_data.shape)

print(np.percentile(test_data[0], (1, 99)))
print(np.percentile(test_data[1], (1, 99)))

print(np.percentile(test_data[0][0], (1, 99)))
print(np.percentile(test_data[1][0], (1, 99)))

# Using dataloader that makes the test images the 50x50 patches

In [None]:
break 

In [None]:
def show_model_flow(model, model_pth, dataloader, idx):
    
    model = model()
    model.to(device)
    model.load_state_dict(torch.load(str(model_pth)))
    model.eval();
    with torch.no_grad():
                          
                          
        test_noise, test_clean = next(iter(dataloader))
        #test_features = test_features.to(device)
        #test_labels = test_labels.cpu().detach().numpy()

        output = model(test_noise)
        resid_img = output.cpu().detach().numpy()

        test_noise = test_noise.cpu().detach().numpy()
        test_clean = test_clean.cpu().detach().numpy()

        fix, ax = plt.subplots(1, 3, figsize=(16,12))
        ax[0].imshow(test_noise[idx][0], origin='lower', interpolation='none')
        ax[0].axis('off')
        ax[0].set_title('Input')
        ax[1].imshow(resid_img[idx][0], origin='lower', interpolation='none')
        ax[1].axis('off')
        ax[1].set_title('Output')
        ax[2].imshow(test_clean[idx][0], origin='lower', interpolation='none')
        ax[2].axis('off')
        ax[2].set_title('Truth')

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 0)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 10)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 20)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 25)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 2)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 12)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 16)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 31)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 28)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 29)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 3)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 4)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 8)

In [None]:
show_model_flow(DnCNN, "model_bs32_e5.pth", test_dataloader, 9)

In [None]:
# # Loading Models
# model = DnCNN()
# model.to(device)
# model.load_state_dict(torch.load("model.pth"))
# model.eval();
# with torch.no_grad():
#     test_noise, test_clean = next(iter(test_dataloader))
#     #test_features = test_features.to(device)
#     #test_labels = test_labels.cpu().detach().numpy()

#     output = model(test_noise)
#     v = output.cpu().detach().numpy()
    
#     test_noise = test_noise.cpu().detach().numpy()
#     test_clean = test_clean.cpu().detach().numpy()
    
#     fix, ax = plt.subplots(1, 3, figsize=(16,12))
#     ax[0].imshow(test_noise[0][0], origin='lower', interpolation='none')
#     ax[0].axis('off')
#     ax[0].set_title('Input Noisy Test Sample')
#     ax[1].imshow(v[0][0], origin='lower', interpolation='none')
#     ax[1].axis('off')
#     ax[1].set_title('Denoised Test Sample')
#     ax[2].imshow(test_clean[0][0], origin='lower', interpolation='none')
#     ax[2].axis('off')
#     ax[2].set_title('Coadded Test Sample')
#     #print(output)

In [None]:
data = training_data
idx = 10
image = data[0, idx]
label = data[1, idx]

patch_size = 50
seed = 123
rng = np.random.RandomState(seed)


img_width = 200
img_height = 200

#randomly crop patch from training set
x1 = rng.randint(img_width - patch_size)
y1 = rng.randint(img_height - patch_size)
S = (slice(y1, y1 + patch_size), slice(x1, x1 + patch_size))

# create new arrays for training patchs
image_patch = image[0][S]
label_patch = label[0][S]

image = torch.from_numpy(image_patch).float().cuda(device)
label = torch.from_numpy(label_patch).float().cuda(device)