# Import Statements


In [None]:
###############REQUIRED FOR PICASSO-LITE##################
import numpy as np
import scipy.optimize as so
from sklearn.preprocessing import KBinsDiscretizer
import random 
import copy
##########################################################
import os
import matplotlib.pyplot as plt
from PIL import Image
os.chdir('C:/Users/Chris')
#import MIunmixer2 as mi

In [None]:
import os
cwd = os.getcwd()
print(cwd)

## Some useful function definitions
### Load n-channel tiff stacks. 
#### If the files are very large (stitched images) adjust the PIL maximum image size allowed. 


In [None]:
def load_nchan_image(filepath):
    imchans_list = []
    if os.path.exists(filepath):
        try:
            im = Image.open(filepath)
            num_chans = getattr(im, "n_frames", 1)
            im.load()#PIL loads channel zero when .open is called initially
            print(f"This is a {num_chans} - channel file.")
            
            data = np.array(im)

            imchans_list.append(data)
            for i in np.arange(1,num_chans):
         
                im.seek(i)#move to channel i 
                im.load()#get the pixel-level data
                data = np.array(im)#convert to numpy array   
  
                imchans_list.append(data)

            return num_chans, imchans_list
        except Exception as e:
            print(f"Warning: Failed to load the image from {filepath}. Error {str(e)}")
            return [],[]
    
    else:
        raise FileNotFoundError(f"The file {filepath} does not exist")
        return [],[]
    

### A function to create montages from tiff stacks imported as numpy arrays

In [None]:
def make_montage(im_stack, titles = [], outpath = [], numcols = 5, numrows = 1,lower_pct = 1, upper_pct = 99, show_montage = False):
    #this function makes a montage of the input image stack, with the titles provided

    n_color = im_stack.shape[0]
    col = numcols
    if col > n_color:
        col = n_color
    
    
    num_rows = np.floor(np.around(n_color/col))

    remainder = np.remainder(n_color,num_rows)

    if n_color > num_rows*col:
        num_rows += 1
    

    ff,aa = plt.subplots(np.int_(num_rows),np.int_(col),figsize=(3*col,3*num_rows))
    aa = aa.ravel()

    
    for idx,ax in enumerate(aa):
        if idx < n_color:
        
            single_slice = np.copy(im_stack[idx,:])
       
        
        
            im_max = np.nanpercentile(single_slice,upper_pct)
            im_min = np.nanpercentile(single_slice,lower_pct)
            
        
        
            ax.imshow(single_slice,vmin = im_min, vmax = im_max,cmap = 'gray',interpolation=None)
            if len(titles) > 0:
                ax.set_title(titles[idx],fontsize = 20)
            else:
                ax.set_title(str(idx),fontsize = 20)
            
            ax.axis('off')
        else:
            ax.axis("off")

    plt.subplots_adjust(wspace=0.001)
    
    if outpath:
        plt.savefig(outpath,bbox_inches='tight', dpi=300)
        
    if not show_montage:
        plt.close()

### The following cell defines functions for making RGB composite images. 

In [None]:
#color dict is the list of colors to choose from. Add more colors if you want. 
color_dict = {'blue':[0,0,1], 'green':[0, 1, 0],'red':[1, 0, 0],'magenta':[1,0,1],'cyan': [0, 1, 1],'yellow':[1, 1, 0],'gray':[1,1,1]}
    

def colorize(im,color):#im needs to be a single channel image. normalized and clipped
    #im is an mxn numpy array representing a grayscale image. 
    #color is the text string for the color this image should be in the composite. 
    
    color_im = np.zeros((3,im.shape[0],im.shape[1]))
    
    for i in [0,1,2]:
        color_array = color_dict[color]
        
        color_im[i,:,:] = im*float(color_array[i])
    
    return color_im


