Notebook for training the NN. Training data stack should already be made, but noise can/will be added here. 

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib widget 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import os
import warnings
warnings.filterwarnings("ignore", module="torch.nn.functional")

from pathlib import Path
from datetime import datetime

from image_helpers import * 
from datatransform import datatransform, training_log

from smallUnet import *
from skyrm_find_CNN import *

if torch.cuda.is_available():
    gpu = 0
    print(f"There are {torch.cuda.device_count()} GPUs available")
    print(f"Running on GPU (index) {gpu}")
else:
    gpu = 'cpu'
    print("GPU is not available, will run on CPU.")
    
savedir = Path('./NN_trained/').absolute()
today = datetime.today().strftime('%y%m%d')

There are 4 GPUs available
Running on GPU (index) 0


In [6]:
NN_name = f"skNet_batch1_{today}"
savedir = './NN_trained/' # dir to save results
nn_path = os.path.join(savedir, NN_name + "_final.pt")

#### If loading a NN

In [91]:
nn_path = savedir / "skNet_batch1_210921_final.pt"
model = smallUnet()
model.cuda(gpu)
model.load_state_dict(torch.load(nn_path))
model.eval()

# specify optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)  ### need to do this? 

## load training data

In [7]:
dataset = np.load('./training_data/batch1/200318_combrot.npz')
test_split = 0.2
batch_size = 64 
epochs = 600 # 

In [8]:
images_load = dataset['Images']
labels_load = (dataset['Labels']>0.1).astype('int') # forgot to threshold the training data, 
# binary classification so this works
assert images_load.shape[0] == labels_load.shape[0]
print(f"Total size Images, Labels:\n {images_load.shape, labels_load.shape}")
print(f"Number training images: {int(images_load.shape[0]*(1-test_split))}")
print(f"Number test images: {int(images_load.shape[0]*(test_split))}")

Total size Images, Labels:
 ((3000, 256, 256), (3000, 256, 256))
Number training images: 2400
Number test images: 600


# apply noise to training data

In [16]:
# For a single class case, we still need to explicitly specify the single channel
labels_load2 = labels_load[..., None] if np.ndim(labels_load) == 3 else labels_load
# Number of channels in masked data (the training images have a single channel)
ch = labels_load2.shape[-1]
# Define image distortion/noise parameters
zoom = False  # zoom factor
poisson = [10, 45]  # P noise range (scaled units)
gauss = [1, 150]  # G noise range (scaled units)
blur = [1, 50]  # Blurring range (scaled units)
contrast = [3, 18]  # contrast range (< 10 is brighter, > 10 is darker)
salt_and_pepper = [0, 20]  # min/max amount of salted/peppered pixels (scaled units)

notes = textwrap.dedent(
    """Trained on rotated Skx data such that tilt axis is along x axis
    (contrast is left/right)
    """
)

dim_order_in = "channel_last"
dim_order_out = "channel_first"
seed = 42
zoom = False
rotation = False

# Run the augmentor
imaug = datatransform(
    n_channels=ch,
    dim_order_in=dim_order_in,
    dim_order_out=dim_order_out,
    gauss_noise=gauss,
    poisson_noise=poisson,
    salt_and_pepper=salt_and_pepper,
    contrast=contrast,
    blur=blur,
    zoom=zoom,
    rotation=rotation,
    seed=seed,
    squeeze_channels=True,
    classifier=True,
)

images_noise, labels_noise = imaug.run(images_load, labels_load2)

training_log(
    NN_name,
    savedir,
    images_shape=images_load.shape,
    labels_shape=labels_load2.shape,
    test_split=test_split,
    batch_size=batch_size,
    epochs=epochs,
    n_channels=ch,
    dim_order_in=dim_order_in,
    dim_order_out=dim_order_out,
    gauss_noise=gauss,
    poisson_noise=poisson,
    salt_and_pepper=salt_and_pepper,
    contrast=contrast,
    blur=blur,
    zoom=zoom,
    rotation=rotation,
    seed=seed,
    notes=notes,
)

labels_noise = labels_noise.squeeze()
print(images_noise.shape, labels_noise.shape)


Saved training params file to:
 /home/amccray/code/SkX_NN/NN_trained/skNet_batch1_220208_training_data_params.txt
(3000, 1, 256, 256) (3000, 256, 256)


### Look at some of the augmented data

In [78]:
s=30
n = 5

n = n + 1
fig = plt.figure( figsize=(15, 8))
for i in range(1, n):   
    ax = fig.add_subplot(3, n, i)
    ax.imshow(images_load[i+s-1], cmap='gray')
    ax.set_title("Original image" + str(i-1), fontsize=10)
    ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)  
    if i != 1: 
        ax.tick_params(axis='y', which='both', left=False, labelleft=False)  
        
    ax = fig.add_subplot(3, n, i+n)
    ax.imshow(images_noise[i+s-1,0,:,:], cmap='gray')
    ax.set_title('Augmented image ' + str(i-1), fontsize=10)
    ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)    
    if i != 1: 
        ax.tick_params(axis='y', which='both', left=False, labelleft=False)  
    
    ax = fig.add_subplot(3, n, i+2*n)
    ax.imshow(labels_noise[i+s-1], cmap='jet', interpolation='Gaussian')
#     ax.imshow(images_noise[i+s-1,0,:,:], cmap='gray', alpha=0.5)
    
    
    ax.set_title('Ground truth ' + str(i), fontsize=10)
    if i != 1: 
        ax.tick_params(axis='y', which='both', left=False, labelleft=False)  

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

