# Модель болезни

In [None]:
import numpy as np
import scipy.linalg as spl
import scipy.optimize as sopt
import scipy.stats as stats

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
#from torch.autograd.functional import jacobian

In [None]:
import matplotlib
import matplotlib.pyplot as plt

font = {'family' : 'Liberation Sans',
        'weight' : 'normal',
        'size'   : 30}

matplotlib.rc('font', **font)

In [None]:
import sys
sys.path.insert(0, './python')

In [None]:
from fuzzy_torch import logic
from fuzzy_torch.modules.indicators import *
from fuzzy_torch.modules.ffsa import *

In [None]:
Logic = logic.Product

## Модель болезни

In [None]:
class ControlModel(nn.Module):
    def __init__(self, system, controller, time, n_steps):
        super().__init__()
        self.system = system
        self.controller = controller
        self.time = time
        self.n_steps = n_steps
        
    def forward(self, input, parameters):        
        variables = [input]
        activations = [torch.zeros((input.size()[0], len(self.controller.ffsa.states)))]
        activations[0][:,0] = 1.0
        controls = []
        
        dt = self.time / self.n_steps
        for step in range(self.n_steps):
            control, activation = self.controller(variables[-1], activations[-1], dt)
            activations.append(activation)
            controls.append(control)
            variables.append(
                torch.clamp(variables[-1] + dt * self.system(parameters, controls[-1], variables[-1]), min=0)
                #variables[-1] + dt * self.system(parameters, controls[-1], variables[-1])
            )
            
        return torch.stack(variables, dim=1), torch.stack(controls, dim=1), torch.stack(activations, dim=1)

In [None]:
class ImmunityModelSystem:
    def __call__(self, parameters, control, variables):
        # Параметры
        a = parameters[:,0]   # Скорость размножения антигена.
        b = parameters[:,1]   # Скорость уничтожения антигенов.
        c = parameters[:,2]   # Скорость производства специфичности.
        d = parameters[:,3]   # Скорость восстановления уровня специфичности.

        C = parameters[:,4]   # Предельная концентрация антигена.
        S_0 = parameters[:,5] # Равновесная специфичность.
        m_0 = parameters[:,6] # Невосприимчивость к лекарству.
        
        result = torch.zeros_like(variables)
        result[:,0] = torch.exp(-control[:,0] / m_0) * a * variables[:,0] * (C - variables[:,0]) - \
                          b * variables[:,0] * variables[:,1]
        result[:,1] = c * variables[:,0] - d * (variables[:,1] - S_0)
        
        return result

In [None]:
class FFSARegressor(nn.Module):
    def __init__(self, logic, ffsa):
        super().__init__()
        self.logic = logic
        self.ffsa = ffsa
        self.debug = False
        
    def forward(self, input, activation, dt):       
        # Новые активации (согласно нечеткому конечному автомату).
        new_activation = self.ffsa(input, activation, dt)
        
        # Получение выходов регрессоров.
        outputs = [state(input) for state in self.ffsa.states]
        stacked_outputs = torch.stack(outputs, dim=1)
        weighted_output = torch.einsum("bo,bo...->b...", new_activation, stacked_outputs)
            
        return weighted_output, new_activation

## Набор данных

In [None]:
class InfectionDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.a_dist = stats.lognorm(scale=2e-7, s=1.0e-1)  # Скорость размножения антигена.
        self.b_dist = stats.lognorm(scale=8e-6, s=1.0e-1)  # Скорость уничтожения антигенов.
        self.c_dist = stats.lognorm(scale=8e-7, s=1.0e-1)  # Скорость производства специфичности.
        self.d_dist = stats.lognorm(scale=2e-6, s=1.0e-1)  # Скорость восстановления уровня специфичности.

        self.C_dist = stats.lognorm(scale=100.0, s=1.0e-1) # Предельная концентрация антигена.
        self.S_0_dist = stats.lognorm(scale=0.1, s=1.0e-1) # Равновесная специфичность.
        self.m_0_dist = stats.lognorm(scale=1.0, s=1.0e-1) # Невосприимчивость к лекарству.
        
        self.init_V_dist = stats.lognorm(scale=1.0, s=3.0e-1)
        self.init_S_dist = stats.lognorm(scale=0.1, s=3.0e-1)
    
    def __len__(self):
        return 16384
    
    def __getitem__(self, idx):
        init_varibles = np.array([self.init_V_dist.rvs(1)[0], self.init_S_dist.rvs(1)[0]]).astype(np.float32)
        
        parameters = np.array([self.a_dist.rvs(1)[0],
                               self.b_dist.rvs(1)[0],
                               self.c_dist.rvs(1)[0],
                               self.d_dist.rvs(1)[0],
                               self.C_dist.rvs(1)[0],
                               self.S_0_dist.rvs(1)[0],
                               self.m_0_dist.rvs(1)[0]]).astype(np.float32)
        
        return init_varibles, parameters

