In [None]:
%matplotlib qt5
%load_ext autoreload
%autoreload 2


import torch 
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tf
from torch.optim import adam

import matplotlib.pyplot as plt
import numpy as np
import random

import time
import os

import hyperspy.api as hs
from tqdm import tqdm
from joblib import Parallel, delayed

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device type: %s"%(device))

In [None]:
def load_data():
    images = np.load(...)
    masks = np.load(...)

    return images, masks

def load_signal(lazy = True):
    return hs.load(..., lazy = lazy)

def to_tensor(image):
    image = np.abs(image)
    image = np.where(image != 0, np.log2(image), 0)
    image =  2*(image / np.max(image)) - 1 # normalize: -1 to 1

    if len(image.shape) == 2:
        return torch.tensor(np.expand_dims(image, axis=0), dtype = torch.float32).unsqueeze(0).to(device)
    return [[torch.tensor(np.expand_dims(element, axis=0), dtype = torch.float32).unsqueeze(0).to(device) for element in row] for row in image]


In [None]:
learning_rate = 5e-4
depth = 3
filters = 5
base_path = ...
out_path = ...
filename = f'\segmentation_lr{learning_rate}_depth{depth}_filters{filters}_combo_'
PATH = out_path + filename + "model.pth"
checkpoint = torch.load(PATH)
model = UNet(in_channels = 1, n_classes = 2, depth = depth, wf = filters, padding = True)
model = model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print('Model Loaded')

## The following cells are for predicting simulated diffraction patters, or a few selected diffraction patters

In [None]:
images, masks = load_data()

In [None]:
prediction = []
for i in range(len(images)):
    image = images[i]
    im = to_tensor(image)
    with torch.no_grad():
        pred = model(im)
        output = torch.argmax(pred, dim=1)  # Get the index of the channel with the highest probability
        output = output.squeeze(0).cpu().numpy()
        prediction.append(output)
    im = im[0].detach().cpu().numpy()


In [None]:
for i in range(len(images)):
   plt.figure()
   plt.imshow(images[i], norm = "symlog")
   plt.figure()
   plt.imshow(prediction[i])

## Predicting an entire signal

In [None]:
signal = hs.load(..., lazy=False)
#If RAM is a concern, it may be necessary to crop the signal.
#signal = signal.inav[:128,:128]
data = signal.data


In [None]:
images = to_tensor(data)

In [None]:
def process_image(x, y, img):
    with torch.no_grad():
        pred = model(img)
        output = torch.argmax(pred, dim=1)  # Get the index of the channel with the highest probability
        return x, y, output.squeeze(0).cpu().numpy()

prediction = np.zeros(data.shape)


# Generate list of tuples with arguments for the process_image function
image_args = [(x, y, img) for x, row in enumerate(images[:10]) for y, img in enumerate(row)]

with Parallel(n_jobs=-1) as parallel:
    results = list(tqdm(parallel(delayed(process_image)(*args) for args in image_args), total=len(image_args)))

for x, y, result in results:
    prediction[x, y] = result

In [None]:
signal.data = prediction

In [None]:
signal.plot()

In [None]:
signal.save(...)