In [None]:
import numpy as np
import scipy
from scipy import signal
import math
from scipy.io import readsav
import matplotlib.pyplot as plt
import cv2
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFilter
import matplotlib.animation as animation
import diplib as dip

### Functions for loading data

In [None]:
def _load_data(filename):
    dat = readsav(filename)
    emission = dat['emission_structure']
    return emission[0]

def _find_index(arr,val):
    return np.argmin(abs(arr-val))

### Functions for enhancing images

In [None]:
def norm(data):
    mn = data.mean()
    std = data.std()
    return((data-mn)/std)

def rescale(data):
    return (data-data.min())/(data.max()-data.min())

def quantfilt(src,thr=0.9):
    filt = np.quantile(src,thr,axis=0)
    out = np.where(src<filt,0,src)
    return out

# gaussian filtering
def gaussblr(src,filt=(31, 3)):
    src = (rescale(src)*255).astype('uint8')
    out = cv2.GaussianBlur(src,filt,0)
    return rescale(out)

# mean filtering
def meansub(src):
    mn = np.mean(src,axis=1)[:,np.newaxis]
    out = np.absolute(src - mn)
    return rescale(out)

# morphological filtering
def morph(src):
    src = (rescale(src)*255).astype('uint8')
    se1 = cv2.getStructuringElement(cv2.MORPH_RECT, (4,4))
    se2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3,1))
    mask = cv2.morphologyEx(src, cv2.MORPH_CLOSE, se1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, se2)
    return rescale(mask)

In [None]:
def brightness_reconstruction(img): # doi: 10.1109/TPS.2018.2828863.
    im_norm = img / 255
    im_ave = np.average(im_norm,axis=None)
    significance = np.log(im_norm + 1) * (im_norm - im_ave)
    probability = significance / np.max(significance)
    fixed_probability = np.where(probability < 0, 0, probability)
    return fixed_probability * 255

def fourier_shifting(img):
    dft = np.fft.fft2(img, axes=(0,1))
    dft_shift = np.fft.fftshift(dft)
    radius = 1
    mask = np.zeros_like(img8)
    cy = mask.shape[0] // 2
    cx = mask.shape[1] // 2
    cv2.circle(mask, (cx,cy), radius, (255,255,255), -1)[0]
    mask = 255 - mask
    dft_shift_masked = np.multiply(dft_shift,mask) / 255
    back_ishift = np.fft.ifftshift(dft_shift)
    back_ishift_masked = np.fft.ifftshift(dft_shift_masked)
    img_back = np.fft.ifft2(back_ishift, axes=(0,1))
    img_filtered = np.fft.ifft2(back_ishift_masked, axes=(0,1))
    img_back = np.abs(img_back).clip(0,255).astype(np.uint8)
    img_filtered = np.abs(3*img_filtered).clip(0,255).astype(np.uint8)
    return img_filtered

def prob_to_edge(image, threshold):
    ratio = np.amax(image) / 255
    img8 = (image/ratio).astype('uint8')
    edge_ = cv2.Canny(img8, threshold[0], threshold[1])
    return edge_

def dark_filter(img):
    img = np.where(img < 5, 0, img)
    return img

### Load data

In [None]:
# [inverted,radii,elevation,frames,times,vid_frames,vid_times,vid] = _load_data('/scratch/gpfs/aj17/plasmatv_data/tv_images/emission_structure_pu_cam240perp_185821.sav')

[inverted,radii,elevation,frames,times,vid_frames,vid_times,vid] = _load_data('emission_structure_pu_cam240perp_185821.sav')

### Detecting lines in raw image (lines correspond to XPR and Emission Front)


In [None]:
tid = 200
img = np.sqrt(vid[tid]).copy()
gray=(255-255*(img-np.min(img))/(np.max(img)-np.min(img))).astype('uint8')

# reduce the noise using Gaussian filters
kernel_size = 11 
blur_gray = cv2.GaussianBlur(gray,(kernel_size, kernel_size),0)

# Apply Canny edge detctor
low_threshold = 10
high_threshold = 20
edges = cv2.Canny(blur_gray, low_threshold, high_threshold)

# Apply Hough transform
rho = 1  # This is the distance resolution in pixels of the Hough grid
theta = np.pi / 180  # angular resolution in radians of the Hough grid
threshold = 5  # minimum number of votes (intersections in Hough grid cell)
min_line_length = 20  # minimum number of pixels making up a line
max_line_gap = 10  # maximum gap in pixels between connectable line segments
line_image = np.zeros((img.shape[0],img.shape[1],3))  # creating a blank to draw lines on


lines = cv2.HoughLinesP(edges, rho, theta, threshold, np.array([]), min_line_length, max_line_gap) # The output "lines" is an array containing endpoints of detected line segments

