In [1]:
from google.colab import drive
 
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
%cd '/content/drive/MyDrive/BrainTumorSegmentation_Oct10/'

/content/drive/MyDrive/BrainTumorSegmentation_Oct10


In [3]:
import os
import albumentations as A
import cv2
import numpy as np
from scipy.ndimage.morphology import binary_dilation
import torch
from torch.nn.functional import sigmoid

class EarlyStopping():
    """
    Stops training when loss stops decreasing in a PyTorch module.
    """
    def __init__(self, patience:int = 6, min_delta: float = 0, weights_path: str = 'weights.pt'):
        """
        :param patience: number of epochs of non-decreasing loss before stopping
        :param min_delta: minimum difference between best and new loss that is considered
            an improvement
        :paran weights_path: Path to the file that should store the model's weights
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.weights_path = weights_path

    def __call__(self, val_loss: float, model: torch.nn.Module):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.weights_path)
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def load_weights(self, model: torch.nn.Module):
        """
        Loads weights of the best model.
        :param model: model to which the weigths should be loaded
        """
        return model.load_state_dict(torch.load(self.weights_path))
            

def get_file_row(path):
    """Produces ID of a patient, image and mask filenames from a particular path"""
    path_no_ext, ext = os.path.splitext(path)
    filename = os.path.basename(path)
    
    patient_id = '_'.join(filename.split('_')[:3]) # Patient ID in the csv file consists of 3 first filename segments
    
    return [patient_id, path, f'{path_no_ext}_mask{ext}']

def iou_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):
    """Calculates Intersection over Union for a tensor of predictions"""
    predictions = sigmoid(predictions)
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    union = (predictions | labels).float().sum((1, 2))
    
    iou = (intersection + e) / (union + e)
    return iou

def dice_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):
    """Calculates Dice coefficient for a tensor of predictions"""
    predictions = sigmoid(predictions)
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    return ((2 * intersection) + e) / (predictions.float().sum((1, 2)) + labels.float().sum((1, 2)) + e)    

def BCE_dice(output, target, alpha=0.2):
    bce = torch.nn.functional.binary_cross_entropy(sigmoid(output), target)
    soft_dice = 1 - dice_pytorch(output, target).mean()
    return alpha*bce + (1 - alpha) * soft_dice    



  from scipy.ndimage.morphology import binary_dilation


In [4]:
import os
import time
import albumentations as A
import cv2
import numpy as np
import pandas as pd
from scipy.ndimage.morphology import binary_dilation
from glob import glob
# from data_frame_utils import get_file_row, iou_pytorch, dice_pytorch, BCE_dice, EarlyStopping
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from tqdm import tqdm

  from scipy.ndimage.morphology import binary_dilation


In [5]:
!pip install ml_collections
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ml_collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ml_collections
  Building wheel for ml_collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml_collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94506 sha256=c7d613615e2117b2e0efa2e593c7028a7fd6a7afb7b372f8d85f4848fcdf1106
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml_collections
Installing collected packages: ml_collections
Successfully installed ml_collections-0.1.1


In [6]:
from TransUNet.datasets.dataset_synapse import MriDataset
from TransUNet.networks.vit_seg_modeling import VisionTransformer as ViT_seg
from TransUNet.networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg

In [8]:
def training_loop(writer, epochs, model, train_loader, valid_loader, optimizer, loss_fn, lr_scheduler):
    history = {'train_loss': [], 'val_loss': [], 'val_IoU': [], 'val_dice': []}
    early_stopping = EarlyStopping(patience=7)
    
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        
        running_loss = 0
        model.train()
        for i, data in enumerate(tqdm(train_loader)):
            img, mask = data
            img, mask = img.to(device), mask.to(device)
            # print(img.shape, mask.shape)
            predictions = model(img)
            predictions = predictions.squeeze(1)
            # print(torch.max(predictions))
            loss = loss_fn(predictions, mask)
            running_loss += loss.item() * img.size(0)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        model.eval()
        with torch.no_grad():
            running_IoU = 0
            running_dice = 0
            running_valid_loss = 0
            for i, data in enumerate(valid_loader):
                img, mask = data
                img, mask = img.to(device), mask.to(device)
                predictions = model(img)
                predictions = predictions.squeeze(1)
                running_dice += dice_pytorch(predictions, mask).sum().item()
                running_IoU += iou_pytorch(predictions, mask).sum().item()
                loss = loss_fn(predictions, mask)
                running_valid_loss += loss.item() * img.size(0)
        train_loss = running_loss / len(train_loader.dataset)
        val_loss = running_valid_loss / len(valid_loader.dataset)
        val_dice = running_dice / len(valid_loader.dataset)
        val_IoU = running_IoU / len(valid_loader.dataset)
        
        history['train_loss'].append(train_loss)
        writer.add_scalar("Training/Train loss", train_loss, epoch)
        writer.add_scalar("Training/Val loss", val_loss, epoch)
        writer.add_scalar("Metric/Val IoU", val_IoU, epoch)
        writer.add_scalar("Metric/Val Dice", val_dice, epoch)

        history['val_loss'].append(val_loss)
        history['val_IoU'].append(val_IoU)
        history['val_dice'].append(val_dice)
        print(f'Epoch: {epoch}/{epochs} | Training loss: {train_loss} | Validation loss: {val_loss} | Validation Mean IoU: {val_IoU} '
         f'| Validation Dice coefficient: {val_dice}')
        
        lr_scheduler.step(val_loss)
        if early_stopping(val_loss, model):
            early_stopping.load_weights(model)
            break
    model.eval()
    return history


In [9]:
# https://drive.google.com/drive/folders/1YxaM1yS6m_zuzmGZDSNjXeN6vNFeqyRK
mri_data='/content/drive/MyDrive/Brain_Tumor/input/lgg-mri-segmentation/kaggle_3m/'

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv_path = mri_data+'data.csv'
files_dir =mri_data
file_paths = glob(f'{files_dir}/*/*[0-9].tif')
df = pd.read_csv(csv_path)
imputer = SimpleImputer(strategy="most_frequent")
df = pd.DataFrame(imputer.fit_transform(df), columns=df.columns)

In [11]:
filenames_df = pd.DataFrame((get_file_row(filename) for filename in file_paths), columns=['Patient', 'image_filename', 'mask_filename'])
df = pd.merge(df, filenames_df, on="Patient")
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
test_df, valid_df = train_test_split(test_df, test_size=0.5, random_state=42)

In [12]:
!pip install tensorboardX

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX
  Downloading tensorboardX-2.6-py2.py3-none-any.whl (114 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/114.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6


In [13]:
from tensorboardX import SummaryWriter
writer = SummaryWriter("tensorboard_logs")

In [14]:
transform = A.Compose([A.ChannelDropout(p=0.3),A.RandomBrightnessContrast(p=0.3),A.ColorJitter(p=0.3),])

In [15]:
# train_dataset = MriDataset(train_df, transform)
train_dataset = MriDataset(train_df)
valid_dataset = MriDataset(valid_df)
test_dataset = MriDataset(test_df)

In [16]:
batch_size = 8
img_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)
config_vit = CONFIGS_ViT_seg["R50-ViT-B_16"]
config_vit.n_classes = 1
config_vit.n_skip = 3


In [17]:
!wget https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz 

--2023-05-16 08:09:58--  https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.128.128, 74.125.124.128, 172.217.212.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.128.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 461217452 (440M) [application/octet-stream]
Saving to: ‘R50+ViT-B_16.npz.1’


2023-05-16 08:10:13 (31.0 MB/s) - ‘R50+ViT-B_16.npz.1’ saved [461217452/461217452]



In [18]:
model = ViT_seg(config_vit, img_size=img_size, num_classes=config_vit.n_classes).cuda()
# weight = np.load('../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz')
weight = np.load('R50+ViT-B_16.npz')
model.load_from(weights=weight)

optimizer = Adam(model.parameters(), lr=0.005)
epochs = 100
lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=2,factor=0.1)
loss_fn = BCE_dice
history = training_loop(writer, epochs, model, train_loader, valid_loader, optimizer, loss_fn, lr_scheduler)

100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 28/100 | Training loss: 0.05809656541459778 | Validation loss: 0.09827528504117207 | Validation Mean IoU: 0.8454220998085151 | Validation Dice coefficient: 0.8787152710607496


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 29/100 | Training loss: 0.05496762559579061 | Validation loss: 0.09565662667786672 | Validation Mean IoU: 0.849231250407332 | Validation Dice coefficient: 0.8820126695148015


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 30/100 | Training loss: 0.053607974891671226 | Validation loss: 0.09782857329091867 | Validation Mean IoU: 0.8464731006299036 | Validation Dice coefficient: 0.879276500313969


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 31/100 | Training loss: 0.054419648057992494 | Validation loss: 0.09720521357731293 | Validation Mean IoU: 0.8468482623666019 | Validation Dice coefficient: 0.8800474255771961


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 32/100 | Training loss: 0.05599526549473475 | Validation loss: 0.09630805831876094 | Validation Mean IoU: 0.8478387371968414 | Validation Dice coefficient: 0.8811972731250828


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 33/100 | Training loss: 0.054573893992902 | Validation loss: 0.094902740052696 | Validation Mean IoU: 0.8498085030054642 | Validation Dice coefficient: 0.8829386412087133


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 34/100 | Training loss: 0.05375122954743413 | Validation loss: 0.0966704132445788 | Validation Mean IoU: 0.8476023779077045 | Validation Dice coefficient: 0.8807310160944017


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 35/100 | Training loss: 0.05449831532925865 | Validation loss: 0.09678663690969096 | Validation Mean IoU: 0.8475285691730047 | Validation Dice coefficient: 0.8805896371097888


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 36/100 | Training loss: 0.053800759506549105 | Validation loss: 0.09600555979955293 | Validation Mean IoU: 0.8482875775482694 | Validation Dice coefficient: 0.8815690816459009


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 37/100 | Training loss: 0.054189440330347595 | Validation loss: 0.0962824126547676 | Validation Mean IoU: 0.8480597665754416 | Validation Dice coefficient: 0.8812073141841565


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 38/100 | Training loss: 0.053620310550569476 | Validation loss: 0.09531656396672696 | Validation Mean IoU: 0.8494129423367776 | Validation Dice coefficient: 0.882422973341861


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 39/100 | Training loss: 0.054738486923258854 | Validation loss: 0.09535883420191983 | Validation Mean IoU: 0.8492475671283269 | Validation Dice coefficient: 0.8823637008666992


100%|██████████| 344/344 [03:18<00:00,  1.73it/s]


Epoch: 40/100 | Training loss: 0.05393441031455156 | Validation loss: 0.09694806170722339 | Validation Mean IoU: 0.8473347429501809 | Validation Dice coefficient: 0.880387417744782
