This case study will explore the examples of Wassmer and Dragalin (2015). This paper uses closed testing methods to ensure Type I Error control, while adapting the trial to focus patient recruitment on the subpopulations where the treatment performs best.

There are 3 different case studies in the paper: 4.1, 4.2, and 4.3.
4.1 uses binomial outcomes, with a Bonferroni rule, with 2 total subgroups x 2 treatments (T vs C) = 4 parameters. Two analysis times, including the final. This should be improvable.
4.2 uses exponential outcomes, does a log-rank test with Dunnett's test for multiple comparisons, with 2 total subgroups x 2 treatments = with 2 total subgroups x 2 treatments (T vs C) = 4 parameters. Four analysis times, including the final. We expect that this one is essentially tight/hard to improve, because it uses Dunnett, which should be tight
4.3 uses exponential outcomes, and does log-rank tests with the Bonferroni rule, with 4 subgroups x 2 treatments = 8 parameters. This is too much for proof-by-simulation to cover. We could still give a superior power analysis to what was reported in the paper, which fixes many of these parameters and varies the treatment effects.

Of this subset, WE WILL DO: 4.1, and then show that it can also be done under exponential outcomes, in both cases yielding improvements.

If we want to beef it up further we could try to compare against 4.2, but I don't think this is strictly necessary to show.


In [7]:
from imprint.nb_util import setup_nb

# setup_nb is a handy function for setting up some nice plotting defaults.
setup_nb()
import scipy.stats
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import imprint as ip
import numpy as np
import math
import random
import jax
import jax.numpy as jnp
from scipy.special import expit, logit

Example 4.2 has survival outcomes.
There are 2 treatment arms and 2 subgroups:
40% of patients are Stage II, 60% of patients are Stage III.

We will use exponential outcomes to represent overall survival:

If all patients start at the same time, the results will implicitly cover the class of proportional hazards models with bounded hazard, by an implicit time-reparametrization, so long as our analysis uses only failure-rank information.

(This is not realistic due to run-in, but the degree to which it is violated does not seem to obviously favor treatment or control... clinical trials I believe generally use time-from-treatment-to-event scale. ...hold on a second, this is already a problem in the normal course of things: what do trials currently do for a log-rank test when there is spillover across interims from time-to-treat analysis, where the previous observations effectively change? anyway...)

Most of the patients who died were in the Stage III group.

Let's assume the sequence of patient arrivals is known (and in the future, that we'll re-visit this analysis with sequential conditioning). This analysis here also effectively conditions on the final number of deaths, anyway.

There were 1123 patients assigned to each treatment group/ with a 40/60 split in each of Stage II vs Stage III

They give the interims weights driven by observed information rates:
information rates:
0.53, 0.7, 0.77, 1.0
with weights given by 
sqrt(t_k - t_k-1).

The subgroup S is Phase III status.
The full group is Phase II + Phase III.

They use the log-rank statistics with a normal approximation for the elementary hypotheses.

For every hypothesis (including the intersection hypotheses), they will do "alpha-spending" which results in critical values:
3.090, 2.976, 2.854, 1.968

We will tune the final of these with calibration, to measure our loss.

To get z-stats, we consider increments of the log-rank test: if d_k is the number of events in the log-rank test,

Z = [sqrt(d_k) LR_k - sqrt(d_{k-1}) LR_{k-1} ] / sqrt(d_k - d_{k-1})

is used as a z-test

For the intersection hypotheses, we use Dunnet's test on the two Z-stats:

p = 1 - F_{sigma,df} (max(Z_1, Z_2))

Where sigma is a correlation matrix with diagonals of 1, and the off-diagonals

sqrt(n0_S + n1_S)/sqrt(n0_F + n1_F)

So basically the covariance matrix is the ratio of total sample sizes in the small/larger-containing group, where "sample sizes" counts only EVENT FAILURES.

We'll just use a normal approximation ourselves, with the code:

import scipy.stats
c = 0.5
1 - scipy.stats.multivariate_normal.cdf([c, c], mean=[0, 0], cov=[[1, 0], [0, 1]])

In [1]:
# The FOLFOX4 group is our treatment
# LV5FU2 is our control
# We also have our subgroups, stage 2 or 3 cancer.
# We list their estimated hazard values here (used in the text)

# Now, assuming that the trial will be cut at 528 failures, 
# WLOG we can set hcontrol_3 = 1
hcontrol_3 = 1
htreat_3 = 0.9
hcontrol_2 = 0.5
htreat_2 = 0.5
# I'm guessing these numbers will give us about similar outcomes to what was observed

In [2]:
# We now have to make a list of the relevant 
# intersection hypotheses, 
#which are indentified by which subgroups are part of them
# In this example, there are 2 elementary hypotheses
hypo3_live = True
hypofull_live = True

In [27]:
outcomes_list.size

1128

