In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

## Load and Crop Unmarked MRI Scans

In [527]:
def crop(X):
    res = np.zeros((len(X), len(X[0]), 300, 300))
    for p in range(len(X)):
        for s in range(len(X[p])):
            for i in range(106, 406):
                res[p][s][i - 106] = X[p][s][i][106:406]
    return res
            
    
X = np.load('small_data.npy')
X = crop(X)
X = torch.from_numpy(X)
X = X.to(torch.float32)
X = [X[i-1:i] for i in range(1, len(X))]

y = pd.read_csv('all_target.csv').to_numpy()
y = torch.tensor(y)

N = len(y)

## Hot-encoded y

In [436]:
y6 = torch.zeros(N, 6)
y4 = torch.zeros(N, 4)
y2 = torch.zeros(N, 2)
for i in range(N):
    y6[i][y[i] - 1] = 1
    
    if y[i] < 3:
        y4[i][0] = 1
        y2[i][0] = 1
    
    elif y[i] == 3:
        y4[i][1] = 1
        y2[i][0] = 1
        
    elif y[i] == 4:
        y4[i][2] = 1
        y2[i][1] = 1
    
    else:
        y4[i][3] = 1
        y2[i][1] = 1

## Vascular System Data and Adjacency Matrix 

In [337]:
import re


def get_p(p):
    p = str(p)
    if '*' in p:
        p = re.split('\*|\^|\(|\)', p)
        p = [float(i) for i in p if i]
        
        return round(p[0] * p[1]**p[2], 5)
        
    if '-' in p:
        p = p.split('-')
        p = [float(i) for i in p]
        
        return round(np.mean(p), 5)
    
    return round(float(p), 5)


df = pd.read_excel('Таблица_графа.xlsx').iloc[:-2, :]
df.columns = ['n', 'name', 'p', 'u', 'A']
df.drop(['n', 'name'], axis=1, inplace=True)

df['p'] = df['p'].apply(get_p)
df['u'] = df['u'].apply(get_p)
df['A'] = df['A'].apply(get_p)

df.index += 1

"""
The dictionary with information about
every vessel.
"""
Info = df.to_dict('index')


"""
The adjacency graph of the vascular system.
Keys correspond to the ones of the Info graph.

Keys represent vessels. 
Values: 
    first array: vessels incoming to our vessel
    second array: vessels outgoing from our vessel
"""

Graph = {
    1: [[],[31, 38]], 2: [[27],[58, 59, 60]], 3: [[8],[4]], 4: [[3, 80],[26]], 5: [[6, 99],[7]], 6: [[95, 97],[5]], 
    7: [[5, 101],[23]], 8: [[12, 93],[3]], 9: [[25],[10, 11]], 10: [[9],[15, 62]], 11: [[9],[54]], 12: [[16, 17],[8]], 
    13: [[15],[18, 63]], 14: [[15],[19, 65]], 15: [[10],[13, 14]], 16: [[20, 90],[12]], 17: [[21, 89],[12]], 
    18: [[13],[64]], 19: [[14],[64]], 20: [[91],[16]], 21: [[91],[17]], 22: [[69],[24]], 23: [[7],[24]], 
    24: [[22, 23],[81]], 25: [[27],[9, 61]], 26: [[4,85],[55]], 27: [[28],[2,25]], 28: [[53],[27,66]], 29: [[31],[43]], 
    30: [[31],[53,57]], 31: [[1],[29,30]], 32: [[34,35],[33]], 33: [[32],[55]], 34: [[36,79],[32]], 35: [[39,82],[32]], 
    36: [[40],[34]], 37: [[38],[42]], 38: [[1],[37, 56]], 39: [[40],[35]], 40: [[47,52],[36,39]], 42: [[37],[44,102]], 
    43: [[29],[45,104]], 44: [[42],[46]], 45: [[43],[46]], 46: [[44,45],[106]], 47: [[103],[40]], 48: [[105],[52]], 
    49: [[53],[106]], 50: [[53],[106]], 51: [[107],[52]], 52: [[48,51],[40]], 53: [[30],[28,49,50]], 54: [[11],[92]], 
    55: [[26,33],[]], 56: [[38],[67]], 57: [[30],[68]], 58: [[2],[69]], 59: [[2],[70]], 60: [[2],[71]], 61: [[25],[72]],
    62: [[10],[73]], 63: [[13],[74]], 64: [[18,19],[75]], 65: [[14],[76]], 66: [[28],[77]], 67: [[56],[78]], 
    68: [[57],[83]], 69: [[58],[22]], 70: [[59],[100]], 71: [[60],[98]], 72: [[61],[96]], 73: [[62],[94]], 
    74: [[63],[86]], 75: [[64],[87]], 76: [[65],[88]], 77: [[66],[84]], 78: [[67],[79]], 79: [[78],[34]], 80: [[81],[4]], 
    81: [[24],[80]], 82: [[83],[35]], 83: [[68],[82]], 84: [[77],[85]], 85: [[84],[26]], 86: [[74],[90]], 87: [[75],[91]],
    88: [[76],[89]], 89: [[88],[17]], 90: [[86],[16]], 91: [[87],[20,21]], 92: [[54],[93]], 93: [[92],[8]], 
    94: [[73],[95]], 95: [[94],[6]], 96: [[72],[97]], 97: [[96],[6]], 98: [[71],[99]], 99: [[98],[5]], 100: [[70],[101]],
    101: [[100],[7]], 102: [[42],[103]], 103: [[102],[47]], 104: [[43],[105]], 105: [[104],[48]], 106: [[46,49,50],[107]],
    107: [[106],[51]]
}