def normalize_and_clip(im,pct_lo,pct_hi,ignore_sat_pixels = True):
    #function to normalize images with quantile percentages lo and hi. then clip values 
    #below zero and above one
    #im is an mxn numpy array representing a grayscale image. 
    #pct_low is a float in the range [0,100]
    #pct_hi is a float in the range [0,100]
    #ignore_sat_pixels  = True causes the function to ignore saturated pixels when normalizing

    if  not ignore_sat_pixels:
        temp = np.copy(im)
        temp = temp[temp < 65000]
        
        lo = np.nanpercentile(temp,pct_lo)
        hi = np.nanpercentile(temp,pct_hi)
        #print('ignoring sat')
    
    elif ignore_sat_pixels:
        lo = np.nanpercentile(im,pct_lo)
        hi = np.nanpercentile(im,pct_hi)
    
    newim = (im - lo)/(hi-lo)
    
  
    newim[newim < 0] = 0
    newim[newim > 1] = 1
    
    return newim
    
    
def save_composite(im_stack,outpath,channels,colors):
    #a function to save composite images
    #im is an mxn numpy array representing a grayscale image.    
    #outpath is a string with save location
    #channels is a list of ints representing specific channels to use in the composite
    #colors is a list of strings for the color dictionary, length needs to be equal to length of channels
    subset = im_stack[channels,:,:]
  
    
    subset_shape = subset.shape
    
    
    comp_image  = np.zeros((3,subset_shape[1],subset_shape[2]))
    for chan in range(subset_shape[0]):
        
        temp = normalize_and_clip(subset[chan],5,99.5,ignore_sat_pixels = False)
        temp = colorize(temp,colors[chan])
        comp_image += temp
    
    rgb_composite = normalize_and_clip(comp_image,5,99.5,ignore_sat_pixels = True)
    rgb_composite = np.transpose(rgb_composite, (1, 2, 0))#reshape the array for RGB type data
          
    plt.imsave(outpath, rgb_composite)
    



### Load the PICASSO paper-provided 3 channel fluorescence microscopy image

In [None]:
#load the 3 channel laser scanning confocal tiff file provided by the PICASSO method authors
filename = '3color_data.tif'
filepath = os.path.join("D:/images/", filename)
num_chans, test = load_nchan_image(filepath)
test = np.asarray(test)
make_montage(test, titles = ['Channel 1','Channel 2','Channel 3'],
             outpath = "D:/images/raw_montage.png", numcols = 3, numrows = 1,lower_pct = 5, upper_pct = 99.5, show_montage = True)
orig_shape = test.shape

colors = ['blue','green','red']#,'gray','blue','magenta','cyan']
save_composite(test,"D:/images/raw_comoposite.png",np.arange(3),colors)

### Investigate the behavior of the mutual information between images as a function of alpha


In [None]:

def mutual_info(im1, im2,qN):
    
    hist_2d, x_edges, y_edges = np.histogram2d(im1, im2, bins=qN)
   
    # Convert bins counts to probability values
    pxy = hist_2d /float(np.sum(hist_2d))
    px = np.sum(pxy, axis=1) # marginal for x over y
    py = np.sum(pxy, axis=0) # marginal for y over x
    px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals

    # Now we can do the calculation using the pxy, px_py 2D arrays
    nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
    #print(nzs.shape)
    mutual_info = np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))

    return mutual_info

In [None]:
#make plots
alphas = np.linspace(0, 1, num=21, endpoint=True)


pair1 = (test[0,:],test[1,:])
pair2 = (test[0,:],test[2,:])
pair3 = (test[1,:],test[2,:])
pair4 = (test[1,:],test[0,:])
pair5 = (test[2,:],test[0,:])
pair6 = (test[2,:],test[1,:])