In [109]:
# We'll repeat this for the second half, which is the same size as the first half
npatientstotal = 1128*2
npatients_2_perarm = int(1128 * 0.4)
npatients_3_perarm = int(1128 * 0.6) + 1
treat2outcomes = scipy.stats.expon.rvs(size = npatients_2_perarm) / htreat_2
ctrl2outcomes = scipy.stats.expon.rvs(size = npatients_2_perarm) / hcontrol_2
ctrl3outcomes = scipy.stats.expon.rvs(size = npatients_3_perarm)/ hcontrol_3
treat3outcomes = scipy.stats.expon.rvs(size = npatients_3_perarm) / htreat_3

#
information_weights = [0.53, 0.7, 0.77, 1.0]
# We now determine the relevant stopping times
stopping_points = [int(528 * x) for x in information_weights]

# crying. I'm sorry. I'm so sorry.
outcomes_list = np.append(np.append(np.append(ctrl2outcomes, ctrl3outcomes), treat3outcomes),treat2outcomes)
sorted = np.sort(outcomes_list)
stoptimes = [sorted[stopping_points[i]] for i in range(4)]

[597, 789, 868, 1128]

In [110]:
#The adult way to do this would be to edit the following logrank function, 
#to switch out censoring time for a sequence of censoring times, and each time return the relevant z-statistics.
# Because the following is already vectorized and so close to what we want, 
# I'm tempted to just leave it this way and code on ahead without making a runnable single-unit version

@jax.jit
def logrank_test(all_rvs, group, censoring_time):
    n0 = jnp.array([jnp.sum(~group), jnp.sum(group)])
    ordering = jnp.argsort(all_rvs)
    ordered = all_rvs[ordering]
    ordered_group = group[ordering]
    include = ordered <= censoring_time
    event_now = jnp.stack((~ordered_group, ordered_group), axis=0) * include
    events_so_far = jnp.concatenate(
        (jnp.zeros((2, 1)), event_now.cumsum(axis=1)[:, :-1]), axis=1
    )
    Nij = n0[:, None] - events_so_far
    Oij = event_now
    Nj = Nij.sum(axis=0)
    Oj = Oij.sum(axis=0)
    Eij = Nij * (Oj / Nj)
    Vij = Eij * ((Nj - Oj) / Nj) * ((Nj - Nij) / (Nj - 1))
    denom = jnp.sum(jnp.where(~jnp.isnan(Vij[0]), Vij[0], 0), axis=0)
    return jnp.sum(Oij[0] - Eij[0], axis=0) / jnp.sqrt(denom)


In [132]:
# So, these functions shouldn't run, but they're already in multidimensional form, and need to just be modified to output
# multiple logrank-statistics as of each different censoring point
#z_3_elementary = logrank_test([treat3outcomes, ctrl3outcomes], group = blah, censoring_times = stopping_points)
#z_full_elementary = logrank_test([treat3outcomes + treat2outcomes, ctrl3outcomes+ ctrl2outcomes], group = blah, censoring_times = stopping_points)

# Putting in placeholder stats
Logrank_stats_3 = [1.2, 1.67, 1.349, 1.73]
Logrank_stats_full = [0.3746, 0.87, 0.55, 1.11]

# Now we compute independent z_statistics
num_3 = np.diff(np.append(0, np.sqrt(information_weights) * Logrank_stats_3))
denom_3 = np.sqrt(np.diff(np.append(0, information_weights)))
z_3_elementary = num_3/denom_3

num_full = np.diff(np.append(0, np.sqrt(information_weights) * Logrank_stats_full))
denom_full = np.sqrt(np.diff(np.append(0, information_weights)))
z_full_elementary = num_full/denom_full

# Now we compute the p-values for each period
p_3_elementary = 1-scipy.stats.norm.cdf(z_3_elementary)
p_full_elementary = 1-scipy.stats.norm.cdf(z_full_elementary)

p_intersection = np.full_like(p_3_elementary, 2)

# doing this for the covariance:
for i in range(4):
    if(i == 0):
        numfailures_3 = sum(treat3outcomes <= stoptimes[i]) + sum(ctrl3outcomes <= stoptimes[i])
        numfailures_total = stopping_points[i]
    if(i > 0):
        numfailures_3 = sum(treat3outcomes <= stoptimes[i]) - sum(treat3outcomes <= stoptimes[i-1]) + sum(ctrl3outcomes <= stoptimes[i]) - sum(ctrl3outcomes <= stoptimes[i-1])
        numfailures_total = stopping_points[i] - stopping_points[i-1]
    offdiag = numfailures_3 / numfailures_total
    c = np.max(np.append(z_3_elementary[i], z_full_elementary[i]))
    p_intersection[i] = 1 - scipy.stats.multivariate_normal.cdf([c, c], mean=[0, 0], cov=[[1, offdiag], [offdiag, 1]])
# In fact there is no selection! So we simply use Dunnett on the elementary hypotheses here...

