# Load the functions

In [2]:
import src.ImagingPreProc as iPP
import src.BehavPreProc as bPP
import os
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
from tkinter.filedialog import askopenfilename, askdirectory, asksaveasfilename

import tifffile as tf
import napari
import pandas as pd
import cv2 as cv
import tkinter as tk
from tkinter import filedialog

# Get the trial info

In [None]:
# rootDirs = [askdirectory(title="Select folder with trial info")]

# print(rootDirs)

# trials = iPP.loadTrialInfo(rootDirs)
# trials.keys()

# Process the data!

In [2]:
def loadTrialInfo(rootDir):
    """Return a list of all .tif files in the given directory"""
    trials = []
    for root, dirs, files in os.walk(rootDir):
        # Add tifs to list
        for file in files:
            if file.split(".")[-1] == "tif":
                filepath = os.path.join(root, file)
                trials.append(filepath)
        # Remove .oif directories
        dirs[:] = [d for d in dirs if ".".join(d.split('.')[-2:-1]) == "oif.files"]

    return trials

In [3]:
# trialfileNms = loadTrialInfo(r'C:\Users\ahshenas\Lab\mockglupuffdata\Batch4\tiffs')


In [4]:
root = tk.Tk()
root.withdraw()
root.attributes('-topmost',1)

trialfileNms = filedialog.askopenfilenames()

In [None]:
# Using files manually converted from .oif to .tif, selecting only the first channel
# trialfileNms = [r'C:\Users\ahshenas\Lab\mockglupuffdata\PlotForKirsten\20221230_6s-ss96-wtb_glutpuff_C_02.oif - C=0.tif']

In [3]:
def DFoFfromfirstfms(rawF, fm_interval):
    """ Calculate the DF/F given a raw fluorescence signal
    The baseline fluorescence is the mean of first 10 seconds of florescence
    Arguments:
        rawF = raw fluorescence
        fm_interval = frame interval aka time it takes to capture a frame
    """

    # Initialize the array to hold the DF/F data
    DF = np.zeros(rawF.shape)

    # rawF axes: [frames, rois]
    baseline_sec = 10
    baseline_end_frame = round(baseline_sec / fm_interval)

    # Calculate the DF/F for each ROI
    for r in range(0,rawF.shape[1]):
        Fbaseline = rawF[0:baseline_end_frame, r].mean()
        DF[:,r] = rawF[:,r]/Fbaseline-1

    
    return DF

In [4]:
def parseDate(filePath):
    """Given a name of the format path\\{date}_{flyline}-glutpuff_{trial num}.tif 
    (ex: path\\20221208_6s-ss96-glutpuff_01.tif)
    Return the date"""
    fileNm = filePath.split("\\")[-1]
    date = fileNm.split("_")[0]
    return date


In [5]:
# fileNm = asksaveasfilename(title="Save Data as")


# for expt in trials.keys():
#     # os.path.join("C:", os.sep, "Users", "Ali Shenasa", "Lab", "VT48352", "20190211", "Fly2_3days_7fxVT48352")
#     # '/Users/dante/Downloads/VT48352/20190211/Fly2_3days_7fxVT48352'
#     if expt == os.path.join("C:", os.sep, "Users", "Ali Shenasa", "Lab", "VT48352", "20190211", "Fly2_3days_7fxVT48352"):
#         continue
expt = input("Select name to save the preprocessed data") #"PlotForKirstenBatch2" #trialfileNms[0].split("/")[-1].split(".")[0]

