In [7]:
"""
Conditional exact Cochran-Armitage trend test
Translated from R (https://github.com/cran/CATTexact/blob/master/R/CATTexact.R) to python. 


#' Conditional exact Cochran-Armitage trend test
#'
#' \code{catt_exact} calculates the Cochran-Armitage trend test statistic (Cochran (1954), Armitage (1955)) and the one-sided p-value for the corresponding conditional exact test.
#' The conditional exact test has been established by Williams (1988). The computation of its p-value is performed using an algorithm following an idea by Mehta, et al. (1992).
#'
#' @param dose.ratings A vector of dose ratings, the i-th entry corresponds to the dose-rating of the i-th group. This vector must be strictly monotonically increasing
#' @param totals The vector of total individuals per group, the i-th entry corresponds to the total number of individuals in the i-th group.
#' @param cases The vector of incidences per groups, the i-th entry corresponds to the number of incidences in the i-th group.
#' @return A list containing the value of the Cochran-Armitage Trend Test Statistic, its exact and asymptotic p-value.
#' @references Armitage, P. Tests for linear trends in proportions and frequencies. \emph{Biometrics}, 11 (1955): 375-386.
#' @references Cochran, W. G. Some methods for strengthening the common \eqn{\chi^2} tests, \emph{Biometrics}. 10 (1954): 417-451.
#' @references Mehta, C. R., Nitin P., and Pralay S. Exact stratified linear rank tests for ordered categorical and binary data. \emph{Journal of Computational and Graphical Statistics}, 1 (1992): 21-40.
#' @references Portier, C., and Hoel D. Type 1 error of trend tests in proportions and the design of cancer screens. \emph{Communications in Statistics-Theory and Methods}, 13 (1984): 1-14.
#' @references Williams, D. A. Tests for differences between several small proportions. \emph{Applied Statistics}, 37 (1988): 421-434.
#' @examples
#' d <- c(1,2,3,4)
#' n <- rep(20,4)
#' r <- c(1,4,3,8)
#'
#' catt_exact(d, n, r)
#'
#' @export

"""

import numpy as np
from scipy import stats

def catt_exact(dose_ratings, totals, cases):
    le_d = len(dose_ratings)
    le_n = len(totals)
    le_r = len(cases)

    if le_d != le_n or le_d != le_r:
        raise ValueError("Length of input is differing!")

    le = le_d

    if le < 3:
        raise ValueError("Need at least three groups")

    # Extract Input, calculate total number of cases and individuals
    nk = np.array(totals)
    nhat = np.sum(nk)

    rk = np.array(cases)
    rhat = np.sum(rk)

    dk = np.array(dose_ratings)

    # Input checks
    if np.any(np.round(np.concatenate([nk, rk])) != np.concatenate([nk, rk])):
        raise ValueError("The number of totals and cases must be integer")

    if np.any(nk <= 0):
        raise ValueError("There must be at least one individual in every dose group")

    if np.any(rk < 0):
        raise ValueError("The number of cases in each group must be nonnegative")

    if np.any(nk < rk):
        raise ValueError("The number of cases cannot exceed the size of the group")

    if np.all(rk == 0):
        raise ValueError("This test cannot be applied when there are no cases")

    if np.all(nk == rk):
        raise ValueError("This test cannot be applied when the number of cases equals the total number of individuals")

    check_dose_mon_vec = np.ones(le - 1)
    for i in range(le - 1):
        check_dose_mon_vec[i] = int(dk[i + 1] > dk[i])

    check_dose_mon = np.min(check_dose_mon_vec)

    if check_dose_mon == 0:
        raise ValueError("Doses must be strictly monotonically increasing")

    factor = np.sqrt(nhat / ((nhat - rhat) * rhat))

    enum = np.sum((rk - (nk / nhat) * rhat) * dk)
    denom = np.sqrt(np.sum((nk / nhat) * dk**2) - np.sum((nk / nhat) * dk)**2)

    test_statistic = -factor * enum / denom

    # Assuming the .pval_exact and .aspvalue methods are defined elsewhere.
    def pval_exact(dk, nk, rk):
        # Implement the exact p-value calculation here
        pass

    def aspvalue(test_statistic):
        # Implement the asymptotic p-value calculation here
        pass

    pval_exact_result = pval_exact(dk, nk, rk)
    pval_asy = aspvalue(test_statistic)

    return {"test.statistic": test_statistic, "exact.pvalue": pval_exact_result, "asymptotic.pvalue": pval_asy}


