# End to End model draft 1

## Todo:
- Show visuaization on the difference of using transform and normalization 
- Stratified Splitting

## Imports

In [1]:
import torch
import torch.nn as nn
import pandas as pd 
import nibabel as nib
import datetime
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from project.preprocessing import z_score_normalization, NiftiToTensorTransform, get_transform
# from project.training import split_train_val_datasets
from project.model import VisionTransformer, ISAVIT
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader, random_split
from project.utils import collatev2
from sklearn.model_selection import train_test_split
from project.model.feeder import Feeder
from project.utils import memcheck
from project import Fitter
from project.utils import compute_statistics

  check_for_updates()


## Logger

In [2]:
import logging
logger = logging.getLogger(f'Nigel_EndToEnd_log_{datetime.datetime.now().strftime("%d%m%y%H%M%S")}')
fh = logging.FileHandler(f'logs/nigel_EndTooEnd{datetime.datetime.now().strftime("%d%m%y%H%M%S")}.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# =================================

## Import dataset and select target slices

**Gets the whole dataset from a specific location**

In [3]:
ds = Dataset()

### Generate targets

In [4]:
# Lood the raw rmi and maks of the data based on the number of cohors in the dataset
## Data  ds.load_raw_mri() can also be changed to ds.load_skullstripped_mri() for the preprocessed data
mri = ds.load_raw_mri()
masks = ds.load_cmb_masks()
# Get the slices of the MRIs 
slices = [nib.load(x).get_fdata().shape[2] for x in mri]

## Create a standard dataframe of the unprocessed data 
standard_df = pd.DataFrame({
    'mri': mri,
    'masks': masks,
    'slices': slices
})

In [5]:
standard_df.head(3)

Unnamed: 0,mri,masks,slices
0,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,35
1,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,35
2,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,35


**Function to generate all the target slices for each case**

In [6]:
def generate_target_slice(mri, masks, slices):
    if len(mri) != len(masks):
        print(f'Unequal amount of mri cases to cmb masks\t{len(mri)} to {len(masks)}')
    if len(mri) != len(slices):
        print(f'Unequal amount of mri cases to case slice counts\t{len(mri)} to {len(slices)}')

    # ls = [(mri[i], masks[i], target) for i in range(len(mri)) for target in range(slices[i])]
    ls = []
    
    for i in range(len(mri)):
        mask_data = nib.load(masks[i]).get_fdata()
        has_microbleed_case = 1 if mask_data.max() > 0 else 0
            
        for target in range(slices[i]):
            has_microbleed_slice = 1 if mask_data[:, :, target].max() > 0 else 0
            ls.append((
                mri[i], 
                masks[i], 
                target, 
                has_microbleed_case, 
                has_microbleed_slice
            ))
            
    df = pd.DataFrame(ls, columns=[
        'mri',
        'masks', 
        'target', 
        'has_microbleed_case', 
        'has_microbleed_slice'
    ])
    # ls = [(case, target) for case, slices in zip(case, slices) for target in range(slices)]
    return df

# df = generate_target_slice(mri, masks, slices)

In [7]:
# df.head(3)

In [8]:
# df.info()

**Double check if all the slices matches with the raw dataframe count**

In [9]:
# ar_targets = df.groupby('mri').target.max()
# ar_slices = standard_df.groupby('mri').slices.max()
# (ar_targets == (ar_slices - 1)).all()

#### Export as metadata(CSV file)

In [10]:
# df.to_csv('targets.csv', index=False)

# ===============================================

In [11]:
data = pd.read_csv('targets.csv')

In [12]:
ch1 = ds.load_raw_mri(1)
data = data[data.mri.isin(ch1)]
data.shape

(385, 5)

In [13]:
data.head(5)

Unnamed: 0,mri,masks,target,has_microbleed_case,has_microbleed_slice
0,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,0,1,0
1,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,1,1,0
2,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,2,1,0
3,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,3,1,0
4,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,c:\Users\nigel\Documents\Thesis\Dataset\VALDO_...,4,1,0


In [14]:
data.shape  

(385, 5)

**7987 cases with 3 columns**
- Columns
    - mri
    - masks
    - target slice


# RPN

### Training Embedder

In [15]:
target_shape = (400, 400)
rpn_mode = False

batch_size = 1
collate_fn = collatev2

image_size = 400
input_output_dim = 2500

epochs = 10
# epochs = 1
loss = nn.MSELoss()
lr = 0.00075

In [16]:
transform = NiftiToTensorTransform(target_shape = target_shape, rpn_mode=rpn_mode) # Hanggang dito lang kaya ng GPU mem ko

cases = data.mri
masks = data.masks
target = data.target

In [17]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=transform,
)