## Vascular System Losses

In [494]:
def boundary_loss(output):
    """
    Checks whether the output is 
    physically possible
    """
    # nothing can be negtive
    negative = F.relu(0 - output)
    
    # maximum pressure = 150
    s = int(len(output) / 3 * 2) 
    temp = output.detach().clone()
    temp[s:] = 150
    too_high = F.relu(output - temp)
    
    return torch.mean((negative + too_high)**2)



def orig_cut_loss(output, x):
    """
    Checks whether the outputed values
    of the parameters of the 9 spinal veins 
    remained the same
    """
    
    index = torch.tensor([2, 3, 7, 11, 15, 16, 19, 20, 25]) # vessels indexes in the output
    
    A_loss = torch.mean((output[index] - x[:9])**2)
    u_loss = torch.mean((output[index+107] - x[9:18])**2)
    p_loss = torch.mean((output[index+214] - x[18:])**2)
    
    return A_loss + u_loss + p_loss



def physical_equations_loss(output):
    """
    Checks whether the outputs follows
    the physical laws 
    """
    
    ro = 1.055 # mean blood density
    
    loss1 = 0 # A1u1 = A2u2 + A3u3
    cnt1 = 0
    loss2 = 0 # p1 + ro u1^2 /2 = p2 + ro u2^2 /2 
    cnt2 = 0
    
    for vessel in Graph.keys():

        A1 = output[vessel-1]
        u1 = output[vessel+106]
        p1 = output[vessel+213]
        
        for j in range(2):
            l1 = A1*u1
            for i in Graph[vessel][j]:
                l2 = p1 + 0.5 * ro * u1**2 / 2
                
                Ai = output[i-1]
                ui = output[i+106]
                pi = output[i+213]

                l1 -= Ai*ui
                l2 -= pi + 0.5 * ro * ui**2 / 2
                
                loss2 += l2**2
                cnt2 += 1
            
            loss1 += l1**2
            cnt1 += 1
            
    return loss1/cnt1 + loss2/cnt2

## Model

In [551]:
class CNN(nn.Module):
    def __init__(self, n_channels, size):
        super(CNN, self).__init__()
        
        self.n_channels = n_channels
        self.size = size
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=n_channels, out_channels=20, kernel_size=10),
            nn.MaxPool2d(3, 3),
            nn.Conv2d(in_channels=20, out_channels=16, kernel_size=5),
            nn.AvgPool2d(2, 2)     
        )
        
        '''
        Should model 3 parameters for each of 9 
        spinal veins.
        '''
        self.class_spinal = nn.Sequential(
            nn.Linear(16*46*46, 1000),
            nn.Linear(1000, 500),
            nn.Linear(500, 100),
            nn.Linear(100, 50),
            nn.Linear(50, 27)
        )
        
        self.class_params = nn.Sequential(
            nn.Linear(27, 50),
            nn.Linear(50, 100),
            nn.Linear(100, 400),
            nn.Linear(400, 321)
        )
        
        self.class_disease = nn.Sequential(
            nn.Linear(321, 100),
            nn.ReLU(),
            nn.Linear(100, 6),
            nn.ReLU()
        )
        
        self.class_4 = nn.Sequential(
            nn.Linear(6, 10),
            nn.Linear(10, 4),
            nn.ReLU()
        )
        
        self.class_cut = nn.Sequential(
            nn.Linear(4, 10),
            nn.Linear(10, 2),
            nn.ReLU()
        )
        
       
    
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1) # flattening
        spinal = self.class_spinal(x)
        system = self.class_params(spinal)
        disease = self.class_disease(system)
        class_4 = self.class_4(disease)
        cut = self.class_cut(class_4)
        