split into train/test

In [87]:
from sklearn.model_selection import train_test_split
images_all, images_test_all, labels_all, labels_test_all = train_test_split(
    images_noise, labels_noise, test_size=test_split)
print(images_all.shape, labels_all.shape, images_test_all.shape, labels_test_all.shape)

(2400, 1, 256, 256) (2400, 256, 256) (600, 1, 256, 256) (600, 256, 256)


# initialize and train a new CNN

In [88]:
rng_seed(42) # for reproducibility

# Initialize a model
model = smallUnet()
model.cuda(gpu)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 

In [89]:
n_train_batches, _ = np.divmod(labels_all.shape[0], batch_size)
n_test_batches, _ = np.divmod(labels_test_all.shape[0], batch_size)
images_allb = np.split(
    images_all[:n_train_batches*batch_size], n_train_batches)
labels_allb = np.split(
    labels_all[:n_train_batches*batch_size], n_train_batches)
images_test_allb = np.split(
    images_test_all[:n_test_batches*batch_size], n_test_batches)
labels_test_allb = np.split(
    labels_test_all[:n_test_batches*batch_size], n_test_batches)
print('image stack shape: ', np.shape(images_all))
print('batch stack shape: ', np.shape(images_allb))

image stack shape:  (2400, 1, 256, 256)
batch stack shape:  (37, 64, 1, 256, 256)


### Train

In [None]:
print_loss = 50 # print loss every m-th epoch.
# Generate sequence of random numbers for batch selection during training/testing
batch_ridx = [np.random.randint(0, len(images_allb)) for _ in range(epochs)]
batch_ridx_t = [np.random.randint(0, len(images_test_allb)) for _ in range(epochs)]
# Start training
train_losses, test_losses = [], []
for e in range(epochs):  
    model.train() # put in training mode 
    # Generate batch of training images with corresponding ground truth
    images = images_allb[batch_ridx[e]]
    labels = labels_allb[batch_ridx[e]]
    # Transform images and ground truth to torch tensors and move to GPU
    images = torch.from_numpy(images).float()
    labels = torch.from_numpy(labels).long()
    images, labels = images.cuda(gpu), labels.cuda(gpu) 
    # Forward --> Backward --> Optimize
    optimizer.zero_grad() 
    prob = model.forward(images)
    loss = criterion(prob, labels)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())
    # Now test the current model state using test data
    model.eval() # turn off batch norm and/or dropout units
    images_ = images_test_allb[batch_ridx_t[e]]
    labels_ = labels_test_allb[batch_ridx_t[e]]
    images_ = torch.from_numpy(images_).float()
    labels_ = torch.from_numpy(labels_).long()
    images_, labels_ = images_.cuda(gpu), labels_.cuda(gpu)
    with torch.no_grad(): # deactivate autograd engine during testing (saves memory)
        prob = model.forward(images_)
        loss = criterion(prob, labels_)
        test_losses.append(loss.item())
    # Print statistics
    if e == 0 or (e+1) % print_loss == 0:
        print('Epoch {:3} .... Training loss: {:8} .... Test loss: {:8}'.format(
            e+1, np.around(train_losses[-1], 8), np.around(test_losses[-1], 8))
        )
    # Save the best model weights
    if e > 100 and test_losses[-1] < min(test_losses[: -1]):
        torch.save(model.state_dict(), 
                   os.path.join(savedir, NN_name + '.pt'))
# Save final weights
torch.save(model.state_dict(), 
           os.path.join(savedir, NN_name + '_final.pt'))

fig, ax = plt.subplots()
ax.plot(train_losses, label='train')
ax.plot(test_losses, label='test')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f"./NN_trained/Training_loss_{NN_name}.png", dpi=600, bbox_inches="tight")
plt.show()

## viewing output

In [None]:
k, im = 2, 8 # batch, image
test_img = images_test_allb[k][im]
test_lbl = labels_test_allb[k][im]
# Convert to 4D tensor (required, even if it is a single image)
test_img = test_img[np.newaxis, ...]
# Convert to pytorch format and move to GPU
test_img_ = torch.from_numpy(test_img).float().cuda()
# make a prediction
prediction = model.forward(test_img_)
prediction = F.softmax(prediction, dim=1).cpu().detach().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1]) # rearange dimensions for plotting
# plot results
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(test_img[0,0], cmap='gray')
ax2.imshow(prediction[0,:,:,1], cmap='gray', Interpolation='Gaussian')
ax1.set_title('Test image')
ax2.set_title('Model prediction')

show_im(test_img[0,0] + prediction[0,:,:,1])

# Test a NN on a full image
Skip to here if already have a trained NN 
For actual experimental data, I have a workflow for bringing in the raw TEM images, filtering them, determining the tilt axis and compensating for tilt angle, etc., but for this example just using an already processed image. 

In [90]:
test_im = np.load("./test_data/150K_test.npy")
tilt_dir = 132.137 # tilt axis
show_im(test_im, title="Example image")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [92]:
model = trained_NN(nn_path, cuda=True)
centers = model.find_skyrms(test_im, tilt_dir, thresh=0.3, gpu=gpu)
show_im(model.prediction[:,:,0], "Model prediction", simple=True)
show_im_peaks(test_im, centers, title="Skyrmion locations")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …