In [1]:
import torch.nn as nn 
import torch 
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
from glob import glob 
import os 

from skimage import io
from patchify import patchify, unpatchify

import random 
import torch 
from torch.utils import data 
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image

import cv2
from glob import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import nibabel as nib

In [2]:
path= '/home/arshad/Downloads/amos22/amos22'
input_paths   = sorted(glob(os.path.join(path, "imagesVa","*.nii.gz")))[:25]
target_paths  = sorted(glob(os.path.join(path, "labelsVa","*.nii.gz")))[:25]

In [3]:
# input_paths

In [4]:
IMAGE_SIZE = 64
BATCH_SIZE = 1
NUM_CLASS = 2

In [5]:
device_ids = [0, 1] # GPUs 0 and 1
devices = [torch.device(f'cuda:{i}') for i in device_ids]

In [6]:
devices

[device(type='cuda', index=0), device(type='cuda', index=1)]

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
class AmosDataLoader(data.Dataset):
  def __init__(
      self, 
      input_paths: list, 
      target_paths: list, 
      transform_input = None, 
      transform_target = None
  ): 

    self.input_paths      = input_paths
    self.target_paths     = target_paths
    self.transform_input  = transform_input
    self.transform_target = transform_target

  def __len__(self):
    return len(self.input_paths)

  def preprocess_img_input(self, input_im):
    input_im = patchify(input_im, (IMAGE_SIZE, IMAGE_SIZE, IMAGE_SIZE), step=IMAGE_SIZE)
    input_im = np.reshape(input_im, (-1, input_im.shape[3], input_im.shape[4], input_im.shape[5]))
    input_im = np.stack((input_im,)*3, axis=-1) #reduced to the 1 channel
    input_im = torch.tensor(input_im).float()/255
    
    input_im = input_im.permute(0,4,1,2,3)
    return input_im
  
  def preprocess_img_output(self, output_im):
    output_im = patchify(output_im, (IMAGE_SIZE, IMAGE_SIZE, IMAGE_SIZE), step=IMAGE_SIZE)
    output_im = np.reshape(output_im, (-1, output_im.shape[3], output_im.shape[4], output_im.shape[5]))
    output_im = np.expand_dims(output_im, axis = 4)
    output_im = torch.tensor(output_im).float()/255
    
    output_im = output_im.permute(0,4,1,2,3)
    return output_im


  def __getitem__(self,x):
    input_im = self.input_paths[x]
    mask_im  = self.target_paths[x]
    input_im    = nib.load(input_im).get_fdata()
    mask_im     = nib.load(mask_im).get_fdata()

    return input_im, mask_im
    
  def collate_fn(self, batch):
    # print(len(batch[0][0]))
    im_ins, im_outs = [], []
    for im_in, im_out  in batch: 
      im_in = self.preprocess_img_input(im_in)
      im_out = self.preprocess_img_output(im_out)

      # im_out = self.preprocess_output(im_out)
      # print(im_in.shape, im_out.shape)
      im_ins.append(im_in)
      im_outs.append(im_out)

    # print(torch.tensor(im_ins).shape)
    return torch.cat(im_ins, dim = 0), torch.cat(im_outs, dim= 0)

In [9]:
train_dl      = AmosDataLoader(input_paths, target_paths)
train_loader  = DataLoader(train_dl, batch_size = BATCH_SIZE, drop_last= True, collate_fn=train_dl.collate_fn)

In [10]:
# for d,k in train_dl: 
#   # # print(len(d))}
#   print(d.shape, k.shape)
#   break

In [11]:
# for d, k in train_loader:
#     print(d.shape, k.shape)
#     break

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv3D(in_channels, out_channels)
        )

    def forward(self, x):
        return self.mpconv(x)

class Up3D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels//2, kernel_size=2, stride=2)

        self.conv = DoubleConv3D(in_channels, out_channels)

    def forward(self, x1, x2):
        # print(x1.shape, x2.shape)
        x1    = self.up(x1)
        # print(x1.shape)
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1    = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2))
        x     = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, 1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.bilinear     = bilinear

        self.conv1    = DoubleConv3D(in_channels, 64)
        self.down1    = Down3D(64, 128)
        self.down2    = Down3D(128, 256)
        self.down3    = Down3D(256, 512)
        self.down4    = Down3D(512, 1024)
        self.up1      = Up3D(1024, 512, bilinear)
        self.up2      = Up3D(512, 256, bilinear)
        self.up3      = Up3D(256, 128, bilinear)
        self.up4      = Up3D(128, 64, bilinear)
        self.outconv  = OutConv3D(64, out_channels)

    # def forward(self, x):
    #     # print(x.shape)
    #     # x = x.unsqueeze(1)
    #     x1 = self.conv1(x)
    #     x2 = self.down1(x1)
    #     x3 = self.down2(x2)
    #     x4 = self.down3(x3)
    #     x5 = self.down4(x4)
    #     # print(x5.shape, x4.shape)
    #     x6 = self.up1(x5, x4)
    #     x7 = self.up2(x6, x3)
    #     x8 = self.up3(x7, x2) 
    #     x9 = self.up4(x8, x1)
    #     output= self.outconv(x9)
    #     # print(x6.shape)
    #     # up network
        
    def forward(self, x):
        # print(x.shape)
        # x = x.unsqueeze(1)
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        # x4 = self.down3(x3)
        # x5 = self.down4(x4)
        # print(x5.shape, x4.shape)
        # x = self.up1(x5, x4)
        # x = self.up2(x, x3)
        x = self.up3(x, x2) 
        x = self.up4(x, x1)
        
        output= self.outconv(x)
        # print(x6.shape)
        # up network
        del x1
        del x2
        del x3
        del x4
        del x5
        del x

        return output

In [13]:
model = nn.DataParallel(UNet3D(3, NUM_CLASS)).to(device)

In [14]:
# model

In [15]:
def myloss(A, B, reduction='none'): 
    out = A.clone()
    out -= B
    return out.pow_(2)

In [16]:
def train_batch(data, model, optimizer):
    model.train()
    ims_in, ims_out = data
    optimizer.zero_grad()
    pred_img = model(ims_in.to(device))    
    total_loss = myloss(pred_img.to('cpu'), ims_out) #custom loss function
    total_loss.backward()
    optimizer.step()
    return total_loss

In [17]:
n_epoch = 1
#criteria = #nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [18]:
train_loss = []
for epoch in range(n_epoch):
  print(epoch)
  epoch_loss = []

  for ix, ims in enumerate(train_loader):
    loss = train_batch(ims, model, optimizer)
    epoch_loss.append(loss)
  print('avg_loss', sum(epoch_loss)/len(epoch_loss))
  train_loss.append(sum(epoch_loss)/len(epoch_loss))

0


OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_29007/3434515746.py", line 102, in forward
    x2 = self.down1(x1)
         ^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_29007/3434515746.py", line 29, in forward
    return self.mpconv(x)
           ^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_29007/3434515746.py", line 18, in forward
    return self.conv(x)
           ^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
    return F.batch_norm(
           ^^^^^^^^^^^^^
  File "/home/arshad/Downloads/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
           ^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.12 GiB (GPU 0; 23.65 GiB total capacity; 21.57 GiB already allocated; 1001.31 MiB free; 21.59 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
