# PET Residual Neural Network with Linear Layer Pretrianed
This is the resnet model with a linear layer at the end; data is padded with black borders. Uses pretrained [medicalnet](https://github.com/Tencent/MedicalNet).

In [None]:
import os
import glob

import torch
import torch.nn as nn

import pandas as pd
from skimage import io, transform
from sklearn import preprocessing
from torchvision import transforms, utils
import adabound
import numpy as np

import nibabel as nib
import random

In [None]:
#import MedicalNet as mn
import os
os.chdir("MedicalNet/")
import model as mn
os.chdir("../")

In [None]:
dir(mn)

In [None]:
# Use the GPU if there is one, otherwise CPU
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

CSV_DIR = "./scores.csv"
DATA_DIR = "../pet_data/"

MIN = 4401596
MAX = 9233460

## Normalize Data

In [None]:
df = pd.read_csv(CSV_DIR)
norm_df = df[["mmse", "cdr", "ageAtEntry"]]

std_scale = preprocessing.StandardScaler().fit(norm_df)

## Dataset Management
Handle CSV diagnosis/signs and get an iterator of brain scans

In [None]:
def get_scores(ID, date):
    scores = []
    for index, row in df[df["Subject"].str.contains(ID)].iterrows():
        cur_date = int(row["ADRC_ADRCCLINICALDATA ID"].split("_")[-1][1:])
        if cur_date > date:
            if cur_date > date:
                if pd.isna(row["mmse"]): row["mmse"] = 30
                if pd.isna(row["cdr"]): row["cdr"] = 0
                data = {
                    'mmse':  [row["mmse"]],
                    'cdr':  [row["cdr"]],
                    'ageAtEntry': [row["ageAtEntry"]+cur_date/365]
                }

                curr_df = std_scale.transform(pd.DataFrame(data, columns=["mmse", "cdr", "ageAtEntry"]))

                scores.append((cur_date-date, curr_df[0][0], curr_df[0][1], curr_df[0][2]))
    
    return scores

In [None]:
get_scores('OAS30001', 423)  # testing

In [None]:
def get_brains():
    subjects = range(1, 11173)
    for subject in subjects:
        subject_id = str(subject).zfill(4)
        path = f"{DATA_DIR}sub-OAS3{subject_id}/"
        if os.path.isdir(path):
            for session in os.listdir(path):
                file = f"{path}{session}/pet/sub-OAS3{subject_id}_{session}_acq-PIB_pet.nii.gz"
                if os.path.isfile(file):
                    for score in get_scores(f"OAS3{subject_id}", int(session[5:])):
                        yield (file, f"OAS3{subject_id}", int(session[5:])) + score
                else:
                    print(file)

def list_brains():
    subjects = range(1, 11173)
    for subject in subjects:
        subject_id = str(subject).zfill(4)
        path = f"{DATA_DIR}sub-OAS3{subject_id}/"
        if os.path.isdir(path):
            for session in os.listdir(path):
                file = f"{path}{session}/pet/sub-OAS3{subject_id}_{session}_acq-PIB_pet.nii.gz"
                if os.path.isfile(file):
                    yield (file, f"OAS3{subject_id}", int(session[5:]))
                else:
                    print(file)

## Define Dataset
Create an iterable dataset with brain data inheriting from `torch.utils.data.Dataset`. Dataset len is `3594`.

In [None]:
class BrainsDataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        self.brains = []
        for brain_name in get_brains():
            if get_scores(*brain_name[1:3]) != []:
                self.brains.append(brain_name)

    def __len__(self):
        return len(self.brains)
    
    def __getitem__(self, index):
        data = nib.load(self.brains[index][0])
        return self.brains[index][3], self.brains[index][4], self.brains[index][5], self.brains[index][6], self.transform((data.get_fdata()+MIN)/MAX)

## Create Data Preprocessing and Cropping
Crop images to (128, 128, 63, 24).

In [None]:
class Rescale(object):
    """Rescale the image in a sample to a given size."""


    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, brain):
        img = transform.resize(brain, self.output_size)

        return img


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        return torch.from_numpy(sample)


