In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization, NEWNiftiToTensorTransform
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array, collatev2
import winsound
from torchvision.models import resnet18, ResNet18_Weights
from project.utils import memcheck, compute_statistics
from project.evaluation import isa_rpn_metric, Tracker, isa_vit_metric
from project import PatchTruther, AnchorFeeder

In [None]:
t = Tracker()

In [None]:
import logging
from datetime import datetime as dtt
import os

path = 'logs'
os.makedirs(path, exist_ok=True)
os.makedirs('history', exist_ok=True)
rn = dtt.now()
dte = rn.strftime('%b_%d_%Y_%H%M%S')

logger = logging.getLogger('andy')
fh = logging.FileHandler(f'logs/{dte}.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

t.date = rn
t.logfile = f'{dte}.log'

dte

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
t.device = device
device

### Config for fitter

In [None]:
from project.model import RPN

config = {
    'model': RPN(
        input_dim=512,
        output_dim=25,
        image_size=300,
        global_context=True,
        nh=4,
        # pretrained=True
    ).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 2,
    'loss': nn.BCEWithLogitsLoss(),
    # 'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    # 'loss': nn.L1Loss(),
    'lr': 0.0001
}

t.model = 'RPN'
t.model_hyperparams = config['model'].config
t.uses_resnet = config['model'].config['pretrained']
t.optimizer = f"{config['optimizer']}"
t.epochs = config['epochs']
t.loss = f"{config['loss']}"
t.lr = config['lr']

#### Load Pretrained Embedder

#### Load RPN Weights

### Load dataset

In [None]:
ds = Dataset()

data = pd.read_csv('targets.csv')
data.shape

In [None]:
data = data.query('has_microbleed_slice == 1').reset_index(drop=True)
t.only_cmb_slices = True
data

### `DataLoader` Generator

In [None]:
def iqr(data, col):
    q3 = data[col].quantile(0.75)
    q1 = data[col].quantile(0.25)
    iqr = q3-q1
    new = data[(data[col] < (q3 + 1.5*iqr)) & (data[col] > (q1 - 1.5*iqr))]
    return new

In [None]:
from sklearn.model_selection import train_test_split

def make_loaders(data,
                 cohort,
                 batch_size,
                 test_size=0.2,
                 random_state=12,
                 target_shape=(300, 300),
                 rpn_mode=True,
                 logger=None,
                 tracker=t
                ):
    if cohort == 1:
        t.cohort1 = True
    if cohort == 2:
        t.cohort2 = True
    if cohort == 3:
        t.cohort3 = True
    t.batch_size = batch_size
    t.test_size = test_size
    t.target_shape = target_shape
    data = data[data.cohort == cohort]
    # data = iqr(data, 'max_value')
    
    s = f'Creating loaders for Cohort {cohort}\n'

    data_train, data_test = train_test_split(
        data,
        test_size=test_size,
        random_state=random_state
    )

    s += f'TRAIN & TEST: {data_train.shape, data_test.shape}\n'

    paths = data_train.mri.unique().tolist()
    s += f'Total Unique MRI Samples in data_train: {len(paths)}\n'
    
    global_min, global_max = compute_statistics(paths)
    s += f'GLOBAL MIN & MAX {global_min, global_max}\n'

    transform = NEWNiftiToTensorTransform(
        target_shape=target_shape,
        rpn_mode=rpn_mode,
        normalization=(global_min, global_max),
        patch_size=target_shape[0]/(tracker.model_hyperparams['output_dim']**.5)
    )

    trans = NiftiToTensorTransform(
        target_shape=target_shape,
        rpn_mode=False,
        normalization=(global_min, global_max),
    )

    reference_set = VALDODataset(
        cases=data.mri.tolist(),
        masks=data.masks.tolist(),
        target=data.target.tolist(),
        transform=trans
    )

    train_set = VALDODataset(
        cases=data_train.mri.tolist(),
        masks=data_train.masks.tolist(),
        target=data_train.target.tolist(),
        transform=transform
    )
    val_set = VALDODataset(
        cases=data_test.mri.tolist(),
        masks=data_test.masks.tolist(),
        target=data_test.target.tolist(),
        transform=transform
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=collatev2
    )
    val_loader = DataLoader(
        val_set,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=collatev2
    )

    if logger != None:
        logger.info(s)
    else:
        print(s)
    
    return reference_set, train_loader, val_loader

### Fitter

In [None]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        evaluation_metric = {
            'dice_score': [], 
            'precision_score': [], 
            'recall_score': [], 
            'f1_score': [],
            'fpr': []
        }
        counter = 0
        for batch in train_loader:
            # self.log('----------------- BATCH -----------------')
            Y = []
            T = []
            for slices, masks, target, case in batch:
                # x = slices.squeeze(1).repeat(1, 3, 1, 1).float().to(self.device)
                x = slices.squeeze(1).float().to(self.device)
                masks = masks.squeeze(1).float().to(self.device)
                y = self.model(x, target)

                dice_score, precision_score, recall_score, f1_score, fpr = isa_vit_metric((y.sigmoid().clone().numpy(force=True) >= np.median(y.sigmoid().clone().numpy(force=True))), masks[target].unsqueeze(0).clone().numpy(force=True))

                evaluation_metric['dice_score'].append(dice_score)
                evaluation_metric['precision_score'].append(precision_score)
                evaluation_metric['recall_score'].append(recall_score)
                evaluation_metric['f1_score'].append(f1_score)
                evaluation_metric['fpr'].append(fpr)
                # self.log(f'EVAL METS: {iou_score, precision_score, recall_score, f1_score}')
                Y.append(y)
                T.append(masks[target].unsqueeze(0))
            
            losses = self.loss(torch.stack(Y), torch.stack(T))
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += 1
            # if counter % len(batch) == 0:
            self.log(f'Batch:\t{counter}/{len(train_loader)}')
            self.log(f'Batch samples:\t{len(batch)}')
            self.log(f'Current error:\t{losses}\n')
            
            
            loss_history.append(losses.detach().cpu().numpy())
        
        self.log(f'\nTraining Evaluation Metric:')
        self.log(f"Avg Dice: {sum(evaluation_metric['dice_score']) / len(evaluation_metric['dice_score'])}")
        self.log(f"Avg Precision: {sum(evaluation_metric['precision_score']) / len(evaluation_metric['precision_score'])}")
        self.log(f"Avg Recall: {sum(evaluation_metric['recall_score']) / len(evaluation_metric['recall_score'])}")
        self.log(f"Avg F1: {sum(evaluation_metric['f1_score']) / len(evaluation_metric['f1_score'])}")
        self.log(f"Avg FPR: {sum(evaluation_metric['fpr']) / len(evaluation_metric['fpr'])}\n")
        
        return loss_history, evaluation_metric
    def validation(self, val_loader):
        self.model.eval()
        loss_history = []
        evaluation_metric = {
            'dice_score': [], 
            'precision_score': [], 
            'recall_score': [], 
            'f1_score': [],
            'fpr': []
        }
        with torch.inference_mode():
            for batch in val_loader:
                Y = []
                T = []
                for slices, masks, target, case in batch:
                    # x = slices.squeeze(1).repeat(1, 3, 1, 1).float().to(self.device)
                    x = slices.squeeze(1).float().to(self.device)
                    masks = masks.squeeze(1).float().to(self.device)
                    y = self.model(x, target)
                    
                    dice_score, precision_score, recall_score, f1_score, fpr = isa_vit_metric((y.sigmoid().clone().numpy(force=True) >= np.median(y.sigmoid().clone().numpy(force=True))), (masks[target].unsqueeze(0).clone().numpy(force=True) > 0))
                    evaluation_metric['dice_score'].append(dice_score)
                    evaluation_metric['precision_score'].append(precision_score)
                    evaluation_metric['recall_score'].append(recall_score)
                    evaluation_metric['f1_score'].append(f1_score)
                    evaluation_metric['fpr'].append(fpr)
                    
                    Y.append(y)
                    T.append(masks[target].unsqueeze(0))
                losses = self.loss(torch.stack(Y), torch.stack(T))
                loss_history.append(losses.cpu().numpy())
        self.log(f'\nValidations Evaluation Metric:')
        self.log(f"Avg Dice: {sum(evaluation_metric['dice_score']) / len(evaluation_metric['dice_score'])}")
        self.log(f"Avg Precision: {sum(evaluation_metric['precision_score']) / len(evaluation_metric['precision_score'])}")
        self.log(f"Avg Recall: {sum(evaluation_metric['recall_score']) / len(evaluation_metric['recall_score'])}")
        self.log(f"Avg F1: {sum(evaluation_metric['f1_score']) / len(evaluation_metric['f1_score'])}")
        self.log(f"Avg FPR: {sum(evaluation_metric['fpr']) / len(evaluation_metric['fpr'])}\n")
        return loss_history, evaluation_metric
                

In [None]:
fitter = RPNFitter(config, logger=logger)

### Training

In [None]:
refset, tl, vl = make_loaders(
    data=data,
    cohort=1,
    batch_size=15
)

In [None]:
thist, vhist, tmhist, vmhist = fitter.fit(tl, vl)

In [None]:
winsound.Beep(500, 500)
winsound.Beep(500, 500)
winsound.Beep(500, 500)

In [None]:
import seaborn as sns
import numpy as np

th = torch.tensor(np.array(thist))
vh = torch.tensor(np.array(vhist))
# print(th.shape)
sns.lineplot(th.mean(1), label='Training history')
sns.lineplot(vh.mean(1), label='Validation history')

In [None]:
sth = f'history/{dte}_thist.pt'
svh = f'history/{dte}_vhist.pt'
t.saved_thist = sth
t.saved_vhist = svh
torch.save(th, sth)
torch.save(vh, svh)

### Save the weights

In [None]:
s = f'RPN_test15a_weights_{dte}.pt'
s

In [None]:
model = config['model']

In [None]:
t.saved_weights = s
torch.save(model.state_dict(), s)

### Evaluation

In [None]:
h, mh = fitter.validation(vl)

In [None]:
valmets = pd.DataFrame(mh)
mets = valmets.mean()

In [None]:
t.dice = mets.dice_score
t.precision = mets.precision_score
t.recall = mets.recall_score
t.f1 = mets.f1_score
t.fpr = mets.fpr

# Trial

In [None]:
model = fitter.model

In [None]:
sample = next(enumerate(tl))

In [None]:
slices, masks, target, case = sample[1][0]

In [None]:
# x = slices.squeeze(1).repeat(1, 3, 1, 1).float().to(device)
x = slices.squeeze(1).float().to(device)
T = masks.squeeze(1).float().to(device)

In [None]:
y = model(x, target)

In [None]:
y

In [None]:
y.sigmoid() > 0.6

In [None]:
y.sigmoid()

In [None]:
fitter.loss(y, T[target].unsqueeze(0))

In [None]:
masks[target]

In [None]:
isa_vit_metric((y.sigmoid() > y.sigmoid().median()).numpy(force=True), masks[target].numpy())

In [None]:
af = AnchorFeeder(t.model_hyperparams['image_size']/(t.model_hyperparams['output_dim']**.5))

In [None]:
ts = y.sigmoid().argmax().tolist()
ts

In [None]:
mris = af(x[target].unsqueeze(0), ts)

anns = refset.locate_case_by_mri(case)[1].float()
patches = af(anns[target].unsqueeze(0), ts)

In [None]:
import matplotlib.pyplot as plt

f, a = plt.subplots(1, 2, figsize=(10, 4))

f.tight_layout()

sns.heatmap(mris.numpy(force=True), ax=a.flat[0])

sns.heatmap(patches, ax=a.flat[1])

In [None]:
while True:
    winsound.Beep(500, 1000)

# Log Progress

In [None]:
t.notes = '''
no important changes
'''

In [None]:
t()

In [None]:
if os.path.exists('history/runs.csv'):
    print('Merging to old df')
    prev_df = pd.read_csv('history/runs.csv', index_col='date')
    merged = pd.concat([prev_df, t()])
    merged.to_csv('history/runs.csv')
else:
    print('Making new csv file')
    t().to_csv('history/runs.csv')