In [134]:
# Now we perform weighted combination, according to the information weights. No subsets can be dropped, so this is a simple weighted sum.
info_decomp = np.append(information_weights[0],np.diff(information_weights))
z_weights = np.sqrt(info_decomp)
p_3_combined = 1 - scipy.stats.norm.cdf (np.sum(scipy.stats.norm.ppf(1 - p_3_elementary) * z_weights))
p_full_combined = 1 - scipy.stats.norm.cdf (np.sum(scipy.stats.norm.ppf(1 - p_full_elementary) * z_weights))
p_intersection_combined = 1 - scipy.stats.norm.cdf (np.sum(scipy.stats.norm.ppf(1 - p_intersection) * z_weights))

array([ 1.2       ,  1.2699385 , -0.80687685,  1.13902522])

In [135]:
# compute final rejection from closed testing:
alpha = 0.05
reject_3 = (p_intersection_combined < alpha) & (p_3_combined < alpha)
reject_full = (p_intersection_combined < alpha) & (p_full_combined < alpha)

array([ 0.3746    ,  1.10397683, -0.92703793,  1.30817139])

In [None]:
#HRplus_pooledaverage = outcomesHRplustreat.mean/2 + outcomesHRpluscontrol.mean/2
#denominatorHRplus = np.sqrt(HRplus_pooledaverage*(1-HRplus_pooledaverage)*(1/npatientsperarm + 1/npatientsperarm))
#zHRplus_stage1 = (outcomesHRplustreat.mean - outcomesHRpluscontrol.mean)/denominatorHRplus

# Note: this could be expressed more simply as a single function if we could nicely combine treatment outcome columns. but whatever...
TNBC_pooledaverage = outcomesTNBCtreat.mean/2 + outcomesTNBCcontrol.mean/2
denominatorTNBC = np.sqrt(TNBC_pooledaverage*(1-TNBC_pooledaverage)*(1/npatientsperarm_firststage + 1/npatientsperarm_firststage))
zTNBC_stage1 = (outcomesTNBCtreat.mean - outcomesTNBCcontrol.mean)/denominatorTNBC

totally_pooledaverage = outcomesTNBCtreat.mean/4 + outcomesTNBCcontrol.mean/4 + outcomesHRplustreat.mean/4 + outcomesHRpluscontrol.mean/4
denominatortotallypooled = np.sqrt(totally_pooledaverage*(1-totally_pooledaverage)*(1/(2*npatientsperarm_firststage) + 1/(2*npatientsperarm_firststage)))
zfull_stage1 = (outcomesTNBCtreat.mean/2 + outcomesHRplustreat.mean/2 - outcomesTNBCcontrol.mean/2 - outcomesHRpluscontrol.mean/2)/denominatortotallypooled

In [31]:
# Done! Just ignore the below scratch

In [24]:
class LogRankVlad:
    def __init__(self, seed, max_K, *, n, censoring_time):
        self.max_K = max_K
        self.censoring_time = censoring_time
        self.n = n
        self.family = "exponential"
        self.family_params = {"n": n}

        self.key = jax.random.PRNGKey(seed)
        self.samples_1 = jax.random.exponential(self.key, shape=(max_K, n, 4))
        self.group = jnp.concatenate(
            [jnp.zeros((max_K, n)), jnp.ones((max_K, n))], axis=1
        ).astype(bool)
        self.vmap_logrank_test = jax.vmap(
            jax.vmap(logrank_test, in_axes=(0, 0, None)), in_axes=(0, None, None)
        )

    def sim_batch(self, begin_sim, end_sim, theta, null_truth, detailed=False):
        control_hazard = -theta[:, 0]
        treatment_hazard = -theta[:, 1]
        hazard_ratio = treatment_hazard / control_hazard
        control_rvs = jnp.tile(self.samples[None, :, :, 0], (hazard_ratio.shape[0], 1, 1))
        treatment_rvs = self.samples[None, :, :, 1] / hazard_ratio[:, None, None]
        all_rvs = jnp.concatenate([control_rvs, treatment_rvs], axis=2)
        test_stat = -self.vmap_logrank_test(
            all_rvs, self.group, self.censoring_time
        )
        return test_stat

DeviceArray([          nan, 5.2642647e-02,           nan, 5.2622620e-02,           nan,
                       nan,           nan,           nan, 5.2641593e-02, 2.2225672e-01, ...,
                       nan, 5.2633159e-02,           nan, 5.2621566e-02,           nan,
                       nan, 2.8201441e-09,           nan,           nan,           nan],            dtype=float32)

In [None]:
g = ip.cartesian_grid([-1, -1], [-1, -1], n=[1, 1], null_hypos=[ip.hypo("theta0 > theta1")])

In [None]:
lr = LogRank(0, 200000, n=100, censoring_time=10000000)
stats = lr.sim_batch(0, lr.max_K, g.get_theta(), g.get_null_truth())