In [1]:
import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report

import nibabel as nib

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR

from torchsummary import summary

%matplotlib inline

They have only one size.

In [2]:
# set(nib.load(b).get_fdata().shape for b in glob('blurred/*.BRIK'))

## Model

In [3]:
from model import *

In [4]:
model = DeepBrain(inplanes=179, planes=3)

  nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu'))


In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(summary(model, input_size=(179, 72, 72, 36)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 3, 72, 72, 36]             540
       BatchNorm3d-2        [-1, 3, 72, 72, 36]               6
              ReLU-3        [-1, 3, 72, 72, 36]               0
            Conv3d-4       [-1, 16, 36, 36, 18]           1,312
       BatchNorm3d-5       [-1, 16, 36, 36, 18]              32
              ReLU-6       [-1, 16, 36, 36, 18]               0
            Conv3d-7       [-1, 32, 36, 36, 18]             544
       BatchNorm3d-8       [-1, 32, 36, 36, 18]              64
            Conv3d-9       [-1, 32, 36, 36, 18]          13,856
      BatchNorm3d-10       [-1, 32, 36, 36, 18]              64
             ReLU-11       [-1, 32, 36, 36, 18]               0
           Conv3d-12       [-1, 32, 36, 36, 18]          27,680
      BatchNorm3d-13       [-1, 32, 36, 36, 18]              64
             ReLU-14       [-1, 32, 36,

## DataLoader

In [6]:
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    
    def __init__(self, datatype='blur', test=False):
        
        if datatype == 'blur':
            self.datafiles = glob('blurred/*.BRIK')
            
        elif datatype == 'scale':
            self.datafiles = glob('scaled/*.BRIK')
            
        else:
            pass
    
        if test: self.datafiles[-5:]
        else: self.datafiles[:-5]
        self.demographics = pd.read_csv('participants_with_runs.csv')
        
    def __getitem__(self, idx):
        
        x = torch.tensor(nib.load(self.datafiles[idx]).get_fdata()).float()
        y = torch.tensor(self.demographics.loc[idx].YB).float()
        
        return x, y
    
    def __len__(self):
        return len(self.datafiles)

In [7]:
dset = MyDataset()
dloader = DataLoader(dset)

## Run

In [8]:
# loss_fn = nn.BCELoss(weight=torch.Tensor([1/44, 1/25]))
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters())
scheduler = MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.97)

In [9]:
losses, ans, accs = [], [], []
for e in range(100):
    
    print(f"Epoch {e}")
    epoch_pred, epoch_true = [], []
    bth_loss = 0
    for i, (x, y) in enumerate(dloader):
        
        x, y = x.permute(0, 4, 1, 2, 3).to(device), y.to(device)
        epoch_true.append(round(float(y)))
        
        optimizer.zero_grad()
        
        y_pred = model.forward(x).to(device)
        epoch_pred.append(float(y_pred.squeeze(1)))
        
        loss = loss_fn(y_pred.reshape(1), y)
        bth_loss += loss.item()
        
        loss.backward()
        optimizer.step()
#         scheduler.step()
        
    torch.cuda.empty_cache()
    
    losses.append(bth_loss / len(dloader))
    ans.append((epoch_true, epoch_pred))
    
#     if e % 10 == 0:
#         plt.title(f'Loss of Epoch {e+1}')
#         plt.plot(losses)
#         plt.grid()
#         plt.close()
        
#         plt.title(f'Accuracy of Epoch {e+1}')
#         plt.plot(accs)
#         plt.grid()
#         plt.close()
    
    print(classification_report(epoch_true, np.array(epoch_pred) > 0.5))
    accs.append(classification_report(epoch_true, np.array(epoch_pred) > 0.5).split('\n')[5].split()[1])

Epoch 0
              precision    recall  f1-score   support

           0       0.66      0.91      0.76        44
           1       0.50      0.16      0.24        25

    accuracy                           0.64        69
   macro avg       0.58      0.53      0.50        69
weighted avg       0.60      0.64      0.57        69

Epoch 1
              precision    recall  f1-score   support

           0       0.63      0.84      0.72        44
           1       0.30      0.12      0.17        25

    accuracy                           0.58        69
   macro avg       0.46      0.48      0.44        69
weighted avg       0.51      0.58      0.52        69

Epoch 2
              precision    recall  f1-score   support

           0       0.63      0.70      0.67        44
           1       0.35      0.28      0.31        25

    accuracy                           0.55        69
   macro avg       0.49      0.49      0.49        69
weighted avg       0.53      0.55      0.54       

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.64      1.00      0.78        44
           1       0.00      0.00      0.00        25

    accuracy                           0.64        69
   macro avg       0.32      0.50      0.39        69