## Define Neural Network
Create a sparse cnn module inheriting from `torch.nn.Module`.

In [None]:
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output
    
class Net(nn.Module):
    def __init__(self, resnet):
        nn.Module.__init__(self)
        self.num_classes = 64
        self.resnet = resnet.to(DEVICE)
        self.avg_pool = nn.AvgPool3d(3, stride=(3,3,3))
        self.fc = nn.Sequential(
            nn.Linear(4096, 256),
            nn.BatchNorm1d(256),
            nn.ELU(),
            nn.Linear(256, self.num_classes),
        ).to(DEVICE)
        
        self.linear = nn.Sequential(
            nn.Linear(self.num_classes*24+2, 512),
            nn.BatchNorm1d(512),
            torch.nn.ELU(),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            torch.nn.ELU(),
            nn.Linear(128, 16),
            nn.BatchNorm1d(16),
            torch.nn.ELU(),
            nn.Linear(16, 2)
        ).to(DEVICE)
        self.max_pool = nn.MaxPool3d(2, stride=(2, 2, 1))

    def forward(self, brain, days_ahead, age):
        c_out = self.resnet(self.max_pool(brain[:, None, :, :, :, 0]))
        c_out = self.avg_pool(c_out).view(c_out.size(0), -1)
        c_out = self.fc(c_out).to(DEVICE)
        for i in range(brain.shape[-1]-1):
            output = self.resnet(self.max_pool(brain[:, None, :, :, :, i]))
            output = self.avg_pool(output).view(output.size(0), -1)
            c_out = torch.cat([c_out, self.fc(output.view(output.size(0), -1))], 1)
            del output
        c_out = torch.cat([torch.stack([days_ahead, age]).permute(1, 0), c_out], 1).to(DEVICE)
        r_out = torch.cuda.FloatTensor(self.linear(c_out)).to(DEVICE)
        return r_out


In [None]:
def train(model, optimizer, criterion, criterion_test, train_loader, test_loader, writer):
    step_num = 0
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print('-' * 10)
        running_loss = 0
        train_iter = iter(train_loader)
        
        for i, data_brains in enumerate(train_loader):
            step_num += 1
            scan = data_brains[4].type(torch.cuda.FloatTensor).to(DEVICE)
            days_ahead = data_brains[0].type(torch.cuda.FloatTensor).to(DEVICE)
            age = data_brains[3].type(torch.cuda.FloatTensor).to(DEVICE)
            real_values = torch.stack([data_brains[1], data_brains[2]]).permute(1, 0).to(DEVICE)

            outputs = model(scan, days_ahead, age)

            loss = criterion(outputs, real_values)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()

            running_loss += loss.item()
            if i % 5 == 4:
                print(f"[{epoch + 1} {i + 1}] loss: {running_loss/5}")
                writer.add_scalar("Training Loss", running_loss/5, step_num)
                running_loss = 0


                with torch.no_grad():
                    
                    try:
                        test_data = next(train_iter)
                    except StopIteration:
                        train_iter = iter(train_loader)
                        test_data = next(train_iter)

                    _scan = test_data[4].type(torch.cuda.FloatTensor).to(DEVICE)
                    _days_ahead = test_data[0].type(torch.cuda.FloatTensor).to(DEVICE)
                    _age = test_data[3].type(torch.cuda.FloatTensor).to(DEVICE)
                    _real_values = torch.stack([test_data[1], test_data[2]]).permute(1, 0).to(DEVICE)
                    _outputs = model(_scan, _days_ahead, _age)
                    writer.add_scalar("Test Loss", criterion_test(_outputs, _real_values), step_num)

                torch.cuda.empty_cache()

        print("Saving model")
        torch.save(model.state_dict(), f"model_pretrained_cnn/new_model_{epoch}_cnn.pt")