In [None]:
dataset = InfectionDataset()

In [None]:
dataset[0]

In [None]:
def test_plots():
    init_variables, parameters = dataset[0]
    init_variables = torch.tensor(init_variables)[None,:]
    parameters = torch.tensor(parameters)[None,:]
    control = torch.ones(1, 1) * 0.0
    
    system = ImmunityModelSystem()
    
    n_days = 10
    time = n_days * 24 * 60 * 60
    n_steps = n_days * 24
    dt = time / n_steps
    
    T = np.linspace(0.0, time, n_steps+1)
    variables = [init_variables]
    for step in range(n_steps):
        variables.append(variables[-1] + dt * system(parameters, control, variables[-1]))
        
    variables = torch.stack(variables, dim=1).detach().numpy()[0]
        
    
    # Графики.
    fig, ax = plt.subplots()

    fig.set_figheight(12)
    fig.set_figwidth(24)
    ax.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')
    ax.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')

    ax.set_xlabel("$t, \\; с$")
    ax.set_ylabel("$x(t)$")

    #ax.set_yscale('log')

    ax.plot(T, variables[:,0], label="$ V $")
    ax.plot(T, variables[:,1], label="$ S $")

    ax.legend(loc='upper left')
    plt.show();

In [None]:
test_plots()

## Автомат

In [None]:
class LogSigmoid(Sigmoid):
    def __init__(self, in_features, weight=None, offset=None):
        super().__init__(in_features=in_features, weight=weight, offset=offset)
        
    def forward(self, x):
        return super().forward(torch.log(torch.abs(x) + 1e-4))

In [None]:
infection_ffsa = TimeDependentFFSA(Logic, normalize=True)

In [None]:
infection_ffsa.states = torch.nn.ModuleList([
    nn.Linear(2, 1),
    nn.Linear(2, 1),
    #nn.Linear(2, 1),
    #nn.Linear(2, 1)
])

In [None]:
infection_ffsa.states[0].weight = nn.Parameter(torch.zeros((1, 2)))
infection_ffsa.states[0].bias = nn.Parameter(torch.zeros((1)))

infection_ffsa.states[1].weight = nn.Parameter(torch.ones((1, 2)) / 8)
infection_ffsa.states[1].bias = nn.Parameter(torch.zeros((1)))

#infection_ffsa.states[2].weight = nn.Parameter(torch.ones((1, 2)) / 8)
#infection_ffsa.states[2].bias = nn.Parameter(torch.zeros((1)))

#infection_ffsa.states[3].weight = nn.Parameter(torch.zeros((1, 2)))
#infection_ffsa.states[3].bias = nn.Parameter(torch.zeros((1)))

In [None]:
infection_ffsa.transitions = torch.nn.ModuleList([
    ContinuousFuzzyTransition(0, 1, LogSigmoid(
        in_features=2,
        weight=np.array([1.0, 0.0]).astype(np.float32),
        offset=np.array([4.0, 0.0]).astype(np.float32)),
        speed=1e-6),
    #ContinuousFuzzyTransition(1, 2, LogSigmoid(
    #    in_features=2,
    #    weight=np.array([-1.0, 1.0]).astype(np.float32),
    #    offset=np.array([3.5, 1.0]).astype(np.float32)),
    #    speed=1e-6),
    #ContinuousFuzzyTransition(1, 3, LogSigmoid(
    #    in_features=2,
    #    weight=np.array([-1.0, 1.0]).astype(np.float32),
    #    offset=np.array([2.0, 1.0]).astype(np.float32)),
    #    speed=1e-6),
    #ContinuousFuzzyTransition(2, 1, LogSigmoid(
    #    in_features=2,
    #    weight=np.array([1.0, 1.0]).astype(np.float32),
    #    offset=np.array([4.0, 1.0]).astype(np.float32)),
    #    speed=1e-6),
    #ContinuousFuzzyTransition(2, 3, LogSigmoid(
    #    in_features=2,
    #    weight=np.array([-1.0, 1.0]).astype(np.float32),
    #    offset=np.array([2.0, 1.0]).astype(np.float32)),
    #    speed=1e-6)
])