In [5]:
"""
Asymptotic Cochran-Armitage trend test
Translated from R (https://github.com/cran/CATTexact/blob/master/R/CATTexact.R) to python. 


#' Asymptotic Cochran-Armitage trend test
#'
#' \code{catt_asy} calculates the Cochran-Armitage trend test statistic (Cochran (1954), Armitage (1955)) and the one-sided p-value for the corresponding asymptotic test.
#' The exact form of used test statistic can be found in the paper by Portier and Hoel (1984).
#'
#' @param dose.ratings A vector of dose ratings, the i-th entry corresponds to the dose-rating of the i-th group. This vector must be strictly monotonically increasing
#' @param totals The vector of total individuals per group, the i-th entry corresponds to the total number of individuals in the i-th group
#' @param cases The vector of incidences per groups, the i-th entry corresponds to the number of incidences in the i-th group
#' @return A list containing the value of the Cochran-Armitage Trend Test Statistic and its asymptotic p-value.
#' @references Armitage, P. Tests for linear trends in proportions and frequencies. \emph{Biometrics}, 11 (1955): 375-386.
#' @references Cochran, W. G. Some methods for strengthening the common \eqn{\chi^2} tests, \emph{Biometrics}. 10 (1954): 417-451.
#' @references Portier, C., and Hoel D. Type 1 error of trend tests in proportions and the design of cancer screens. \emph{Communications in Statistics-Theory and Methods}, 13 (1984): 1-14.
#' @examples
#' d <- c(1,2,3,4)
#' n <- rep(20,4)
#' r <- c(1,4,3,8)
#'
#' catt_asy(d, n, r)
#'
#' @export

"""

import math
from scipy.stats import norm
from scipy.special import comb

def catt_asy(dose_ratings, totals, cases):
    le_d = len(dose_ratings)
    le_n = len(totals)
    le_r = len(cases)

    if le_d != le_n or le_d != le_r:
        raise ValueError("Length of input is differing!")

    le = le_d

    if le < 3:
        raise ValueError("Need at least three groups")

    # Extract Input, calculate total number of cases and individuals
    nk = totals
    nhat = sum(nk)

    rk = cases
    rhat = sum(rk)

    dk = dose_ratings

    # Input checks
    if any(isinstance(x, float) and not x.is_integer() for x in nk + rk):
        raise ValueError("The number of totals and cases must be integer")

    if any(x <= 0 for x in nk):
        raise ValueError("There must be at least one individual in every dose group")

    if any(x < 0 for x in rk):
        raise ValueError("The number of cases in each group must be nonnegative")

    if any(nk[i] < rk[i] for i in range(le)):
        raise ValueError("The number of cases cannot exceed the size of the group")

    if all(x == 0 for x in rk):
        raise ValueError("This test cannot be applied when there are no cases")

    if all(nk[i] == rk[i] for i in range(le)):
        raise ValueError("This test cannot be applied when the number of cases equals the total number of individuals")

    check_dose_mon_vec = [1] * (le - 1)
    for i in range(le - 1):
        check_dose_mon_vec[i] = int(dk[i + 1] > dk[i])

    check_dose_mon = min(check_dose_mon_vec)

    if check_dose_mon == 0:
        raise ValueError("Doses must be strictly monotonically increasing")

    factor = math.sqrt(nhat / ((nhat - rhat) * rhat))

    enum = sum((rk[i] - (nk[i] / nhat) * rhat) * dk[i] for i in range(le))

    denom = math.sqrt(sum((nk[i] / nhat) * dk[i] ** 2 for i in range(le)) - sum((nk[i] / nhat) * dk[i] for i in range(le)) ** 2)

    test_statistic = -factor * enum / denom

    pval_asy = aspvalue(test_statistic)

    return {"test.statistic": test_statistic, "asymptotic.pvalue": pval_asy}