for line in lines:
    for x1,y1,x2,y2 in line:
        cv2.line(line_image,(x1,y1),(x2,y2),(255,0,0),5)

        
line_len=[]
for line in lines:
    for x1,y1,x2,y2 in line:
        line_len.append(np.sqrt((x2-x1)**2+(y2-y1)**2))
        
# add the line_image as an extra layer on top of the original image
lines_edges = cv2.addWeighted(cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR), 1, line_image, 0.5, 0,dtype =0)


In [None]:
plt.imshow(lines_edges,aspect='auto')
plt.title('Final result')


## Other Filtering Methods

In [None]:
idx = 131
img = np.sqrt(vid[idx]).copy() # [25:250,250:700]
ratio = np.amax(img) / 255
img8 = (img/ratio).astype('uint8')
print(img8.shape)
img8 = img8[0:240, 240:720]
kernel_size = 5
blur_gray = cv2.GaussianBlur(img8,(kernel_size, kernel_size),0)

# img = cv2.imread('zhang4ab-2828863-large.jpg')
# img8 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[100:200,100:200]
aspect_num = 1/2

fig, ax = plt.subplots()
ax.imshow(blur_gray, cmap='gray')
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right-x_left)/(y_low-y_high))*aspect_num)

plt.title(f'Original, Frame {idx}')
plt.show()

In [None]:
kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]]) # sharpening kernel
probability = brightness_reconstruction(blur_gray)
probability = np.array(dip.MatchedFiltersLineDetector2D(probability, sigma = 1)) # 10.1109/42.34715
probability = np.where(probability < 0, 0, probability)
probability *= 255.0/probability.max()
probability = cv2.filter2D(probability,-1,kernel)
probability = brightness_reconstruction(probability)
fig, ax = plt.subplots(2)
ax[0].imshow(probability, cmap='gray')
ax[1].imshow(img8, cmap='gray')
x_left, x_right = ax[0].get_xlim()
y_low, y_high = ax[0].get_ylim()
ax[0].set_aspect(abs((x_right-x_left)/(y_low-y_high))*aspect_num)
ax[1].set_aspect(abs((x_right-x_left)/(y_low-y_high))*aspect_num)
ax[0].set_title(f'2X BR + 2DMFLD, Frame {idx}')
ax[1].set_title(f'Original, Frame {idx}')
plt.tight_layout()
plt.show()

In [None]:
kernel_size = 5
aspect_num = 1/2
fig, ax = plt.subplots(2)
for i in range(264,len(vid)-1):
    img = np.sqrt(vid[i]).copy() # [25:250,250:700]
    ratio = np.amax(img) / 255
    img8 = (img/ratio).astype('uint8')
    img8 = img8[0:240, 240:720]
    blur_gray = cv2.GaussianBlur(img8,(kernel_size, kernel_size),0)
    kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]]) # sharpening kernel
    probability = brightness_reconstruction(blur_gray)
    probability = np.array(dip.MatchedFiltersLineDetector2D(probability, sigma = 1)) # 10.1109/42.34715
    probability = np.where(probability < 0, 0, probability)
    probability *= 255.0/probability.max()
    probability = cv2.filter2D(probability,-1,kernel)
    probability = brightness_reconstruction(probability)
    ax[0].imshow(probability, cmap='gray')
    ax[1].imshow(img8, cmap='gray')
    x_left, x_right = ax[0].get_xlim()
    y_low, y_high = ax[0].get_ylim()
    ax[0].set_aspect(abs((x_right-x_left)/(y_low-y_high))*aspect_num)
    ax[1].set_aspect(abs((x_right-x_left)/(y_low-y_high))*aspect_num)
    ax[0].set_title(f'2X BR + 2DMFLD, Frame {i}')
    ax[1].set_title(f'Original, Frame {i}')
    plt.tight_layout()
    plt.savefig(f'output_images\\{i}.png')
plt.show()

In [None]:
rho = 1  # This is the distance resolution in pixels of the Hough grid
theta = np.pi / 180  # angular resolution in radians of the Hough grid
threshold = 5  # minimum number of votes (intersections in Hough grid cell)
min_line_length = 20  # minimum number of pixels making up a line
max_line_gap = 10  # maximum gap in pixels between connectable line segments
line_image = np.zeros((img.shape[0],img.shape[1],3))  # creating a blank to draw lines on


lines = cv2.HoughLinesP(edge, rho, theta, threshold, np.array([]), min_line_length, max_line_gap) # The output "lines" is an array containing endpoints of detected line segments

for line in lines:
    for x1,y1,x2,y2 in line:
        cv2.line(line_image,(x1,y1),(x2,y2),(255,0,0),5)

        
line_len=[]
for line in lines:
    for x1,y1,x2,y2 in line:
        line_len.append(np.sqrt((x2-x1)**2+(y2-y1)**2))
        