In [18]:
train_dataset, val_dataset = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])

train_loader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=batch_size,
    collate_fn=collatev2
)

val_loader = DataLoader(
    val_dataset, 
    shuffle=True, 
    batch_size=batch_size,
    collate_fn=collatev2
)

In [19]:
from project.model import SliceEmbedding, Autoencoder, Decoder

en = SliceEmbedding(
    image_size=image_size,
    output_dim=input_output_dim,
)

de = Decoder(
    image_size=image_size,
    input_dim=input_output_dim
)

config = {
    'model': Autoencoder(en, de).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': epochs,
    'loss': loss,
    'lr': lr
}

144


In [20]:
class AEFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        print("Training========")
        for batch in train_loader:
            Y = []
            T = []
            for slices, masks, target, case in batch:
                if slices is None:
                    logger.error(f'CASE NOT WORKING: {case}')
                    continue
                x = slices.squeeze(1).float().to(self.device)
                y = self.model(x)
                logger.info(f'MEMORY after X, Y, T to device\t{memcheck()}')
                losses = self.loss(y, x)
                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()
                loss_history.append(losses.detach().cpu().numpy())
            
            counter += len(batch)
            if counter % 100 == 0:
                logger.info(f'Progress:\t{counter}/{len(dataset)}')
                logger.info(f'Current error:\t{losses}')
            
            # del losses, Y, T
            # torch.cuda.empty_cache()
            # logger.info(f'MEMORY after CLEARING MEMORY\t{memcheck()}')
            
        return loss_history
    def validation(self, val_loader):
        print("Validating========")
        self.model.eval()
        with torch.inference_mode():
            loss_history = []
            counter = 0
            for batch in val_loader:
                Y = []
                T = []
                for slices, masks, target, case in batch:
                    if slices is None:
                        logger.error(f'CASE NOT WORKING: {case}')
                        continue
                    x = slices.squeeze(1).float().to(self.device)
                    y = self.model(x)
                    losses = self.loss(y, x)
                    loss_history.append(losses.detach().cpu().numpy())

                    print("Vlidation: ",x.shape, y.shape)
                counter += len(batch)
                if counter % 100 == 0:
                    logger.info(f'Progress:\t{counter}/{len(dataset)}')
                    logger.info(f'Current error:\t{losses}')
                
                # del losses, Y, T
                # torch.cuda.empty_cache()
                # logger.info(f'MEMORY after CLEARING MEMORY\t{memcheck()}')
            
        return loss_history

In [21]:
Autoenfitter = AEFitter(config)

### Encoder Training

In [None]:
# ae_hist = Autoenfitter.fit(train_loader, val_loader)



torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400, 400])
torch.Size([35, 1, 2500])
torch.Size([35, 1, 400

In [37]:
# ae_t_hist, ae_v_hist = ae_hist

### RPN Proper training

**Get the latest weights of the encoder**

In [38]:
import glob
import os.path

folder_path = r"weights/"
file_type = r"\*pt"
files = glob.glob(folder_path + file_type)
weight = max(files, key=os.path.getctime)

print(weight)

weights\Encoder_weights_251024111249.pt


In [39]:
from project.model import RPN

config = {
    'model': RPN(
        input_dim=2500,
        output_dim=4,
        image_size=300
    ).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 1,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.00001
}

model = config['model']

#Load the embedder weights
# model.embedder.load_state_dict(torch.load(weight))


16




### RPN Fitter 

In [40]:
class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        for batch in train_loader:
            Y = []
            T = []
            for slices, masks, target, case in batch:
                num_slices = slices.shape[0]
                x = slices.squeeze(1).float().to(self.device)
                masks = masks.squeeze(1).float().to(self.device)
                y = self.model(x, target)
                Y.append(y)
                T.append(masks[target])
            
            losses = self.loss(torch.stack(Y), torch.stack(T))
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += len(batch)
            if counter % 100 == 0:
                logger.info(f'Progress:\t{counter}/{len(dataset)}')
                logger.info(f'Current error:\t{losses}')
            loss_history.append(losses.detach().cpu().numpy())
        return loss_history
    

fitter = RPNFitter(config)

### RPN Training

In [41]:
rpn_hist = fitter.fit(train_loader, val_loader)

KeyboardInterrupt: 

In [40]:
# rpn_train_history, rpn_val_history = rpn_hist

# ViT

In [41]:
ds = Dataset()

cases = ds.load_raw_mri()
masks = ds.load_cmb_masks()
data = pd.read_csv('targets.csv')
data.shape

ch1 = ds.load_raw_mri(1)
data = data[data.mri.isin(ch1)]
data.shape

cases = data.mri
masks = data.masks
target = data.target

In [None]:
target_shape = (512, 512)
global_min, global_max = compute_statistics(cases)


normalized_transform = NiftiToTensorTransform(
    target_shape = target_shape, 
    normalization=(global_min, global_max)
)

dataset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=normalized_transform,
)

In [None]:
resize = get_transform(
    height=16,
    width=16,
    p=1.0,
    rpn_mode=False
)

feeder = Feeder(resize)

In [None]:
train_dataset, val_dataset = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])