## Регрессор

In [None]:
infection_regressor = FFSARegressor(Logic, infection_ffsa)

## Итоговая модель

In [None]:
system = ImmunityModelSystem()

n_days = 10
time = n_days * 24 * 60 * 60
n_steps = n_days * 24 * 8
dt = time / n_steps

In [None]:
model = ControlModel(system, infection_regressor, time, n_steps)

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

In [None]:
def my_loss(variables, controls):
    #print(variables.size())
    loss = torch.sqrt(torch.mean((variables[:,:,0])**2)) + 5.0 * torch.sqrt(torch.mean((controls)**2))
    return loss

In [None]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.0001)#, momentum=0.9)

In [None]:
for epoch in range(50):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        init_variables, parameters = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        variables, controls, activations = model(init_variables, parameters)
        loss = my_loss(variables, controls)
        controls.retain_grad()
        loss.backward(retain_graph = True)
        #if epoch <= 2:
        #    torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
        for param in model.parameters():
            param.grad = torch.nan_to_num(param.grad, nan=0.0, posinf=0.0, neginf=0.0)
        optimizer.step()

        # print statistics
        #print(activations[0])
        #print(controls[0])
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}')

print('Finished Training')

In [None]:
for state in model.controller.ffsa.states:
    print(state.weight)
    print(state.bias)

In [None]:
for transition in model.controller.ffsa.transitions:
    print(transition.condition.linear.weight)
    print(transition.condition.linear.bias)
    print(transition.speed)

In [None]:
print(variables[0])

In [None]:
print(activations[1])

In [None]:
print(controls[0])

In [None]:
def test_control_plots():
    init_variables, parameters = dataset[0]
    init_variables = torch.tensor(init_variables)[None,:]
    parameters = torch.tensor(parameters)[None,:]
    control = torch.ones(1, 1) * 0.0
    
    system = ImmunityModelSystem()
    
    n_days = 100
    time = n_days * 24 * 60 * 60
    n_steps = n_days * 24
    dt = time / n_steps
    
    T = np.linspace(0.0, time, n_steps+1)
    variables = [init_variables]
    for step in range(n_steps):
        variables.append(variables[-1] + dt * system(parameters, control, variables[-1]))
        
    variables = torch.stack(variables, dim=1).detach().numpy()[0]
        
    
    # Графики.
    fig, ax = plt.subplots()

    fig.set_figheight(12)
    fig.set_figwidth(24)
    ax.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')
    ax.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')

    ax.set_xlabel("$t, \\; с$")
    ax.set_ylabel("$x(t)$, у.е.")

    #ax.set_yscale('log')

    ax.plot(T, variables[:,0], label="Вирусная нагрузка")
    ax.plot(T, variables[:,1], label="Иммунная реакция")

    ax.legend(loc='upper right')
    plt.show();
    
    
    model.time = time
    model.n_steps = n_steps
    new_variables, controls, activations = model(init_variables, parameters)
    
    
    # Графики.
    fig, ax = plt.subplots()

    fig.set_figheight(12)
    fig.set_figwidth(24)
    ax.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')
    ax.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')

    ax.set_xlabel("$t, \\; с$")
    ax.set_ylabel("$x(t)$, у.е.")

    #ax.set_yscale('log')

    ax.plot(T, new_variables[0,:,0].detach().numpy(), label="Вирусная нагрузка")
    ax.plot(T, new_variables[0,:,1].detach().numpy(), label="Иммунная реакция")
    ax.plot(T[1:], controls[0,:,0].detach().numpy(), label="Лекарство")

    ax.legend(loc='upper right')
    plt.show();
    
    
    # Графики.
    fig, ax = plt.subplots()

    fig.set_figheight(12)
    fig.set_figwidth(24)
    ax.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')
    ax.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')

    ax.set_xlabel("$t, \\; с$")
    ax.set_ylabel("$x(t)$, у.е.")

    #ax.set_yscale('log')

    ax.plot(T, variables[:,0], label="Вирусная нагрузка без лечения")
    ax.plot(T, new_variables[0,:,0].detach().numpy(), label="Вирусная нагрузка с лечением")

    ax.legend(loc='upper right')
    plt.show();

In [None]:
test_control_plots()