In [None]:
class Params():
    def __init__(self):
        self.model = "resnet"
        self.model_depth = 34
        self.input_W = 64
        self.input_H = 64
        self.input_D = 62
        self.resnet_shortcut = "A"
        self.no_cuda = False
        self.gpu_id = [0]
        self.n_seg_classes = 128
        self.phase = "train"
        self.pretrain_path = os.getcwd()+"/MedicalNet/pretrain/resnet_34_23dataset.pth"
        self.new_layer_names = ['upsample1', 'cmp_layer3', 'upsample2', 'cmp_layer2', 'upsample3', 'cmp_layer1', 'upsample4', 'cmp_conv1', 'conv_seg']

## Hyperparameters

In [None]:
NUM_EPOCHS = 50
BATCH_SIZE = 2

## Get Data

In [None]:
scale = Rescale((128, 128, 63, 24))
composed = transforms.Compose([scale, ToTensor()])

dataset = BrainsDataset(composed)

NUM_INSTANCES = len(dataset)
TEST_RATIO = 0.2
TEST_SIZE = int(NUM_INSTANCES * TEST_RATIO)
TRAIN_SIZE = NUM_INSTANCES - TEST_SIZE

In [None]:
train_data, test_data = torch.utils.data.random_split(dataset, (TRAIN_SIZE, TEST_SIZE))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle = False)

In [None]:
import os
os.getcwd()

In [None]:
from torch.utils.tensorboard import SummaryWriter
model, _ = mn.generate_model(Params())

model = Net(model)

optimizer = torch.optim.AdamW(model.parameters())  
criterion = torch.nn.MSELoss()
criterion_test = torch.nn.MSELoss()
writer = SummaryWriter()

loss = train(model, optimizer, criterion, criterion_test, train_loader, test_loader, writer)

In [None]:
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
for i, data_brains in enumerate(train_loader):
    # data_brains = days_ahead, mmse, cdr, age, scan
    # output = (days ahead, mmse, cdr)
    # scan = [x,y,z,t]
    scan = data_brains[4].type(torch.cuda.FloatTensor).to(DEVICE)
    days_ahead = data_brains[0].type(torch.cuda.FloatTensor).to(DEVICE)
    # mmse = mmse.type(torch.cuda.FloatTensor).to(DEVICE)
    # cdr = cdr.type(torch.cuda.FloatTensor).to(DEVICE)
    age = data_brains[3].type(torch.cuda.FloatTensor).to(DEVICE)

    writer.add_graph(model, (scan, days_ahead, age))
    break
writer.close()

In [None]:
torch.save(model, f"model_cnn_normal_final_new.pt")

In [None]:
torch.cuda.empty_cache()

## Test model

In [None]:
with torch.no_grad():
    for i, data_brains in enumerate(train_loader):
        # data_brains = days_ahead, mmse, cdr, age, scan
        # output = (days ahead, mmse, cdr)
        # scan = [x,y,z,t]
        scan = data_brains[4].type(torch.cuda.FloatTensor).to(DEVICE)
        days_ahead = data_brains[0].type(torch.cuda.FloatTensor).to(DEVICE)
        # mmse = mmse.type(torch.cuda.FloatTensor).to(DEVICE)
        # cdr = cdr.type(torch.cuda.FloatTensor).to(DEVICE)
        age = data_brains[3].type(torch.cuda.FloatTensor).to(DEVICE)
        real_values = torch.stack([data_brains[1], data_brains[2]]).permute(1, 0).to(DEVICE)

        # print(scan.shape)
        model.eval()
        outputs = model(scan[None, 0], days_ahead[None, 0], age[None, 0])
        print(f"Outputs: {outputs}")
        print(f"Real Values: {real_values}")

        data_pred = {
            'mmse':  [outputs[0][0]],
            'cdr':  [outputs[0][1]]
        }

        data_pred = std_scale.inverse_transform(pd.DataFrame(data_pred, columns=["mmse", "cdr", "ageAtEntry"]))

        data_real = {
            'mmse':  [real_values[0][0]],
            'cdr':  [real_values[0][1]]
        }

        data_real = std_scale.inverse_transform(pd.DataFrame(data_real, columns=["mmse", "cdr", "ageAtEntry"]))
    
        print(data_pred)
        print(data_real)