pairlist = [pair1,pair2,pair3,pair4,pair5,pair6]
difflist = []
negimlist = []
MIlist = []
neglist = []
diffMIlist = []
for alpha in alphas: 
    print(alpha)
    #compare the values of the mutual information for Imagej1 with Image_i - alpha*Image_j as a function of alpha
    
    sub_difflist = []#list of images
    sub_neg = []#list of negative images
    sub_MIlist = []#list of MI values
    sub_neglist = []#list of negative pixel fractions
    sub_diffMI = [] #list of MI values between Xj and the negative image of Xi - alpha*Xj
    for count, pair in enumerate(pairlist):
        #print(alpha, count)
        #form the difference image
        diff = pair[0] - alpha*pair[1]
        imshape = diff.shape
        
        #form the "negative image" caused by over subtraction
        neg_img = np.zeros(imshape)
        neg_img[diff<0] = np.abs(diff[diff<0])#the "negative image" is the absolute value of the negative pixel values. 
        rand_img = np.random.rand(imshape[0],imshape[1])
        neg_img[diff > 0] = 0.01*rand_img[diff > 0] #add random noise to the 'zero' area of the negative image to prevent a false increase in MI
        neg_img = np.abs(neg_img)
        sub_neg.append(neg_img)
        #calculate the negative fraction of pixels
        frac_neg = np.sum(diff < 0)/(imshape[0]*imshape[1])
        sub_neglist.append(frac_neg)



        #set negative pixels to zero in the corrected difference image. 
        diff[diff<0] = 0
        sub_difflist.append(diff)


        #calculate the mutual information between image j and the difference image
        mutinf = mutual_info(pair[1].flatten(),diff.flatten(), qN = 100)
        sub_MIlist.append(mutinf)
        #print(mutinf)        
        #calculate the mutual information between image j and the negative difference image
        diff_MI = mutual_info(pair[1].flatten(), neg_img.flatten(), qN = 100)
        sub_diffMI.append(diff_MI)

        # fig, axes = plt.subplots(4, 1, figsize = (6,12))
        # ax1, ax2, ax3, ax4 = axes.flatten()
        # im_max = np.nanpercentile(pair[0],99.5)
        # im_min = np.nanpercentile(pair[0],5)
        # ax1.imshow(pair[0],vmin = im_min, vmax = im_max,cmap = 'gray',interpolation=None)
        # im_max = np.nanpercentile(pair[1],99.5)
        # im_min = np.nanpercentile(pair[1],5)
        # ax2.imshow(pair[1], vmin = im_min, vmax = im_max,cmap = 'gray',interpolation=None)
        # im_max = np.nanpercentile(diff,99.5)
        # im_min = np.nanpercentile(diff,5)
        # ax3.imshow(diff, vmin = im_min, vmax = im_max,cmap = 'gray',interpolation=None)
        # im_max = np.nanpercentile(neg_img,99.5)
        # im_min = np.nanpercentile(neg_img,5)
        # ax4.imshow(neg_img, vmin = im_min, vmax = im_max,cmap = 'gray',interpolation=None)
        # ax1.set_title(r'Image $X_i$')
        # ax2.set_title(r'Image $X_j$')
        # ax3.set_title(r'$(X_i - \alpha X_j)$')
        # ax4.set_title(r'Negative Image from $(X_i - \alpha X_j)$')
        # # remove the x and y ticks
        # for ax in axes:
        #     ax.set_xticks([])
        #     ax.set_yticks([])
        # fig.tight_layout()
        # #plt.show()
        # plt.savefig(str(alpha) + 'pair_' + str(count) + '.png')
    
    negimlist.append(sub_difflist)
    difflist.append(sub_difflist)
    MIlist.append(sub_MIlist)
    neglist.append(sub_neglist)
    diffMIlist.append(sub_diffMI)

In [None]:
print(np.asarray(neglist).shape)

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=1, sharex=True, sharey=False, figsize=(9,12))
ax1, ax2, ax3, ax4 = axes.flatten()

SMALL_SIZE = 20
MEDIUM_SIZE = 25
BIGGER_SIZE = 40

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

ax1.plot(MIlist)

ax2.plot(diffMIlist)
ax3.plot(neglist)
ax4.plot(np.array(diffMIlist) + np.array(MIlist) + np.array(neglist))

loss = np.array(diffMIlist) + np.array(MIlist) + np.array(neglist)
min_losses = np.min(loss, axis=0)
print(min_losses)
min_args = np.argmin(loss, axis=0)
print(min_args)

ax4.plot(min_args,min_losses,'kx')