#         return (spinal, system, torch.softmax(disease, dim=0),
#                 torch.softmax(class_4, dim=0), torch.softmax(cut, dim=0))
        
        return (spinal, system, disease, class_4, cut)


## Training

In [500]:
x_train = X[:60]
y6_train = y6[:60]
y4_train = y4[:60]
y2_train = y2[:60]

x_test = X[60:]
y6_test = y6[60:]
y4_test = y4[60:]
y2_test = y2[60:]

In [563]:
model = CNN(X[0].shape[1], X[0].shape[2])

mse = nn.MSELoss()
learning_rate = 0.1
optim = torch.optim.SGD(model.parameters(), lr=learning_rate)

yi = 0
for pat in x_train:
    
    spinal, system, disease, class_4, cut = model(pat)
    
    loss = 0
    loss += boundary_loss(spinal[0]) + boundary_loss(system[0])
    print(loss)
    loss += orig_cut_loss(system[0], spinal[0])
    print(loss)
    loss += physical_equations_loss(system[0])
    print(loss)
    loss += mse(disease[0], y6_train[yi])
    print(loss)
    loss += mse(class_4[0], y4_train[yi])
    print(loss)
    loss += mse(cut[0], y2_train[yi])
    print(loss)
    yi += 1
    
    loss.backward()
    optim.step()
    optim.zero_grad()
    

tensor(33.4218, grad_fn=<AddBackward0>)
tensor(188.1968, grad_fn=<AddBackward0>)
tensor(190.8452, grad_fn=<AddBackward0>)
tensor(190.9993, grad_fn=<AddBackward0>)
tensor(191.1252, grad_fn=<AddBackward0>)
tensor(191.6252, grad_fn=<AddBackward0>)
tensor(inf, grad_fn=<AddBackward0>)
tensor(inf, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

KeyboardInterrupt: 

In [552]:
cut

tensor([[-0.1115, -0.2747]], grad_fn=<AddmmBackward>)

## Load marked MRI scans

In [None]:
def convert(file):
    """
    converts the image into an array
    """
    
    image = Image.open(file)
    return image, np.asarray(image)[:, :, :-1]


def get_cut(file):
    """
    returns the share of blue pixels
    in percents
    """
    
    _, arr = convert(file)
    cnt = 0
    pat = [0, 0, 255]
    for i in range(160, 380):
        for j in range(100, 450):
            if list(arr[i, j, :]) == pat:
                cnt += 1
                
    return round(cnt / 512 / 512 * 100, 4)


def get_vein(file):
    """
    Determines what vein out of the 3 
    is on the scan. 
    Returns:
        0 - left and right external iliac veins
        1 - four veins - left and right external
            and internal iliac veins
        2 - 
    """

X = []
patients = sorted(os.listdir('Размеченные снимки МРТ'))
print(patients)
for patient in patients:
    if patient != '.DS_Store':
        print(f'Patient: {patient},', end='\t')
        files = os.listdir('Размеченные снимки МРТ/' + patient)
        x = []
        for file in sorted(files):
            path = 'Размеченные снимки МРТ/' + patient + '/' + file
            if 'png' in path:
                cut = get_cut(path)
                x.append(cut)
        print('Done')
        X.append(x)

## NN

In [None]:

class VascularSystem(nn.Module):
    """
    Input:
        Average cuts of the two main
        spinal veins: A_0, A_1
    Output:
        Average cut, velocity, and pressure
        for 107 vessels of a human.
        [A_i] * 107 , [u_i] * 107, [p_i] * 107.
        A_0, A_1 = input
    """
    def __init__(self):
        super(VascularSystem, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(2, 100),
            nn.Linear(100, 100),
            nn.Linear(100, 321)
        )
        
        
    def forward(self, x):
        return self.layers(x)
    

    
X = torch.randn(100, 2)
model = VascularSystem()

optim = torch.optim.SGD(model.parameters(), lr=0.05)
# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.005)

for epoch in range(1, 1001):
    
    pred = model(X[0])
    
    lp1, lp2 = physical_equations_loss(pred)
    l = lp1 + lp2 + boundary_loss(pred) + orig_cut_loss(pred, x)

    l.backward()

    optim.step()
    optim.zero_grad()

#     scheduler.step()
        
    if epoch % 100 < 1:
        print(f'Epoch: {epoch} \t loss = {l:.10f}')