<a href="https://colab.research.google.com/github/Oscar-Mo/big_brain_BC/blob/main/training_(2%2B1)D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training the model

In this notebook we train the UNet model with the modified (2+1)D DoubleConv class. The training set consists of 525 samples, representing 80% of the full dataset. The other 20% is dedicated to testing.

### Setting up

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install bids
!pip install git+https://github.com/npnl/bidsio
import bidsio
bids_loader = bidsio.BIDSLoader(data_entities=[{'subject': '',
                                               'session': '',
                                               'suffix': 'T1w',
                                               'space': 'MNI152NLin2009aSym'}],
                                target_entities=[{'suffix': 'mask',
                                                'label': 'L',
                                                'desc': 'T1lesion'}],
                                data_derivatives_names=['ATLAS'],
                                target_derivatives_names=['ATLAS'],
                                batch_size=2,
                                root_dir='drive/MyDrive/big_brain/split1/train/')

In [None]:
tmp = bids_loader.load_sample(0)
print(f'There are {len(bids_loader)} subjects in our dataset.')
print(f'Every sample loads {len(tmp)} images.')
print(f'Images have these dimensions: {bids_loader.data_shape}')
print(f'Every batch will load {bids_loader.batch_size} samples.')

There are 525 subjects in our dataset.
Every sample loads 2 images.
Images have these dimensions: (197, 233, 189)
Every batch will load 2 samples.


In [None]:
from torch import nn

In [None]:
import torch
from UNet_model_2_plus_1D import UNet

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=1, n_classes=1)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

loss_func = DiceBCELoss()

In [None]:
import torchvision.transforms as T
import torch.nn.functional as F

def downsize_64(batch):
  if (len(batch.shape) == 4):
    batch = batch.unsqueeze(0)
  num_samples = batch.shape[0]
  scale = 3
  batch = batch[:,:,::scale,::scale,::scale]
  batch = batch[:,:,:64,:64]
  zeros = torch.zeros(num_samples,1,64,64,1)
  batch = torch.cat((zeros, batch), 4)
  return batch

def downsize_128(batch):
  num_samples = batch.shape[0]
  d = torch.linspace(-1, 1, 128)
  meshz, meshy, meshx = torch.meshgrid((d, d, d))
  grid = torch.stack((meshx, meshy, meshz), 3)
  grid = grid.unsqueeze(0)
  grid = grid.repeat_interleave(num_samples, dim=0)

  return F.grid_sample(batch, grid, align_corners=True)


In [None]:
i = 0
for epoch in range(7):
    running_loss = 0.0
    for data, label in bids_loader.load_batches():

        data = torch.Tensor(data)
        label = torch.Tensor(label)

        data = downsize_128(data).to(device)
        label = downsize_128(label).to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(data)
        loss = loss_func(outputs, label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print('-', end ="")
        if (i % 10 == 9):
          loss_str = f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10 :.8f}'
          print(loss_str)
          running_loss = 0.0

          f = open("drive/MyDrive/big_brain/models/split1/loss.txt", "a")
          f.write(loss_str + '\n')
          f.close()

          torch.save(model.state_dict(), 'drive/MyDrive/big_brain/models/'+model_name+'/model_weights_'+model_name+'.pth')
        i += 1




----------[1,    10] loss: 1.48061694
----------[1,    20] loss: 1.40900997
----------[1,    30] loss: 1.34488068
----------[1,    40] loss: 1.29072158
----------[1,    50] loss: 1.26043851
----------[1,    60] loss: 1.23049161
----------[1,    70] loss: 1.20349872
----------[1,    80] loss: 1.18199406
----------[1,    90] loss: 1.16449884
----------[1,   100] loss: 1.15069263
----------[1,   110] loss: 1.13675107
----------[1,   120] loss: 1.12389805
----------[1,   130] loss: 1.10656959
----------[1,   140] loss: 1.10770390
----------[1,   150] loss: 1.10486622
----------[1,   160] loss: 1.09663471
----------[1,   170] loss: 1.08299068
----------[1,   180] loss: 1.08667674
----------[1,   190] loss: 1.07984316
----------[1,   200] loss: 1.07051530
----------[1,   210] loss: 1.06200569
----------[1,   220] loss: 1.04223539
----------[1,   230] loss: 1.04669147
----------[1,   240] loss: 1.03013310
----------[1,   250] loss: 1.04143755
----------[1,   260] loss: 1.05720116
----------[2