# Create the legend
from matplotlib.lines import Line2D
cmap = plt.cm.tab20
custom_lines = [Line2D([0], [0], color="#1f77b4", lw=4),
                Line2D([0], [0], color="#ff7f0e", lw=4),
                Line2D([0], [0], color="#2ca02c", lw=4),
                Line2D([0], [0], color="#d62728", lw=4),
                Line2D([0], [0], color= "#9467bd", lw=4),
                Line2D([0], [0], color="#8c564b", lw=4)]




ax1.set_title(r'$I(X_j , X_i - \alpha X_j)$ vs. $\alpha$')
ax1.set_ylabel(r'$I(X_j , X_i - \alpha X_j)$')
ax1.set_ylim([0, 1.1])
ax1.set_xticks([0, 5, 10, 15, 20], [str(alphas[0]), str(alphas[5]), str(alphas[10]), str(alphas[15]), str(alphas[20])], color='red', fontsize=12)

ax2.set_title(r'$I(X_j, (X_i - \alpha X_j)_{Neg.Img.})$ vs. $\alpha$')
ax2.set_ylabel(r'$I(X_j , X_i - \alpha X_j)_{Neg.Img.})$')
ax2.set_ylim([0, 1.1])

ax3.set_title(r'Negative Pixel Fraction in $(X_i - \alpha X_j)$ vs. $\alpha$')
ax3.set_ylabel(r'Fraction')
ax3.set_ylim([0, 1.1])

ax4.set_title(r'$L(I,I_{neg. Im.},frac_{neg.})$ vs. $\alpha$')
ax4.set_ylabel(r'Loss')
ax4.set_ylim([0, 1.1])
fig.supxlabel(r'Alpha ($\alpha$)')

fig.tight_layout()
fig.legend(custom_lines, ['${X_0, X_1}$', '${X_0, X_2}$', '${X_1, X_2}$', '${X_1, X_0}$', 
                          '${X_2, X_0}$', '${X_2, X_1}$'],loc='upper right', bbox_to_anchor=(1.22, 1.0),title=r"$X_i, X_j$ Pairs")
plt.savefig('plotsoflossprops'+'.png')
plt.show()
#the easiest way to plot these these would be to use pandas dataframes. old code dies hard. 

### Unmix the 3 channel dataset

In [None]:
testing = np.reshape(test,(orig_shape[0],-1)).astype(float) #num_chans rows by m times n total pixels (columns)
my_unmixer = MI_unmixer(n_color=3, maxIter = 2, learn_rate = 1, qQ = 300,qN = 100)

my_unmixer.mi_fit(testing) #fit the subsampled data
unmixed = my_unmixer.mi_transform(np.copy(testing))
unmixed = np.reshape(unmixed, orig_shape)


In [None]:
cropped = unmixed[:,200:700,400:900]

make_montage(unmixed, titles = ['Channel 1','Channel 2','Channel 3'],
             outpath = "D:/images/unmixed_montage.png", numcols = 3, numrows = 1,lower_pct = 5, upper_pct = 99.5, show_montage = True)

make_montage(cropped, titles = ['Channel 1','Channel 2','Channel 3'],
             outpath = "D:/images/cropped_unmixed_montage.png", numcols = 3, numrows = 1,lower_pct = 5, upper_pct = 99.5, show_montage = True)


colors = ['blue','green','red']
save_composite(unmixed,"D:/images/unmixed_composite.png",np.arange(3),colors)
save_composite(cropped,"D:/images/cropped_unmixed_composite.png",np.arange(3),colors)

In [None]:
#if you want to look at the unmixing matrix values, use the .ops_list property. the alpha values do not converge to zero. once the major structures have been removed from other channels, the algorithm works on small random noise differences. 
#i would recommend using something like the gradient of the SSIM to determine stopping criteria if that was desired. the goal is to run in two iterations, thus two iterations is my stopping criterion. 
print((my_unmixer.ops_list))

In [None]:
print(test.shape)