weighted avg       0.41      0.64      0.50        69

Epoch 18
              precision    recall  f1-score   support

           0       0.64      1.00      0.78        44
           1       0.00      0.00      0.00        25

    accuracy                           0.64        69
   macro avg       0.32      0.50      0.39        69
weighted avg       0.41      0.64      0.50        69

Epoch 19
              precision    recall  f1-score   support

           0       0.64      1.00      0.78        44
           1       0.00      0.00      0.00        25

    accuracy                           0.64        69
   macro avg       0.32      0.50      0.39        69
weighted avg       0.41      0.64      0.50        69

E

              precision    recall  f1-score   support

           0       0.67      0.45      0.54        44
           1       0.38      0.60      0.47        25

    accuracy                           0.51        69
   macro avg       0.53      0.53      0.50        69
weighted avg       0.56      0.51      0.51        69

Epoch 43
              precision    recall  f1-score   support

           0       0.71      0.61      0.66        44
           1       0.45      0.56      0.50        25

    accuracy                           0.59        69
   macro avg       0.58      0.59      0.58        69
weighted avg       0.62      0.59      0.60        69

Epoch 44
              precision    recall  f1-score   support

           0       0.77      0.61      0.68        44
           1       0.50      0.68      0.58        25

    accuracy                           0.64        69
   macro avg       0.64      0.65      0.63        69
weighted avg       0.67      0.64      0.64        69

E

              precision    recall  f1-score   support

           0       0.57      0.52      0.55        44
           1       0.28      0.32      0.30        25

    accuracy                           0.45        69
   macro avg       0.43      0.42      0.42        69
weighted avg       0.47      0.45      0.46        69

Epoch 68
              precision    recall  f1-score   support

           0       0.67      0.55      0.60        44
           1       0.39      0.52      0.45        25

    accuracy                           0.54        69
   macro avg       0.53      0.53      0.52        69
weighted avg       0.57      0.54      0.55        69

Epoch 69
              precision    recall  f1-score   support

           0       0.58      0.48      0.53        44
           1       0.30      0.40      0.34        25

    accuracy                           0.45        69
   macro avg       0.44      0.44      0.43        69
weighted avg       0.48      0.45      0.46        69

E

              precision    recall  f1-score   support

           0       0.67      0.23      0.34        44
           1       0.37      0.80      0.51        25

    accuracy                           0.43        69
   macro avg       0.52      0.51      0.42        69
weighted avg       0.56      0.43      0.40        69

Epoch 93
              precision    recall  f1-score   support

           0       0.50      0.20      0.29        44
           1       0.31      0.64      0.42        25

    accuracy                           0.36        69
   macro avg       0.41      0.42      0.36        69
weighted avg       0.43      0.36      0.34        69

Epoch 94
              precision    recall  f1-score   support

           0       0.50      0.05      0.08        44
           1       0.35      0.92      0.51        25

    accuracy                           0.36        69
   macro avg       0.43      0.48      0.30        69
weighted avg       0.45      0.36      0.24        69

E

## Guided Backpropagation

In [8]:
# torch.save(model, 'model.pth')
model= torch.load('model.pth')



In [9]:
import nibabel as nib
from guided_backprop import GuidedBackprop

_ = model.eval()

# Remove LogSoftmax
# model.classifier = nn.Sequential(*list(model.classifier.children())[:-1])

GBP = GuidedBackprop(model)

In [14]:
x, y = dset[1]

In [20]:
x.permute(3, 0, 1, 2).shape

torch.Size([179, 72, 72, 36])

In [21]:
inputs, label = x.permute(3, 0, 1, 2), torch.BoolTensor(0)
input_img = nn.Parameter(torch.FloatTensor(inputs).unsqueeze(0), requires_grad=True).to(device)

guided_grads = GBP.generate_gradients(input_img, label)

export_gradient = np.zeros((3,75,93,81))
# export_gradient[:, 8:-8, 8:-8, :-10] = guided_grads
# nifti_img = nib.Nifti1Image(export_gradient.transpose(1, 2, 3, 0), np.eye(4))
nifti_img = nib.Nifti1Image(guided_grads.transpose(1, 2, 3, 0), np.eye(4))
nifti_img.to_filename('vis_tmp.nii.gz')

tensor([[0.]], device='cuda:0')


In [24]:
nib.load('vis_tmp.nii.gz').get_data()[[nib.load('vis_tmp.nii.gz').get_fdata() != 0]]


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  """Entry point for launching an IPython kernel.
  """Entry point for launching an IPython kernel.


array([], dtype=float32)