def pval_exact(dk, nk, rk):
    dk = dk / dk[1]

    rest = dk - np.floor(dk)
    rest = rest[rest > 0]

    mult = min(1 / np.prod(rest), 10 ** 12)

    dk = np.round(dk * mult)

    le = len(dk)
    nodes = [None] * (le + 1)
    nhat = sum(nk)
    rhat = sum(rk)
    nodes[0] = 0
    a0 = sum(rk[i] * dk[i] for i in range(le))

    for i in range(le - 1):
        lowerbound = max(0, rhat - sum(nk[i + 1:]))
        upperbound = min(rhat, sum(nk[:i + 1]))

        nodes[i + 1] = list(range(lowerbound, upperbound + 1))

    nodes[le] = [rhat, 0]
    arcs = [[] for _ in range(le)]

    for i in range(le):
        for j in nodes[i]:
            for k in range(max(j, min(nodes[i + 1])) , min(max(nodes[i + 1]), j + nk[i])):
                arcs[i].append([j, k, dk[i] * (k - j), comb(nk[i], k - j)])

    for i in range(le):
        nodes[i] = list(zip(nodes[i], [0] * len(nodes[i])))

    for i in range(le - 1, -1, -1):
        for j in nodes[i]:
            arckonkur = [arc for arc in arcs[i] if arc[0] == j[0]]
            for k in range(len(arckonkur)):
                arckonkur[k][3] = nodes[i + 1][nodes[i + 1].index([arckonkur[k][1], 0])][1]

            nodes[i][nodes[i].index(j)][1] = max([arc[3] + arc[2] for arc in arckonkur])

    nodes_u = [[] for _ in range(le + 1)]
    nodes_cu = [[] for _ in range(le + 1)]

    nodes_u[0] = [0]
    nodes_cu[0] = [0]

    nodes_u[1] = arcs[0][:, 2]
    nodes_cu[1] = arcs[0][:, 3]

    nodes_with_paths = nodes[1][:, 0]

    for i in range(1, le):
        nodes_with_paths_new = []
        for j in nodes_with_paths:
            succ = arcs[i][arcs[i][:, 0] == j]

            u_candidates = np.column_stack((succ, np.repeat(nodes_u[i][np.where(nodes[i][:, 0] == j)], len(succ) // 4)))
            cu_candidates = np.column_stack((succ, np.repeat(nodes_cu[i][np.where(nodes[i][:, 0] == j)], len(succ) // 4)))

            u_candidates[:, 4:] += succ[:, 2]
            cu_candidates[:, 4:] *= succ[:, 3]

            for k in range(len(succ) // 4):
                candidate = u_candidates[k, 1]
                LP = nodes[i + 1][np.where(nodes[i + 1][:, 0] == candidate), 1][0]

                u_list = u_candidates[k, 4:]
                cu_list = cu_candidates[k, 4:]
                u_list = u_list[u_list >= (a0 - LP - 1E-8)]
                cu_list = cu_list[u_list >= (a0 - LP - 1E-8)]

                if len(u_list) > 0:
                    nodes_with_paths_new.append(candidate)

                existing_u = np.intersect1d(u_list, nodes_u[i + 1][np.where(nodes[i + 1][:, 0] == candidate)])

                new_u = np.setdiff1d(u_list, nodes_u[i + 1][np.where(nodes[i + 1][:, 0] == candidate)])
                new_cu = cu_list[np.isin(u_list, new_u)]

                for l in existing_u:
                    index = np.where(nodes_u[i + 1][np.where(nodes[i + 1][:, 0] == candidate)] == l)[0][0]
                    index2 = np.where(u_list == l)[0][0]
                    nodes_cu[i + 1][np.where(nodes[i + 1][:, 0] == candidate)][index] += cu_list[index2]

                nodes_u[i + 1][np.where(nodes[i + 1][:, 0] == candidate)] = np.append(nodes_u[i + 1][np.where(nodes[i + 1][:, 0] == candidate)], new_u)
                nodes_cu[i + 1][np.where(nodes[i + 1][:, 0] == candidate)] = np.append(nodes_cu[i + 1][np.where(nodes[i + 1][:, 0] == candidate)], new_cu)

        nodes_with_paths = nodes_with_paths_new

    pval = sum(nodes_cu[le][0]) / comb(nhat, rhat)

    return pval


def aspvalue(statistic):
    pval = norm.cdf(statistic)
    return pval


