In [26]:
import torch


# Define your batches
rads_b = torch.tensor([[-1,2],[-4,1],[1,2],[4,5]], dtype=torch.float32)
rad_deltas_b = torch.tensor([[1,1],[2,-1],[-3,-4],[1,2]], dtype=torch.float32)
cyt_b = torch.tensor([5.0, 3.0, 4.0, 10.0], dtype=torch.float32)
efficiency_b = torch.tensor([0.8, 0.85, 0.9, 1.0], dtype=torch.float32)

def grow_torch(rads_b, rad_deltas_b, cyt_b, efficiency_b):
    csa_deltas_b = (rads_b + rad_deltas_b)**2 - rads_b**2 # cross-sectional area

    # Print initial batches
    # print("\nR: \n", rads_b)
    # print("\nCSA Mags: \n", rads_b**2)
    # print("\ndR: \n", rad_deltas_b)
    # print("\nCSA Deltas: \n", csa_deltas_b)
    # print("\nCyt: \n", cyt_b)

    # print("-------")

    # Atrophy muscle and convert to cyt
    cyt_b -= torch.sum(torch.where(csa_deltas_b < 0, csa_deltas_b, torch.tensor(0.0)),dim=1) * efficiency_b

    new_csa_mags_b = rads_b**2.0
    new_csa_mags_b[csa_deltas_b < 0] += csa_deltas_b[csa_deltas_b < 0]

    # print("\nNew CSA mags: \n", new_csa_mags_b)
    # print("\nCyt after atrophy: \n", cyt_b)

    # print("-------")

    # Grow myscle from cyt, if possible
    cyt_desired_b = torch.sum(torch.where(csa_deltas_b > 0, csa_deltas_b, torch.tensor(0.0)),dim=1)
    csa_delta_distribution_b = torch.where(csa_deltas_b > 0, csa_deltas_b, torch.tensor(0.0)) / cyt_desired_b.unsqueeze(1)

    cyt_consumed_b = torch.where(cyt_desired_b > cyt_b, cyt_b, cyt_desired_b)
    csa_grown_b = cyt_consumed_b * efficiency_b
    new_csa_mags_b = torch.where(csa_deltas_b > 0, new_csa_mags_b + csa_grown_b.unsqueeze(1) * csa_delta_distribution_b, new_csa_mags_b)

    cyt_b -= cyt_consumed_b

    new_rad_mags_b = torch.sqrt(new_csa_mags_b)
    new_signs = torch.sign(rads_b + rad_deltas_b)

    # print("\nCyt Desired: \n", cyt_desired_b)
    # print("\nNew CSA Mags: \n", new_csa_mags_b)
    # print("\nNew Rads: \n", new_rad_mags_b * new_signs)
    # print("\nNew Cyt: \n", cyt_b)
    return new_rad_mags_b * new_signs, cyt_b

grow_torch(rads_b, rad_deltas_b, cyt_b, efficiency_b)


(tensor([[ 0.0000,  2.8284],
         [-2.0000,  0.0000],
         [-1.9235, -2.0000],
         [ 4.3275,  5.6809]]),
 tensor([ 0.8000, 14.0500,  1.0000,  0.0000]))

In [28]:
rads_batch = torch.randn(10000,5)
rad_deltas_batch = torch.randn(10000,5)
cyt_batch = torch.randn(10000)
efficiency_batch = torch.randn(10000)

import time 
start = time.time()
for i in range(1000):
    grow_torch(rads_batch, rad_deltas_batch, cyt_batch, efficiency_batch)
print((time.time() - start)/1000)

0.0012755818367004394
