In [1]:
import os

import torch
from torch.utils.data import SubsetRandomSampler

from src.dataset import ImageDataset
from src.utils import get_indices
from src.model import MWCNN
from src.layers import DWT, IWT, FFTloss
from src.training import Trainer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Dataset part used for testing
TEST_SPLIT = 0.2
# Batch size for training. Limited by GPU memory
BATCH_SIZE = 6
# Dataset folder used
DATASET_USED = 'e9_5_GLM87a_cycle1_8_8'
# Full Dataset path
DATASET_PATH = os.path.join('data\mip2edof_2samples', DATASET_USED)
# Training Epochs
EPOCHS = 20



## 1 Initiate train and test loader

In [2]:
image_dataset = ImageDataset(DATASET_PATH, DATASET_USED)

train_indices, test_indices = get_indices(len(image_dataset), image_dataset.root_dir, TEST_SPLIT)
train_sampler, test_sampler = SubsetRandomSampler(train_indices), SubsetRandomSampler(test_indices)

trainloader = torch.utils.data.DataLoader(image_dataset, BATCH_SIZE, sampler=train_sampler)
testloader = torch.utils.data.DataLoader(image_dataset, 1, sampler=test_sampler)



## 1.1 Show example of images

In [3]:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots(nrows=6, ncols=2, figsize=(20, 60))
# for data in trainloader:
#     input_image = data['input_image'].squeeze().permute(1, 2, 0)
#     output_image = data['output_image'].squeeze().permute(1, 2, 0)
#     for i in range(input_image.shape[2]):
#         ax[i,0].imshow(input_image[:,:,i], cmap='gray', vmin=0, vmax=16383, aspect='equal')
#         ax[i,1].imshow(output_image[:,:,i], cmap='gray', vmin=0, vmax=16383, aspect='equal')
#     break

# plt.tight_layout()

## 2 Initiate model

In [4]:
dwt = DWT()
iwt = IWT()

MWCNN_model = MWCNN(2, dwt, iwt).to(device)

## 3. Model Training and Testing

In [5]:
# Training
criterion = FFTloss()
MWCNN_trainer = Trainer(MWCNN_model, criterion, device)

loss_record = MWCNN_trainer.train(EPOCHS,trainloader,mini_batch=100)

print(f'Training finished!')

Starting Training Process
Epoch: 001,  Loss:68333928.4444444,  Epoch: 002,  Loss:65867077.0370370,  Epoch: 003,  Loss:46292045.1851852,  Epoch: 004,  Loss:74616954.4444444,  Epoch: 005,  Loss:62482273.9259259,  Epoch 00006: reducing learning rate of group 0 to 8.5000e-04.
Epoch: 006,  Loss:61707977.3333333,  Epoch: 007,  Loss:74710519.4074074,  

KeyboardInterrupt: 

In [None]:
# save the model
torch.save(MWCNN_model, 'model.pth')

In [6]:
# Testing process on test data.
unet_score = MWCNN_trainer.test(testloader)

print(f'Score {unet_score}')

Score 27.188481788906746


Load model

In [7]:
image_index = test_indices[0]
sample = image_dataset[image_index]
image, output, d_score = MWCNN_trainer.predict(sample)

AttributeError: 'numpy.ndarray' object has no attribute 'cpu'