In [None]:
import numpy as np


model = 1
vis_noise = None
raw_data = [1,2,3,4]
tolSpec = 1e-4
minCycle = 64
ref_ant = 1
base_to_ants = None
ant1arr = []
ant2arr = []

# Enforce the existence of a model
if model is None:
    model = np.ones_like(raw_data)
elif (True if (not isinstance(model, np.np.ndarray)) else (model.size == 1)):
    model = np.ones_like(raw_data) * model
elif model.shape != raw_data.shape:
    raise ValueError("model must be the same shape as raw_data.")

# Enforce the existence of a data weights
if vis_noise is None:
    vis_noise = np.ones_like(raw_data)
elif (True if (not isinstance(vis_noise, np.np.ndarray)) else (vis_noise.size == 1)):
    vis_noise = np.ones_like(raw_data) * vis_noise
elif vis_noise.shape != raw_data.shape:
    raise ValueError("vis_noise must be the same shape as raw_data.")

# Prep data for LS processing, get some basic values
n_samp = raw_data.shape(2)
n_ants = len(set(ant1arr).union(set(ant1arr)))
n_base = len(ant1arr)

commonCoords = np.tile(np.arange(n_base), 4) + (
    n_base *(
        [
            ant1arr,
            ant1arr - (ant1arr > ref_ant) + n_ants,
            ant2arr,
            ant2arr - (ant2arr > ref_ant) + n_ants,
        ] - 1
    )
)

posCoords = np.concatenate([ant2arr, ant2arr + n_ants, ant1arr, ant1arr + n_ants])
negCoords = np.concatenate([ant2arr + n_ants, ant2arr, ant1arr+n_ants, ant1arr])
signVals = np.concatenate([-np.ones(n_base), np.ones(2*n_base), -np.ones(n_base,1)])

common_mask = (ref_ant + n_ants) != np.concatenate(
    [ant1arr, ant1arr + n_ants, ant2arr, ant2arr + n_ants]
)
commonCoords = commonCoords(common_mask)
posCoords = posCoords(common_mask)
negCoords = negCoords(common_mask)
signVals = signVals(common_mask)

#rawData = reshape(transpose(rawData),[],1);
#visNoise = reshape(transpose(visNoise),[],1);
#model = reshape(transpose(model),[],1);
dataWeight = (1./vis_noise)

# toTest is the change in solution from value to value.
tolTest = 1
cycle=0
# cycleFinLim is the number of iterations to try before bailing. nParams^2
# appeared to give the best dynamic results (i.e. if a solution is possible,
# Monte Carlo tests showed that it always appeared in this number of cycles).
cycleFinLim=np.max(((2*n_ants-1)^2,minCycle))
dataMask = np.isfinite(model) & np.isfinite(vis_noise) & np.isfinite(raw_data)

# gain_guess provides a first guess based on the baselines of the refAnt. This
# step seems to cut convergence time in half. Needs to be split into separate
# real and imag components for the solver, though.
if eq(nSamp,1)
    refMask = and(permute(or(eq(ant1Arr,refAnt),eq(ant2Arr,refAnt)),...
        [2 1]),dataMask);
    antGuess = gain_guess(rawData(refMask)./model(refMask),refAnt,...
        ant1Arr(refMask),ant2Arr(refMask),nAnts);
else
    tempData = nanmean(reshape(rawData,nBase,nSamp),2)./...
        nanmean(reshape(model,nBase,nSamp),2);
    refMask = and(permute(or(eq(ant1Arr,refAnt),eq(ant2Arr,refAnt)),[2 1]),...
        isfinite(tempData));
    antGuess = gain_guess(tempData(refMask),refAnt,...
        ant1Arr(refMask),ant2Arr(refMask),nAnts);
end
antGuess = reshape([real(antGuess),imag(antGuess)],[],1);