In [None]:
model.train()
list(model.parameters())

# Get Dimensions of Data
Here we get the dimensions of the data to be cropped

This distribution goes something like this:

```
{(128, 128, 63, 51): 88,
 (128, 128, 63, 25): 60,
 (128, 128, 63, 24): 11,
 (128, 128, 109, 26): 549,
 (128, 128, 63, 52): 76,
 (128, 128, 63, 53): 41,
 (128, 128, 63, 26): 51,
 (128, 128, 63, 50): 34,
 (256, 256, 127, 26): 5,
 (128, 128, 63, 41): 2,
 (128, 128, 2592): 1,
 (128, 128, 74, 25): 1,
 (128, 128, 63, 49): 11,
 (256, 256, 127): 2,
 (128, 128, 63, 23): 4,
 (128, 128, 2832): 4,
 (128, 128, 63, 34): 2,
 (128, 128, 47, 49): 2,
 (128, 128, 63, 48): 1,
 (128, 128, 109, 4): 1,
 (128, 128, 2827): 1,
 (128, 128, 47, 50): 1,
 (128, 128, 109, 20): 2,
 (128, 128, 47, 52): 1,
 (128, 128, 109, 17): 1,
 (128, 128, 109, 6): 1,
 (128, 128, 63, 20): 1,
 (128, 128, 47, 51): 1,
 (128, 128, 63, 45): 1}
```

In [None]:
dimensions = {}

for brain in get_brains():
    shape = nib.load(brain[0]).get_fdata().shape
    if shape in dimensions:
        dimensions[shape] += 1
    else:
        dimensions[shape] = 1

print(dimensions)

nums = 0
for num in dimensions.values():
    nums+=num
nums

In [None]:
smallest, largest = 0, 0

for brain in list_brains():
    data = nib.load(brain[0]).get_fdata()
    
    if np.min(data) < smallest:
        smallest = np.min(data)
    if int(np.max(data)) > largest:
        largest = np.max(data)

(smallest, largest)
# should be (-4401596.0, 4831864.0)

In [None]:
for brain in get_brains():
    shape = nib.load(brain[0]).get_fdata().shape
    if len(shape) == 3:
        print(brain)

In [None]:
bad_data = ['../data/sub-OAS30065/ses-d0553/pet/sub-OAS30065_ses-d0553_acq-PIB_pet.nii.gz',
'../data/sub-OAS30229/ses-d0101/pet/sub-OAS30229_ses-d0101_acq-PIB_pet.nii.gz',
'../data/sub-OAS30253/ses-d3948/pet/sub-OAS30253_ses-d3948_acq-PIB_pet.nii.gz',
'../data/sub-OAS30332/ses-d0091/pet/sub-OAS30332_ses-d0091_acq-PIB_pet.nii.gz',
'../data/sub-OAS30472/ses-d1278/pet/sub-OAS30472_ses-d1278_acq-PIB_pet.nii.gz',
'../data/sub-OAS30498/ses-d0120/pet/sub-OAS30498_ses-d0120_acq-PIB_pet.nii.gz',
'../data/sub-OAS30588/ses-d1639/pet/sub-OAS30588_ses-d1639_acq-PIB_pet.nii.gz',
'../data/sub-OAS30896/ses-d1601/pet/sub-OAS30896_ses-d1601_acq-PIB_pet.nii.gz'
]