# add the line_image as an extra layer on top of the original image
lines_edges = cv2.addWeighted(cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR), 1, line_image, 0.5, 0,dtype =0)

In [None]:
plt.imshow(lines_edges)

### Significance Tiling

In [None]:
# manually test image size for tiling
print(blur_gray.shape)
print(blur_gray.shape[0] % 24, blur_gray.shape[1] % 72)

In [None]:
def img_split(img, kernel_size):
    img_height, img_width = img.shape
    tile_height, tile_width = kernel_size
    
    num_rows = img_height // tile_height
    num_cols = img_width // tile_width
    
    img = img.reshape(num_rows, tile_height, num_cols,tile_width)
    img = img.swapaxes(1,2).reshape(-1, tile_height, tile_width)

    return img

def recombine_tiles(img, img_shape, kernel_size):
    img_height, img_width = img_shape
    tile_height, tile_width = kernel_size
    
    num_rows = img_height // tile_height
    num_cols = img_width // tile_width
    
    img = img.reshape(num_rows, num_cols, tile_height, tile_width)
    img = img.swapaxes(1, 2).reshape(img_height, img_width)
    
    return img

In [None]:
bg_tiles = img_split(blur_gray, (24, 72))

In [None]:
def prob_tiles(img, kernel_size):
    
    tile_im = img_split(img, kernel_size)
    
    steps = tile_im.shape[0]
    
    for i in range(steps):
        tile_im[i] = get_probability(tile_im[i])
        
    combined_im = recombine_tiles(tile_im, img.shape, kernel_size)
    
    return combined_im

In [None]:
plt.imshow(prob_tiles(blur_gray, (48, 120)), cmap='gray')

## Training a model to predict X point coordinates based on synthetic XPR

In [None]:
cutoff_idx = 1000
end_idx = 1500

num_train_idx = cutoff_idx
num_val_idx = end_idx - cutoff_idx

In [None]:
# data = pickle.load(open('/projects/EKOLEMEN/plasmatv_data/synthetic_data/synthetic_outs.pl','rb'))
data = pickle.load(open('synthetic_outs.pl','rb'))
X=np.int_(np.dstack([v for k,v in data['image'].items()]))
y=np.dstack([v for k,v in data['RZ'].items()])

rand_ind=np.random.permutation(X.shape[2])

X_train = X[:,:,rand_ind[:cutoff_idx]]
y_train = y[:,:,rand_ind[:cutoff_idx]]

X_valid = X[:,:,rand_ind[cutoff_idx:end_idx]]
y_valid = y[:,:,rand_ind[cutoff_idx:end_idx]]

X_test = X[:,:,rand_ind[end_idx:]]
y_test = y[:,:,rand_ind[end_idx:]]

In [None]:
plt.pcolormesh(X_train[:,:,100])
plt.title("Intended")
plt.show()

Task 1: Train a model to predict a single X point using XPR synthetic data

Task 2: Load the syntheitc data (synthetic_outs_2d_ver2.pl) for both XPR and Emission Front and train a model to detect both inner and outer X points

Task 3: Detect the XPR and Emission Front in the raw image and redo Task 1&2 but with the detected lines rather than synthetic data

In [None]:
# hyperparameters
init_lr = 0.001
batch_size = 4
epochs = 2

# dummy dims
input_dim = 4
output_dim = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        
        super(Model, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.conv = nn.Conv2d(in_channels=input_dim, out_channels=20, kernel_size=(5,5))
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        self.linear = nn.Linear(358, output_dim)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.linear(x)
        return x

In [None]:
class TVDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[2]
    
    def __getitem__(self, index):
        return self.X[:,:,index], self.y[:,:,index]

In [None]:
model = Model(input_dim, output_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=init_lr)
loss_fn = nn.MSELoss()

# training history
H = {
    "train_loss": [],
    "train_acc": [],
    "valid_loss": [],
    "valid_acc": []
}

In [None]:
# convert to torch tensors
X_train_d = torch.from_numpy(X_train).float()
y_train_d = torch.from_numpy(y_train).float()
X_valid_d = torch.from_numpy(X_valid).float()
y_valid_d = torch.from_numpy(y_valid).float()
X_test_d = torch.from_numpy(X_test).float()
y_test_d = torch.from_numpy(y_test).float()

# data load debugging
dataset = TVDataset(X_train_d, y_train_d)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
for e in range(epochs):
    model.train()
    
    total_train_loss = 0
    total_val_loss = 0
    train_correct = 0
    val_correct = 0
    
    for i, (inputs, labels) in enumerate(dataloader):
        
        (inputs, labels) = (inputs.to(device), labels.to(device))
        
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        total_train_loss += loss.item()
        
        train_correct += (outputs.argmax(1) == labels).sum().item()
        
    H["train_loss"].append(total_train_loss / len(dataloader))
    H["train_acc"].append(train_correct / len(dataloader))