guessMask = [true(nAnts+refAnt-1,1); false; true(nAnts-refAnt,1)];
delGuess = zeros(nAnts*2,1);
totMask = complex(zeros(nBase,(2*nAnts)-1));
% Complete either by hitting cycle limit or hitting tolerance target
while and(ge(tolTest,tolSpec),lt(cycle,cycleFinLim))
    % Create some gain correction factors, which will be used later
    % to determine the data residuals
    gainScale = complex(antGuess(1:nAnts),antGuess(nAnts+1:end));
    gainApp = (gainScale(ant1Arr).*conj(gainScale(ant2Arr)));
    
    cycle=cycle+1;
    
    % Create a "generic" mask based of the first-order Taylor expansion
    % of the gains.    
    totMask(commonCoords) = complex(antGuess(posCoords),...
        antGuess(negCoords).*signVals);
    
    % Calculate the true x-values for each table entry, corrected for noise
    xVals = repmat(totMask,[nSamp 1]).*repmat(dataWeight.*model,[1 (2*nAnts)-1]);
    xVals = cat(1,real(xVals(dataMask,:)),imag(xVals(dataMask,:)));
    
    % Calculate the true residuals, corrected for noise
    dataRes = (rawData-(model.*repmat(gainApp,[nSamp 1]))).*dataWeight;
    dataRes = cat(1,real(dataRes(dataMask)),imag(dataRes(dataMask)));
    
    alphaMatrix = transpose(xVals)*xVals;
    betaMatrix = transpose(xVals)*dataRes;
    delGuess(guessMask) = alphaMatrix\betaMatrix;
    # Calc the RMS change in the solutions
    tolTest = sqrt(sum(((delGuess(1:nAnts).^2)+(delGuess(nAnts+1:end))).^2./...
        ((antGuess(1:nAnts).^2)+(antGuess(nAnts+1:end).^2)))/nAnts);
    # Add new deltas to solutions, but dampen it in the case of large changes to
    # prevent positive feedback loops. The formula below is based on how large
    # the deltas need to be in order for second-order effects to create such
    # loops. Increases the number of cycles requires by convergence by ~50%, but
    # greatly enhances stability.
    antGuess = antGuess+(delGuess./sqrt(1+tolTest));
    
    # Prevent a degenerate case where the phase of the refant becomes 180
    # instead of zero (i.e. a negative, real number for the gains soln).
    if  antGuess[refAnt] < 0:
        antGuess *= -1

# Are there any NaNs, Infs or zeros in the gains soln?
errCheck = or(any(and(eq(antGuess(1:nAnts),0),eq(antGuess(nAnts+1:2*nAnts),0))),...
    not(all(isfinite(antGuess))));
# If we reached the cycle lim or there was an error in the check above,
# mark solutions as bad.
if or(eq(cycle,cycleFinLim),logical(errCheck))
    %disp({'Error!',errCheck,cycle})
    gainsTable = zeros(nAnts,1);
    if ge(nargout,2)
        varargout{1} = zeros(nAnts,nAnts);
        varargout{2} = zeros(1,nAnts);
    end
else
    gainsTable = complex(antGuess(1:nAnts),antGuess(nAnts+1:end));
    if ge(nargout,2)
        xVals = zeros(nBase,nAnts);
        xVals(repmat(1:nBase,[1 2])+(nBase.*([ant1Arr ant2Arr]-1))) = ...
            repmat(visNoise,[2 1]);
        covarMatrix = inv(transpose(xVals)*xVals);
        dataRes = (rawData-(model.*repmat(gainApp,[nSamp 1])))./visNoise;
        redChiSqVal = (transpose(conj(dataRes(dataMask)))*dataRes(dataMask))/...
            (sum(dataMask)-nAnts);
        gainCovar = covarMatrix.*redChiSqVal;
        gainErrs = diag(gainCovar);
        varargout{1} = gainCovar;
        varargout{2} = gainErrs;
    end
end