outfileNm = f'C:/Users/ahshenas/Lab/mockglupuffdata/results/{expt}'
expt_dat = dict()
for i, trial in enumerate(trialfileNms):
    print(trial)

    # Load the stack
    [stack, nCh, nDiscardFBFrames, fpv] = iPP.loadTif(trial)
    print(f"stack shape: {stack.shape}")

    # Get frame interval (time between frames)
    with tf.TiffFile(trial) as tif:
        imagej_metadata = tif.imagej_metadata
        fm_interval = float(imagej_metadata.get("finterval"))


    frameidx = 1 # index of stack shape with frames
    ch = 0 # channel to be used

    # Select the first channel
    # Stack axes [Z?, frame, channel, X, Y]
    stack = stack[:,:,ch:ch+1,:,:]

    ### iPP.getROIs
    mean_stack = stack.mean(axis=1) # Axis 0 is of length one so it just returns the whole stack
    # Load the mean image in napari
    viewer = napari.Viewer()
    viewer.add_image(mean_stack)
    if (i>0) and (len(rois) > 0):
        # TODO if is a new brain, remove the old rois
        #   If the dict for the date exists add the rois from that date to napari
        viewer.add_shapes(rois, shape_type='Polygon', name = 'Shapes')
    napari.run()

    # Use the ROIs that were drawn in napari to get image masks
    ### iPP.getPolyROIs
    # Get the ROIs from napari
    rois = viewer.layers['Shapes'].data

    shape_x = stack.shape[3]
    shape_y =  stack.shape[4]
    all_masks = viewer.layers['Shapes'].to_masks(mask_shape=(shape_x,shape_y))


    ### iPP.FfromROIs
    # Initialize the array to hold the fluorescence data
    rawF = np.zeros((stack.shape[frameidx],len(all_masks)))

    # Step through each frame in the stack
    for fm in range(0,stack.shape[frameidx]):
        fmNow = stack[0,fm,ch,:,:]

        # print(f"fmNow.shape: {fmNow.shape}")
        # print(f"all_masks.shape: {all_masks.shape}")

        # Find the sum of the fluorescence in each ROI for the given frame
        for r in range(0,len(all_masks)):
            rawF[fm,r] = np.multiply(fmNow, all_masks[r]).sum()

    rawF_G = rawF

    # Get the DF/F
    DF_G = DFoFfromfirstfms(rawF_G, fm_interval)

    print(f"rawF.shape = {rawF.shape}")
    print(f"DF_G.shape = {DF_G.shape}")

    # Save the processed data
    expt_dat[trial] = {'trialName': trial,
                        'stack_mean_G': np.squeeze(mean_stack),
                        'rawF_G': rawF_G,
                        'DF_G': DF_G,
                        'all_masks': all_masks,
                        'rois': rois,
                        'fm_interval': fm_interval,
                        }

iPP.saveDFDat(outfileNm, expt, expt_dat)

Select name to save the preprocessed datatestpreproc
C:/Users/ahshenas/Lab/mockglupuffdata/20221208_6s-ss96-glutpuff_01.tif
stack shape: (1, 117, 2, 256, 256)
rawF.shape = (117, 12)
DF_G.shape = (117, 12)


In [None]:
print(expt_dat['C:\\Users\\ahshenas\\Lab\\mockglupuffdata\\PlotForKirstenBatch2\\tiffs\\20221221_6s-ss96-wtb-glutpuff_04.tif'].keys())

In [None]:
print(fm_interval)

In [None]:
### Display rois
import matplotlib.patheffects as PathEffects

plt.figure(figsize=(10, 10))
panel = plt.axes([0.1, 0.1, 0.75, 0.75])
for t in expt_dat.keys():
    print(expt_dat[t]['stack_mean_G'].shape)
    panel.imshow(expt_dat[t]['stack_mean_G'])
    panel.axis('off')
    for j,r in enumerate(expt_dat[t]['rois']):
        xidx = 2
        yidx = 3
        panel.add_patch(Polygon([[pt[yidx],pt[xidx]] for pt in r], closed=True,fill=False, edgecolor = (1,1,1,0.5)))
        panel.text(r[:,yidx].mean(), r[:,xidx].mean(),str(j+1),dict(ha='center', va='center', fontsize=5, color='w'))
    break

In [None]:
for roi in rois:
    print(roi.shape)
print(rois[0])

In [None]:
print(all_masks[0].shape)

In [None]:
print(stack.shape)
print(mean_stack.shape)

In [None]:
print(DF_G)

In [None]:
print(stack.shape)
print(stack[:,:,0,:,:].shape)
print(np.squeeze(stack).shape)

In [None]:
print(all_masks.shape)

In [None]:
print(all_masks['mask'].shape)

for x in all_masks['mask']:
    print(f"{x}", end="\n")
    print(len(x))

In [None]:
print(rawF_G)