In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import os
import shutil
import cv2
import numpy as np
from os import path

In [2]:
from torch.nn.modules.pooling import MaxPool2d

class DoubleConv(nn.Module):
  def __init__(self, input, output):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(input, output, 3, 1, 1, bias = False),
        nn.BatchNorm2d(output),
        nn.ReLU(inplace=True),
        nn.Conv2d(output, output, 3, 1, 1, bias = False),
        nn.BatchNorm2d(output),
        nn.ReLU(inplace=True))
  def forward(self, x):
    return self.conv(x)


class UNET(nn.Module):
  def __init__(self,input = 3, output = 1, features = [64, 128, 256, 512]):
    super(UNET, self).__init__()
    self.up = nn.ModuleList()
    self.down = nn.ModuleList()
    self.pool = MaxPool2d(kernel_size=2, stride=2)


    #UNET left side (down)
    for feature in features:
      self.down.append(DoubleConv(input, feature))
      input = feature

    #UNET right side (up)
    for feature in reversed(features):
      self.up.append(nn.ConvTranspose2d(
          feature*2, feature, kernel_size=2, stride=2,
      ))
      self.up.append(DoubleConv(feature*2, feature))

    self.bottleneck = DoubleConv(features[-1], features[-1]*2)
    self.final = nn.Conv2d(features[0], output, kernel_size=1)

  def forward(self, x):
    skip_connections = []

    for down in self.down:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    skip_connections  = skip_connections[::-1]

    for idx in range(0, len(self.up), 2):
      x = self.up[idx](x)
      skip_connection = skip_connections[idx//2]
      if x.shape != skip_connection.shape:
        x = TF.resize(x, size=skip_connection.shape[2:])



      concat_skip = torch.cat((skip_connection, x), dim = 1)
      x = self.up[idx+1](concat_skip)

    return self.final(x)

In [3]:
from google.colab import files
files.upload()
!rm -r ~/.kaggle
!mkdir ~/.kaggle
!mv ./kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation

Saving kaggle.json to kaggle.json
rm: cannot remove '/root/.kaggle': No such file or directory
Downloading lgg-mri-segmentation.zip to /content
 99% 706M/714M [00:05<00:00, 128MB/s]
100% 714M/714M [00:06<00:00, 124MB/s]


In [4]:
import zipfile
zip_ref = zipfile.ZipFile('lgg-mri-segmentation.zip', 'r')
zip_ref.extractall('/content')
zip_ref.close()

In [5]:
os.remove('/content/kaggle_3m/README.md')
os.remove('/content/kaggle_3m/data.csv')

In [6]:
main_folder = '/content/kaggle_3m'
folders = os.listdir(main_folder)

if path.exists('/content/image_dir') == False:
  os.mkdir('/content/image_dir')


if path.exists('/content/mask_dir') == False:
  os.mkdir('/content/mask_dir')

for folder in folders:
    folder_path = os.path.join(main_folder, folder)
    files = [file for file in os.listdir(folder_path) if file.lower() != "data.csv"]

    for file in files:
        file_path = os.path.join(folder_path, file)

        if 'mask' not in file:
            shutil.move(file_path, "image_dir")
        else:
            shutil.move(file_path, "mask_dir")

In [7]:
if path.exists('/content/image_img') == False:
  os.mkdir('/content/train_img')
if path.exists('/content/train_mask') == False:
  os.mkdir('/content/train_mask')

if path.exists('/content/val_img') == False:
  os.mkdir('/content/val_img')
if path.exists('/content/val_mask') == False:
  os.mkdir('/content/val_mask')


In [8]:
img_dir = "/content/image_dir"
msk_dir = "/content/mask_dir"

In [9]:
ct = 0
for files in os.listdir(img_dir):
  file = os.path.join(img_dir, files)
  ct += 1
  if ct <= 3500:
    shutil.move(file, "/content/train_img")
  else:
    shutil.move(file, "/content/val_img")

In [10]:
ct = 0
for files in os.listdir(msk_dir):
  file = os.path.join(msk_dir, files)
  ct += 1
  if ct <= 3500:
    shutil.move(file, "/content/train_mask")
  else:
    shutil.move(file, "/content/val_mask")

In [41]:
from PIL import Image
from torch.utils.data import Dataset
mask_dir = "/content/mask_dir"
image_dir = "/content/image_dir"


class MRIData(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)
    self.masks = os.listdir(mask_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.masks[index].replace('_mask.jpg', "_mask.tif"))
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
    mask[mask == 255.0] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]



In [12]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/content/train_img"
TRAIN_MASK_DIR = "/content/train_mask"
VAL_IMG_DIR = "/content/val_img"
VAL_MASK_DIR = "/content/val_mask"


In [29]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(device=DEVICE)
    targets = targets.float().unsqueeze(1).to(device=DEVICE)

    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fn(predictions, targets)


    optimizer.zero_grad()
    scaler.scaler(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    loop.set_postfix(loss=loss.item())

def main():
  train_transform = A.Compose(
      [
          A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
          A.Rotate(limit=35, p=1.0),
          A.HorizontalFlip(p=0.5),
          A.VerticalFlip(p=0.1),
          A.Normalize(
              mean = [0.0, 0.0, 0.0],
              std = [1.0, 1.0, 1.0],
              max_pixel_value = 255.0
          ),
          ToTensorV2()

      ]
  )

  val_transforms = A.Compose(
      [
          A.Resize(height=IMAGE_HEIGHT, width = IMAGE_WIDTH),
          A.Normalize(
              mean = [0.0, 0.0, 0.0],
              std = [1.0, 1.0, 1.0],
              max_pixel_value = 255.0
          ),
          ToTensorV2()
      ]
  )


  model = UNET(input = 3, output = 1).to(DEVICE)
  loss_fn = nn.BCEWithLogitsLoss()
  optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

  train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      VAL_IMG_DIR,
      VAL_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      val_transforms,
      NUM_WORKERS,
      PIN_MEMORY
  )

  scaler = torch.cuda.amp.GradScaler()
  for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

In [19]:
def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'):
  print("-> Saving Checkpoint")
  torch.save(state,filename)

def load_checkpoint(checkpoint, model):
  print("Loading Checkpoint ->")
  model.load_state_dict(checkpoint["state_dict"])

In [20]:
from torch.utils.data import DataLoader
import torchvision

In [21]:
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = MRIData(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = MRIData(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [22]:
if path.exists('/content/saved_images') == False:
  os.mkdir('/content/saved_images')

In [23]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [42]:
main()

  0%|          | 0/219 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79e34ad5d510>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  0%|          | 0/219 [00:10<?, ?it/s]


TypeError: ignored