In [1]:
from torchsummary import summary
import torchvision.transforms as transforms
import torch.nn as nn
import torch

from src.nn.regression_dataset import RegressionDataset
from src.nn.create_data_loaders import create_data_loaders
from src.nn.cnn_regressor import CNNRegressor
from src.nn.training import training
from src.nn.plot_losses import plot_losses
from src.data.synthMRWregul import synthMRWregul
import src.ctes.num_ctes as nctes

In [2]:
sample_size = nctes.LEN_SAMPLE

In [3]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        if(torch.cuda.is_available()):
            return torch.FloatTensor(sample).cuda()
        else:
            return torch.FloatTensor(sample)

In [4]:
data_path = "../../data/MRW.npz"
model_path = "../../data/model.pt"
transform = ToTensor()

data = RegressionDataset(data_path, transform, sample_size)

In [5]:
batch_size = 128
valid_size = 0.2
test_size = 0.2

train_loader, valid_loader, test_loader = create_data_loaders(batch_size, valid_size, test_size, data)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device ' + str(device))

Using device cuda


In [7]:
model = CNNRegressor(input_size=sample_size)
model.to(device=device)

summary(model, (1, sample_size))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 16, 32768]           --
|    └─Conv1d: 2-1                       [-1, 16, 32768]           16
|    └─BatchNorm1d: 2-2                  [-1, 16, 32768]           32
|    └─ReLU: 2-3                         [-1, 16, 32768]           --
├─Sequential: 1-2                        [-1, 32, 32767]           --
|    └─Conv1d: 2-4                       [-1, 32, 32767]           1,024
|    └─BatchNorm1d: 2-5                  [-1, 32, 32767]           64
|    └─ReLU: 2-6                         [-1, 32, 32767]           --
├─AvgPool1d: 1-3                         [-1, 32, 16384]           --
├─Sequential: 1-4                        [-1, 64, 16381]           --
|    └─Conv1d: 2-7                       [-1, 64, 16381]           8,192
|    └─BatchNorm1d: 2-8                  [-1, 64, 16381]           128
|    └─ReLU: 2-9                         [-1, 64, 16381]           --
├─AvgPoo

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 16, 32768]           --
|    └─Conv1d: 2-1                       [-1, 16, 32768]           16
|    └─BatchNorm1d: 2-2                  [-1, 16, 32768]           32
|    └─ReLU: 2-3                         [-1, 16, 32768]           --
├─Sequential: 1-2                        [-1, 32, 32767]           --
|    └─Conv1d: 2-4                       [-1, 32, 32767]           1,024
|    └─BatchNorm1d: 2-5                  [-1, 32, 32767]           64
|    └─ReLU: 2-6                         [-1, 32, 32767]           --
├─AvgPool1d: 1-3                         [-1, 32, 16384]           --
├─Sequential: 1-4                        [-1, 64, 16381]           --
|    └─Conv1d: 2-7                       [-1, 64, 16381]           8,192
|    └─BatchNorm1d: 2-8                  [-1, 64, 16381]           128
|    └─ReLU: 2-9                         [-1, 64, 16381]           --
├─AvgPoo

In [8]:
criterion = nn.MSELoss().to(device=device)

In [9]:
params = model.parameters()
lr = 0.01

optimizer = torch.optim.SGD(params, lr)

In [10]:
n_epochs = 2

train_losses, valid_losses = training(n_epochs, train_loader, valid_loader, model, criterion, optimizer, device, model_path) 

 19%|██████████▉                                               | 6/32 [15:27<1:06:57, 154.51s/it]


KeyboardInterrupt: 

In [None]:
plot_losses([n_epochs, train_losses, valid_losses], ["Train", "Val"])

In [None]:
model.load_state_dict(torch.load(model_path, map_location=device))

In [None]:
# Change Flatten layer by some ConvTranspose1d before the linear dense layer
# Predict statistics on test data
    # Data can be accesed using 'test_loader'
    # Use trained model for predicting
# Evaluate predicted statistics using MSE or RMSE for each statistic individually
# Reconstruct test samples using 'synthMRWregul' on the predicted statistics
# Evaluate reconstructions using MSE or RMSE over the whole length of the original and reconstructed sample
# Plot some test samples and its reconstruction on the same figure for each pair.