## self-attention-cv : illustration of a training process with subvolume sampling for 3d segmentation

The dataset can be found here: https://iseg2019.web.unc.edu/ . i uploaded it and mounted from my gdrive

In [None]:
from google.colab import drive
drive.mount('/gdrive')
import zipfile
root_path = '/gdrive/My Drive/DATASETS/iSeg-2019-Training.zip' 
!echo "Download and extracting folders..."
zip_ref = zipfile.ZipFile(root_path, 'r')
zip_ref.extractall("./")
zip_ref.close()
!echo "Finished"
!pip install torchio
!pip install self-attention-cv

Mounted at /gdrive
Download and extracting folders...
Finished
Collecting torchio
[?25l  Downloading https://files.pythonhosted.org/packages/3d/33/94812ae74a2815fdd5bf7c4e26be75086ebc770309c569380e6f7cc4ad60/torchio-0.18.29-py2.py3-none-any.whl (140kB)
[K     |████████████████████████████████| 143kB 18.9MB/s 
Collecting Deprecated
  Downloading https://files.pythonhosted.org/packages/d4/56/7d4774533d2c119e1873993d34d313c9c9efc88c5e4ab7e33bdf915ad98c/Deprecated-1.2.11-py2.py3-none-any.whl
Collecting SimpleITK<2
[?25l  Downloading https://files.pythonhosted.org/packages/4a/ee/638b6bae2db10e5ef4ca94c95bb29ec25aa37a9d721b47f91077d7e985e0/SimpleITK-1.2.4-cp37-cp37m-manylinux1_x86_64.whl (42.5MB)
[K     |████████████████████████████████| 42.5MB 65kB/s 
Installing collected packages: Deprecated, SimpleITK, torchio
Successfully installed Deprecated-1.2.11 SimpleITK-1.2.4 torchio-0.18.29
Collecting self-attention-cv
  Downloading https://files.pythonhosted.org/packages/69/5b/4163230c657f80a

## Training example

In [None]:
import glob
import torchio as tio
import torch
from torch.utils.data import DataLoader

paths_t1 = sorted(glob.glob('./iSeg-2019-Training/*T1.img'))
paths_t2 = sorted(glob.glob('./iSeg-2019-Training/*T2.img'))
paths_seg = sorted(glob.glob('./iSeg-2019-Training/*label.img'))
assert len(paths_t1) == len(paths_t2) == len(paths_seg)

subject_list = []
for pat in zip(paths_t1, paths_t2, paths_seg):
  path_t1, path_t2, path_seg = pat
  subject = tio.Subject(t1=tio.ScalarImage(path_t1,),
              t2=tio.ScalarImage(path_t2,),
              label=tio.LabelMap(path_seg)) 
  subject_list.append(subject)


transforms = [tio.RescaleIntensity((0, 1)),tio.RandomAffine() ]
transform = tio.Compose(transforms)

subjects_dataset = tio.SubjectsDataset(subject_list, transform=transform)

patch_size = 24
queue_length = 300
samples_per_volume = 50
sampler = tio.data.UniformSampler(patch_size)

patches_queue = tio.Queue(
subjects_dataset,
queue_length,
samples_per_volume,sampler, num_workers=1)

patches_loader = DataLoader(patches_queue, batch_size=16)

In [None]:
from self_attention_cv.Transformer3Dsegmentation import Transformer3dSeg

def crop_target(img, target_size):
  dim = img.shape[-1]
  center = dim//2
  start_dim = center - (target_size//2) - 1
  end_dim = center + (target_size//2)
  return img[:,0,start_dim:end_dim,start_dim:end_dim,start_dim:end_dim].long()

target_size = 3 # as in the paper 
patch_dim = 8
num_epochs = 50
num_classes = 4
model = Transformer3dSeg(subvol_dim=patch_size, patch_dim=patch_dim,
                         in_channels=2, blocks=2, num_classes=num_classes).cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(len(patches_loader))
for epoch_index in range(num_epochs):
  epoch_loss = 0
  for c,patches_batch in enumerate(patches_loader):
    optimizer.zero_grad()
    
    input_t1 = patches_batch['t1'][tio.DATA]  
    input_t2 = patches_batch['t2'][tio.DATA]

    input_tensor = torch.cat([input_t1, input_t2], dim=1).cuda()
    
    
    logits = model(input_tensor) # 8x8x8 the 3d transformer-based approach

    # for the 3d transformer-based approach the target must be cropped again to the desired size
    targets = patches_batch['label'][tio.DATA]  
    
    cropped_target = crop_target(targets, target_size).cuda()

    loss = criterion(logits, cropped_target)
    loss.backward()
    optimizer.step()
    epoch_loss = epoch_loss+loss.cpu().item()

  print(f'epoch {epoch_index} loss {epoch_loss/c}')
    



32
epoch 0 loss 0.8919196542232267
epoch 1 loss 0.6648283805097303
epoch 2 loss 0.6422034237653979
epoch 3 loss 0.5969387196725414
epoch 4 loss 0.5559082502318967
epoch 5 loss 0.49828739656556037
epoch 6 loss 0.48543436681070634
epoch 7 loss 0.3903121284900173
epoch 8 loss 0.38039007951175013
epoch 9 loss 0.2883441626064239
epoch 10 loss 0.35982790421093663
epoch 11 loss 0.2505160081650942
epoch 12 loss 0.2158138483402229
epoch 13 loss 0.20691758676642372
epoch 14 loss 0.20189064626972522
epoch 15 loss 0.24909109192629975
epoch 16 loss 0.18076440347959438
epoch 17 loss 0.23432552934654297
epoch 18 loss 0.23753149663248369
epoch 19 loss 0.21906323085028317
epoch 20 loss 0.20713701904300721
epoch 21 loss 0.22791918559420493
epoch 22 loss 0.20537897133298458
epoch 23 loss 0.20976788646751834
epoch 24 loss 0.19728194228223256
epoch 25 loss 0.21557400706824997
epoch 26 loss 0.16888576995341048
epoch 27 loss 0.1890002822338213
epoch 28 loss 0.20790056818945996
epoch 29 loss 0.186593093278427

## Inference

In [None]:
import torch
import torch.nn as nn
import torchio as tio
patch_overlap = 0
patch_size = 24, 24, 24
target_patch_size = 3

#input sampling
grid_sampler = tio.inference.GridSampler(subject_list[0], patch_size, patch_overlap)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
# target vol sampling
grid_sampler_target = tio.inference.GridSampler(subject_list[0], target_patch_size, patch_overlap)
aggregator = tio.inference.GridAggregator(grid_sampler_target)
target_loader = torch.utils.data.DataLoader(grid_sampler_target, batch_size=4)

model.eval()

with torch.no_grad():
  for patches_batch,target_patches in zip(patch_loader,target_loader):

    input_t1 = patches_batch['t1'][tio.DATA]  
    input_t2 = patches_batch['t2'][tio.DATA]
    input_tensor = torch.cat([input_t1, input_t2], dim=1).float().cuda()

    locations = target_patches[tio.LOCATION]
    logits = model(input_tensor)
    labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
    outputs = labels
    aggregator.add_batch(outputs.type(torch.int32), locations)

  print('output tensor shape:',outputs.shape)
  output_tensor = aggregator.get_output_tensor()
  print(output_tensor.shape)

output tensor shape: torch.Size([4, 1, 3, 3, 3])
torch.Size([1, 144, 192, 256])
