In [None]:
%matplotlib notebook
from matplotlib import pyplot as plt

import os

import numpy as np

import torch
import torch.optim as optim

from myutils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
nrows = 128
ncols = 128

nsamples = 256

transforms = [
    add_randn_noise
]

params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 0} # change this if on gpu (try 6)

In [None]:
train_set = RandomDataset(nrows,ncols,nsamples,transforms,device)
train_dataloader = torch.utils.data.DataLoader(train_set, **params)

test_set = RandomDataset(nrows,ncols,nsamples,transforms,device)
test_dataloader = torch.utils.data.DataLoader(test_set, **params)

In [None]:
data, labels = next(iter(train_set))   
    
plt.figure(figsize=(9,6))
plt.subplot(1, 2, 1)
plt.imshow(torch.squeeze(data),aspect='auto')
plt.title('data')
plt.subplot(1, 2, 2)
plt.imshow(torch.squeeze(labels),aspect='auto')
plt.title('labels')
plt.show()

In [None]:
cnn_channels=[1,8,16,8,1]
kernel_sizes=[3,3,3,3]

model = CNN(cnn_channels,kernel_sizes)
#model = UNet(cnn_channels,kernel_sizes,bilinear=False,double_conv=True)

model.to(device)

criterion = nn.MSELoss()

#optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(model.parameters())

In [None]:
epochs = 10
patience = 5
min_loss = float('inf')
model.train()
for t in range(epochs):
    train_loss = train_loop(train_dataloader, model, criterion, optimizer)
    test_loss = test_loop(test_dataloader, model, criterion)
    
    print('Epoch %i Train error: %0.4f Test error: %0.4f' % (t+1,train_loss,test_loss), end='\r')
    
    if test_loss < min_loss:
        n_not_improved = 0
        print("\nTest error decreased: saving model")
        min_loss = test_loss
        if not os.path.exists('models'):
            os.mkdir('models')
        torch.save(model, 'models/model_checkpoint.pt')
    else:
        n_not_improved += 1
        if n_not_improved>patience:
            print('\nTest error has not decreased for %i iterations. Returning' % n_not_improved)
            break
    
print('\nFinished Training. Final error: %0.4f' % min_loss)

In [None]:
# load last checkpoint
model = torch.load('models/model_checkpoint.pt')

In [None]:
test_in, test_labels = next(iter(train_set)) 

model.eval()
with torch.no_grad():
    test_out = model(test_in[None,:])

In [None]:
plt.figure(figsize=(9,6))
plt.subplot(1, 3, 1)
plt.imshow(torch.squeeze(test_in),aspect='auto')
plt.title('input')
plt.subplot(1, 3, 2)
plt.imshow(torch.squeeze(test_labels),aspect='auto')
plt.title('label')
plt.subplot(1, 3, 3)
plt.imshow(torch.squeeze(test_out),aspect='auto')
plt.title('output')
plt.show()

In [None]:
plt.figure(figsize=(9,6))
plt.plot(torch.squeeze(test_in)[:,20])
plt.plot(torch.squeeze(test_out)[:,20])
plt.plot(torch.squeeze(test_labels)[:,20])
plt.legend(['in','out','labels'])
plt.show()

In [None]:
data = np.load('../Sample Data/Stryde_input_data_csg1.npy')

data = torch.tensor(data,requires_grad=False,dtype=torch.float)
plt.figure(figsize=(9,6))
plt.imshow(data.T/torch.std(data,dim=1),clim=[-.1,.1],aspect='auto',cmap='gray')
plt.colorbar()
plt.show()

In [None]:
with torch.no_grad():
    test_out = torch.squeeze(model(data[None,None,:,:]))

In [None]:
plt.figure(figsize=(9,6))
plt.imshow(test_out.T/torch.std(test_out,dim=1),clim=[-1,1],aspect='auto',cmap='gray')
plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize=(9,6))
plt.plot(test_out[:,test_out.shape[1]//2])
plt.plot(data[:,data.shape[1]//2])
plt.show()