In [1]:
import os

import torch
from torch.utils.data import SubsetRandomSampler, ConcatDataset

from src.dataset import ImageDataset
from src.utils import get_indices
from src.model import MWCNN
from src.layers import DWT, IWT, FFTloss
from src.training import Trainer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Dataset part used for testing
TEST_SPLIT = 0.2
# Batch size for training. Limited by GPU memory
BATCH_SIZE = 6
# Dataset folder used
DATASETS = ['e9_5_GLM87a_cycle1_8_8', 'e12_5_slide7_round1_section1']
ROOTDIR = 'data\mip2edof_2samples'
# Training Epochs
EPOCHS = 100



## 1 Initiate train and test loader

In [2]:
image_dataset = ImageDataset(ROOTDIR, DATASETS)

train_indices, test_indices = get_indices(len(image_dataset), image_dataset.root_dir, TEST_SPLIT)
train_sampler, test_sampler = SubsetRandomSampler(train_indices), SubsetRandomSampler(test_indices)

trainloader = torch.utils.data.DataLoader(image_dataset, BATCH_SIZE, sampler=train_sampler)
testloader = torch.utils.data.DataLoader(image_dataset, 1, sampler=test_sampler)



## 1.1 Show example of images

In [3]:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots(nrows=6, ncols=2, figsize=(20, 60))
# for data in trainloader:
#     input_image = data['input_image'].squeeze().permute(1, 2, 0)
#     output_image = data['output_image'].squeeze().permute(1, 2, 0)
#     for i in range(input_image.shape[2]):
#         ax[i,0].imshow(input_image[:,:,i], cmap='gray', vmin=0, vmax=16383, aspect='equal')
#         ax[i,1].imshow(output_image[:,:,i], cmap='gray', vmin=0, vmax=16383, aspect='equal')
#     break

# plt.tight_layout()

## 2 Initiate model

In [4]:
dwt = DWT()
iwt = IWT()

MWCNN_model = MWCNN(16, dwt, iwt).to(device)

## 3. Model Training and Testing

In [5]:
# Training
criterion = torch.nn.L1Loss()
MWCNN_trainer = Trainer(MWCNN_model, criterion, device)

loss_record = MWCNN_trainer.train(EPOCHS,trainloader,mini_batch=5)

print(f'Training finished!')

Starting Training Process
Batch: 05,	Batch Loss: 103.6150218
Batch: 10,	Batch Loss: 111.1386963


KeyboardInterrupt: 

In [9]:
# save the model
torch.save(MWCNN_model, 'MWCNN_model_L1loss_16.pth')

In [4]:
MWCNN_model = torch.load('MWCNN_model_fft_16.pth')

In [8]:
# Testing process on test data.
mwcnn_psnr, mwcnn_mse = MWCNN_trainer.test(testloader)

print(f'PSNR: {mwcnn_psnr}, MSE: {mwcnn_mse}')

PSNR: 28.96020724376925, MSE: 2317197.6719563804


Load model

In [6]:
image_index = test_indices[0]
sample = image_dataset[image_index]
input_image, pred, output_image, score = MWCNN_trainer.predict(sample)

IndexError: too many indices for tensor of dimension 3