In [None]:
cropped_test = test[:,200:700,400:900]
make_montage(cropped_test, titles = ['Channel 1','Channel 2','Channel 3'],
             outpath = "D:/images/cropped_raw_montage.png", numcols = 3, numrows = 1,lower_pct = 5, upper_pct = 99.5, show_montage = True)
save_composite(cropped_test,"D:/images/cropped_raw_composite.png",np.arange(3),colors)

# PICASSO-LITE Class Definition

In [None]:




class MI_unmixer: 
    
    def __init__(self, n_color=0, maxIter = 2, learn_rate = 1, qQ = 300,qN = 100):
        
        self.n_color = n_color
        self.maxIter = maxIter#maximum number of iterations
        self.learn_rate = learn_rate#learning rate for unmixing
        self.qQ = qQ#number of bins for 16bit range quantization
        self.qN = qN#number of bins for 2d histograms
        self.ops_list = []#unmixing operations list. each entry in the list is one unmixing step. 
        self.encoder = KBinsDiscretizer(n_bins=qQ, encode="ordinal", strategy="uniform", random_state=42)#encoder for unmixing
    
    def objective(self,x,XY):#function of optimization for MI-based unmixing
        np.random.seed(42)
        
        X_j = XY[1,:]
        X_i = XY[0,:]
    
        nr = np.shape(X_j)[0]
        rand_img = np.random.rand((nr))
    
        total_pix = len(X_j)

        scaled_diff = X_i - x*X_j
        num_neg = np.sum([scaled_diff < 0])#the cost function to minimize includes the number of negative pixels
    
        neg_img = np.copy(scaled_diff)
    
    
        frac_neg = num_neg/total_pix
    
        neg_img[neg_img > 0] = 0.01*rand_img[neg_img > 0]#0
        neg_img = np.abs(neg_img)
    
        diff_mi = self.mutual_info(X_j, scaled_diff)
        neg_mi = self.mutual_info(X_j,neg_img)
    
    
        return diff_mi + neg_mi + frac_neg #+ x**2
    
    def mutual_info(self, im1, im2):
        qN = self.qN
        hist_2d, x_edges, y_edges = np.histogram2d(im1, im2, bins=qN)

        # Convert bins counts to probability values
        pxy = hist_2d /float(np.sum(hist_2d))
        px = np.sum(pxy, axis=1) # marginal for x over y
        py = np.sum(pxy, axis=0) # marginal for y over x
        px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
    
        #the calculation using the pxy, px_py 2D arrays
        nzs = pxy > 0 # Only non-zero pxy values contribute to the sum

        mutual_info = np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
    
        return mutual_info
    
    def quantize(self, img):
        qQuant = self.qQ
        #img is image data: nrow x mcol image. single slice. 
        #qQuant is the number of bins

        #print('im min: ', np.min(img))
        #print('im max: ', np.max(img))
        compressed = np.zeros(img.shape)
        compressed = self.encoder.fit_transform(img.reshape((-1, 1))).reshape(img.shape)/(qQuant-1)#quantize the entire hyperspectral image at once, not each channel separately. 
        #this preserves within-channel spatial signal relationships and inter-channel signal relationships

        return compressed
    
    def mi_datashaper(self,img_stack):
        #function to shape hyperspectral image cubes (l-channels x m rows x n columns) into l-channels x m*n columns
        reshaped = np.reshape(img_stack,(img_stack.shape[0],-1))
        
        return reshaped

    def subsample_img(self, im_stack,subsamp_pct = 1.0, mask=None):
    #a function to subsample hyperspectral images
    #im_stack is the m-channel x n pixel mixed image set ->reshape c*x*y image cube to 2D
    #mask is a 2d image reshaped into 1D array->.flatten()
    #this function randomly subsamples the mask or the whole field
    #print(im_stack.shape)#number of channels and pixels in the hyperspectral image
    #print(subsamp_pct*npix)
        if subsamp_pct < 1.0:
            if mask is not None:
                mask[mask > 0] = 1
                goodpix = 1-mask
                
                good_indices = goodpix.nonzero()
                
                
                im_stack = np.squeeze(im_stack[:,good_indices])
                
          
                n_color, npix = np.shape(im_stack)
                
                numpixels = np.round(subsamp_pct*npix)
           
            
                subsamp_img = np.zeros((n_color,numpixels.astype(int)))
            
                ran_indices = np.random.choice(npix, (1,numpixels.astype(int)), replace=False)
                
                for chans, _ in enumerate(im_stack):
                    subsamp_img[chans,:] = im_stack[chans,ran_indices]
                return subsamp_img
            
            else:#not using a mask
                
                n_color, npix = np.shape(im_stack)
                numpixels = np.round(subsamp_pct*npix)
                        
                subsamp_img = np.zeros((n_color,numpixels.astype(int)))
            
                ran_indices = np.random.choice(npix, (1,numpixels.astype(int)), replace=False)
                
                for chans,rows in enumerate(im_stack):
                    subsamp_img[chans,:] = im_stack[chans,ran_indices]
                return subsamp_img
            
        else:
            return im_stack
    
    def mi_fit(self, im_stack):
        np.random.seed(42)
        #im_stack is the m-channel x n pixel mixed image set ->reshape c*x*y image cube to 2D
        n_color, npix = np.shape(im_stack) #number of channels and pixels in the hyperspectral image

        self.n_color = n_color
        learn_rate = self.learn_rate
        
        X = im_stack
       
        bound = [(0, 1)] #set positive contraints for alpha values

        num_iter = 0#number iterations
        
        alpha_mat = 0.1*np.random.rand(n_color,n_color) +0.1 #matrix to hold the unmixing coefficients
        P = np.array(np.identity(n_color)) #matrix like alpha_mat but with negative signs in front of non-diagonal entries

        encoder = self.encoder#encoder for unmixing

        unmix_ops_list = np.zeros((self.maxIter,n_color,n_color)) #keep the unmixing matrix for each iteration
        
        #alpha_mat diagonals must == 1:
        
        alpha_mat[P==1] = 1
        
        beta_mat = np.zeros((n_color,n_color))
    
        while num_iter < self.maxIter:
            
            num_iter += 1
        
            print('Iteration #:', num_iter)
            
            #normalize the data
            for color in np.arange(n_color):
                       
                temp = X[color,:]

                Xmax = np.nanpercentile(temp.flatten(),100)

                beta_mat[color,:] = Xmax
                
                XX = (X[color,:])/(Xmax)
                XX[XX<0] = 0

                X[color,:] = XX  
            
            Xq = self.quantize(X)
           
            for ch in range(P.shape[0]):

                for dy in range(P.shape[1]):
                
                    if ch != dy: #diagonals are == 1

                        Xj = Xq[dy,:]
                        Xi = Xq[ch,:]
                    
                        X_Y = np.zeros((2,Xj.shape[0]))
                        X_Y[0,:] = Xi
                        X_Y[1,:] = Xj
           
                        alpha = alpha_mat[ch,dy]
                        
                        results = so.minimize(self.objective, x0=alpha, args=X_Y, bounds=bound,method='Powell',options = {'xtol':0.001,'ftol':0.001,'maxiter':5000}) #x0 must have ndims = 1
                        
                        alpha_mat[ch,dy] = results.x#there's no real need to actually hold the 'alpha_mat'. This is just a holdover from me trying to recreate the PICASSO method
                        P[ch,dy] = (-1)*learn_rate*alpha_mat[ch,dy]

            #hold the unmixing iteration array
            unmix_ops_list[num_iter-1,:,:] = P
            
            #perform the  unmixing
            
            BP = np.multiply(beta_mat,P)
            
            X = np.matmul(BP,X) 
            X[X<0] = 0#negative pixels don't exist in real life.
                      
        self.ops_list = unmix_ops_list
   
        return self
    
    def mi_transform(self,im_stack):
        #im_stack is the m-channel x n pixel mixed image set ->reshape c,x,y image cube to 2D
        if len(self.ops_list) > 0:
            X = im_stack
            beta_mat = np.zeros((self.n_color,self.n_color))
            for operation in range(len(self.ops_list)):
                
                P = self.ops_list[operation]
                n_color = P.shape[0]
                #normalize the data
                for color in np.arange(n_color):
                       
                    temp = X[color,:]
                    Xmax = np.nanpercentile(temp.flatten(),100)
                
                    beta_mat[color,:] = Xmax
                
                    XX = (X[color,:])/(Xmax)
                    XX[XX<0] = 0

                    X[color,:] = XX  
            
                               
                BP = np.multiply(beta_mat,P)
                X = np.matmul(BP,X) 
                X[X<0] = 0
                   
             
            return X                       
        else:
            print('error: model not trained on data')
            return []
        