train_loader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=batch_size,
    collate_fn=collatev2
)

val_loader = DataLoader(
    val_dataset, 
    shuffle=True, 
    batch_size=batch_size,
    collate_fn=collatev2
)

In [None]:
feedset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=normalized_transform,
    # normalization=z_score_normalization,
)

### ViT Training Proper

In [42]:
config = {
    'model': ISAVIT(
        d_model=1000,
        patch_size=16,
        dim_ff=2000
    ).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 10,
    # 'loss': nn.BCEWithLogitsLoss(),
    'loss': nn.CrossEntropyLoss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0000001
}

### ViT Feeder

In [43]:
class ViTFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        for batch in train_loader:
            Y = []
            T = []
            for slices, masks, target, case in batch:
                num_slices = slices.shape[0]
                
                regions = feedset.locate_case_by_mri(case)
                bboxes = regions[1].view(regions[1].shape[0], -1)
                bbox = bboxes[target].int().tolist()
                
                x = feeder(slices, bbox, 16)
                t = feeder(masks, bbox, 16)

                x = x.view(num_slices, 1, -1).float().to(self.device)
                masks = t.view(num_slices, 1, -1).float().to(self.device)
                
                y = self.model(x, target)
                Y.append(y)
                T.append(masks[target])
            
                losses = self.loss(torch.stack(Y), torch.stack(T))
                # losses = self.loss(torch.stack(T), torch.stack(Y))
                loss_history.append(losses.detach().cpu().numpy())
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += len(batch)
            if counter % 100 == 0:
                logger.info(f'Progress:\t{counter}/{len(dataset)}')
                logger.info(f'Current error:\t{losses}')
        
        return loss_history
    
    def validation(self, val_loader):
        self.model.eval()
        with torch.no_grad():
            loss_history = []
            counter = 0
            for batch in val_loader:
                Y = []
                T = []
                for slices, masks, target, case in batch:
                    num_slices = slices.shape[0]
                    
                    regions = feedset.locate_case_by_mri(case)
                    bboxes = regions[1].view(regions[1].shape[0], -1)
                    bbox = bboxes[target].int().tolist()
                    
                    x = feeder(slices, bbox, 16)
                    t = feeder(masks, bbox, 16)

                    x = x.view(num_slices, 1, -1).float().to(self.device)
                    masks = t.view(num_slices, 1, -1).float().to(self.device)
                    
                    y = self.model(x, target)
                    Y.append(y)
                    T.append(masks[target])
                
                    losses = self.loss(torch.stack(Y), torch.stack(T))
                    # losses = self.loss(torch.stack(T), torch.stack(Y))
                    loss_history.append(losses.detach().cpu().numpy())
                counter += len(batch)
                if counter % 100 == 0:
                    logger.info(f'Progress:\t{counter}/{len(dataset)}')
                    logger.info(f'Current error:\t{losses}')
            return loss_history
        

fitter = ViTFitter(config)

### ViT Training

In [None]:
vit_hist = fitter.fit(train_loader, val_loader)

In [48]:
vit_train_hist, vit_val_hist = vit_hist

In [None]:
sns.lineplot(torch.tensor(np.array(vit_train_hist)).mean(1), label='Training')
sns.lineplot(torch.tensor(np.array(vit_val_hist)).mean(1), label='Validation')

plt.title("Training and Validation History")
plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.legend()

plt.show()

In [None]:
model = config['model']

In [None]:
import datetime

s = f'vit_weights/ViT_test12_weights_{datetime.datetime.now().strftime("%d%m%y%H%M%S")}.pt'
s

In [None]:

model = fitter.model
sample = next(enumerate(dloader))
slices, masks, target, case = sample[1][3]

num_slices = slices.shape[0]

regions = feedset.locate_case_by_mri(case)
bboxes = regions[1].view(regions[1].shape[0], -1)
bbox = bboxes[target].int().tolist()

x = feeder(slices, bbox, 16)
t = feeder(masks, bbox, 16)

x = x.view(num_slices, 1, -1).float().to(device)
masks = t.view(num_slices, 1, -1).float().to(device)

y = model(x, target)
# fitter.loss(y,x)

sns.heatmap(y.detach().cpu())