### a fast positively-constrained linear least squares fitting algorithm for comparison with PICASSO-Lite if the individual spectra are available:

In [None]:
def fast_linearunmix(design_mat,data):
    #import numpy as np
    #this function implements a fast and loose (brute force) positive constraints unmixing algorithm that I came up with. it is really fast.  
    #design_mat is a matrix where the columns are the minmax normalized spectra of the individual components. m channels by n dyes
    #data is a m channels by n pixels array
    a = design_mat
    app = np.ones((a.shape[0],1))
    a = np.append(a,app,axis=1)                 
    
    orig_coeffs,_,_,_ = np.linalg.lstsq(a, data,rcond=None)
    coeffs = np.array(np.copy(orig_coeffs))

    #we want to go through all combinations of dyes and refit the pixels with negative fit coefficients

    num_dyes = a.shape[1] #number of dyes used

    column_listing = np.arange(num_dyes)

    already_analyzed = np.empty(coeffs.shape)

    #Column 0 is the background (Autofluorescence)column. 

    for i in range(1,num_dyes)[::-1]:

        column_selections = np.asarray(list(iterator.combinations(range(1,num_dyes),i)))

        for j in range(column_selections.shape[0]):

            working_on = column_selections[j,:]
            not_workingon = np.setdiff1d(column_listing,working_on)

            pixels = np.empty(coeffs.shape)
            pixels = coeffs < 0
            totalbad = np.sum(pixels,0)

            bad_pixels = np.empty(coeffs[working_on,:].shape)
            bad_pixels = coeffs[working_on,:] < 0

            sumofbad = np.sum(bad_pixels,0)

            pixels_selection = (sumofbad == i) & (totalbad > 0)
            #print('reanalyzing pixels: ', np.sum(pixels_selection*1),' pixels')

            pixels_toanalyze = np.squeeze(data[:,pixels_selection])
            #print('shape of pixels_toanalyze: ',pixels_toanalyze.shape)

            if i == num_dyes -1:
                #make the new support matrix
                #all the dye coefficients are negative, refit with only BG model
                aprime = np.zeros((a.shape[0],1))
                aprime[:,0] = a[:,0]
            else:
                if np.sum(pixels_selection)>0:
                    aprime = np.zeros((a.shape[0],num_dyes - i))

                    aprime = a[:,not_workingon]

                #else:
                    #print('i did nothing')
            if np.sum(pixels_selection) > 0:        
                temp_coeffs,_,_,_ = np.linalg.lstsq(aprime, pixels_toanalyze,rcond=None)


                coeffs[working_on[:,None], pixels_selection] = 0  ## [:,None] allows proper indexing of the 2D array coeffs

                if temp_coeffs.ndim > 1:    
                    temp_coeffs[already_analyzed[not_workingon[:,None],pixels_selection] == 1] = 0

                    if i <= num_dyes/2:
                        temp_coeffs[temp_coeffs < 0] = 0
                        coeffs[not_workingon[:,None],pixels_selection] = temp_coeffs
                else:

                    if temp_coeffs[0] > 0:
                        coeffs[not_workingon[:,None],pixels_selection] = temp_coeffs[0]
                    else:
                        coeffs[not_workingon[:,None],pixels_selection] = 0

                #update the already_analyzed variable to list the pixel coefficients we already worked on for being negative    
                already_analyzed[working_on[:,None],pixels_selection] = 1
    #final cleanup of negatives    
    coeffs[coeffs < 0] = 0
    
    return coeffs