In [1]:
import warnings
from typing import Sequence, Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import alias, export

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, compute_meandice
import torch
import torch.nn as nn
import numpy as np
import os
import torch.optim as opt
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import cv2
import h5py
from scipy.ndimage.interpolation import zoom
import torchvision.transforms as T
import random
from scipy.ndimage.interpolation import zoom
from torch_geometric.nn import SAGEConv
from torch.cuda.amp import GradScaler, autocast

In [3]:
start = []
end = []
for i in range(625):
    if i >= 0 and i <= 24:
        if i == 0:
            start.append(0)
            end.append(1)
            start.append(0)
            end.append(50)
        elif i == 24:
            start.append(24)
            end.append(23)
            start.append(24)
            end.append(49)
        else:
            start.append(i)
            end.append(i-1)
            start.append(i)
            end.append(i+1)
            start.append(i)
            end.append(i+25)
    elif (i >= 600 and i <= 624):
        if i == 600:
            start.append(600)
            end.append(575)
            start.append(600)
            end.append(601)
        elif i == 624:
            start.append(600)
            end.append(575)
            start.append(600)
            end.append(601)
        else:
            start.append(i)
            end.append(i-1)
            start.append(i)
            end.append(i+1)
            start.append(i)
            end.append(i-25)
    elif i % 25 == 0:
        start.append(i)
        end.append(i+1)
        start.append(i)
        end.append(i+25)
        start.append(i)
        end.append(i-25)
    elif i % 25 == 24:
        start.append(i)
        end.append(i-1)
        start.append(i)
        end.append(i+25)
        start.append(i)
        end.append(i-25)     
    else:
        start.append(i)
        end.append(i-1)
        start.append(i)
        end.append(i+25)
        start.append(i)
        end.append(i-25)  
        start.append(i)
        end.append(i+1)
edges = torch.tensor([start, end], dtype=torch.long)

In [4]:
base = './'
data_path = [base + 'train/' + x for x in os.listdir(base + 'train/')]
names = [x.split('/')[-1].split('_')[0] for x in data_path]
counts = pd.Series(names).value_counts().to_dict()
templete =  './train/{x}_{y}.h5'
train_paths = []
for name in counts.keys():
    count = counts[name]
    for i in range(1, count-2):
        temp = []
        for j in range(i-1, i+2):
            temp.append(templete.format(x=name, y=j))
        train_paths.append(temp)

base = './'
data_path = [base + 'val/' + x for x in os.listdir(base + 'val/')]
names = [x.split('/')[-1].split('_')[0] for x in data_path]
counts = pd.Series(names).value_counts().to_dict()
templete =  './val/{x}_{y}.h5'
val_paths = []
for name in counts.keys():
    count = counts[name]
    for i in range(1, count-2):
        temp = []
        for j in range(i-1, i+2):
            temp.append(templete.format(x=name, y=j))
        val_paths.append(temp)

class segmentation(torch.utils.data.Dataset):
    def __init__(self, paths, aug=False, train=True):
        self.paths = paths
        self.train = train
        self.aug = aug
        
    def __getitem__(self, idx):
        paths = self.paths[idx]
        xs = []
        f = h5py.File(paths[1], 'r')
        y = f['gt'][:]
        y = torch.from_numpy(y)
        y = y.float().view(1, 400, 400)
        f.close()
        for path in paths:
            f = h5py.File(path, 'r')
            x = f['ct'][:].astype(np.float)
            x = torch.from_numpy(x)
            x = x.float().view(1, 400, 400)
            xs.append(x)  
            f.close()
        x = torch.cat(xs, dim=0)
        return x, y
        
    def __len__(self):
        return len(self.paths)
    
def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
    # if `dim` is bigger, add singleton dim at the end
    if labels.ndim < dim + 1:
        shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
        labels = torch.reshape(labels, shape)

    sh = list(labels.shape)

    if sh[dim] != 1:
        raise AssertionError("labels should have a channel with length equal to one.")

    sh[dim] = num_classes

    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
    labels = o.scatter_(dim=dim, index=labels.long(), value=1)

    return labels


In [5]:
test = train_paths[0]
xs = []
f = h5py.File(test[1], 'r')
y_test = f['gt'][:]
y_test = torch.from_numpy(y_test)
y_test = y_test.float().view(1, 400, 400)
f.close()
for path in test:
    f = h5py.File(path, 'r')
    x_test = f['ct'][:].astype(np.float)
    x_test = torch.from_numpy(x_test)
    x_test = x_test.float().view(1, 400, 400)
    xs.append(x_test)  
    f.close()
x_test = torch.cat(xs, dim=0)
x_test = x_test.view(1, 3, 400, 400)

In [10]:
class UNet_GNN(nn.Module):
    def __init__(
        self,
        dimensions: int,
        in_channels: int,
        out_channels: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        up_kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 0,
        act: Union[Tuple, str] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout=0.0,
    ) -> None:
        super().__init__()
        delta = len(strides) - (len(channels) - 1)
        self.dimensions = dimensions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.channels = channels
        self.strides = strides
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout
        self.downs = []
        self.ups = []
        self.downs = []
        
        
        #graph sage conv
        self.sageconv1 = SAGEConv(in_channels = 256, out_channels = 256)
        self.relu = nn.ReLU(inplace=True)
        self.sageconv2 = SAGEConv(in_channels = 256, out_channels = 256)
        self.linear =  nn.Linear(512, 256)
        def _create_block(
            inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool
        ) -> nn.Sequential:
            c = channels[0]
            s = strides[0]

            subblock: nn.Module

            if len(channels) > 2:
                _create_block(c, c, channels[1:], strides[1:], False)
                upc = c * 2
            else:
                self.subblock = self._get_bottom_layer(c, channels[1])
                upc = c + channels[1]

            self.downs.append(self._get_down_layer(inc, c, s, is_top))
            self.ups.append(self._get_up_layer(upc, outc, s, is_top))
        
        _create_block(in_channels, out_channels, self.channels, self.strides, True)
        print(len(self.ups), len(self.downs))
        self.up1, self.up2, self.up3, self.up4 = self.ups
        del self.ups
        self.down1, self.down2, self.down3, self.down4 = self.downs
        del self.downs

    def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
        if self.num_res_units > 0:
            return ResidualUnit(
                self.dimensions,
                in_channels,
                out_channels,
                strides=strides,
                kernel_size=self.kernel_size,
                subunits=self.num_res_units,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
            )
        return Convolution(
            self.dimensions,
            in_channels,
            out_channels,
            strides=strides,
            kernel_size=self.kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
        )

    def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:
        return self._get_down_layer(in_channels, out_channels, 1, False)

    def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
        conv: Union[Convolution, nn.Sequential]

        conv = Convolution(
            self.dimensions,
            in_channels,
            out_channels,
            strides=strides,
            kernel_size=self.up_kernel_size,
            act=self.act,
            norm=self.norm,
            dropout=self.dropout,
            conv_only=is_top and self.num_res_units == 0,
            is_transposed=True,
        )

        if self.num_res_units > 0:
            ru = ResidualUnit(
                self.dimensions,
                out_channels,
                out_channels,
                strides=1,
                kernel_size=self.kernel_size,
                subunits=1,
                act=self.act,
                norm=self.norm,
                dropout=self.dropout,
                last_conv_only=is_top,
            )
            conv = nn.Sequential(conv, ru)

        return conv

    def forward(self, x: torch.Tensor, device, edges) -> torch.Tensor:
        edges = edges.to(device)
        xs = []        
        for m in [self.down4, self.down3, self.down2, self.down1]:
            x = m(x)
            #print(x.shape)
            xs.append(x)
        
        x = self.subblock(x)   
        #print(x.shape)
        graph_x = x.view(x.shape[0], 256, -1).permute(0, 2, 1)
        #print(x.shape)
        graph_x = self.sageconv1(x=graph_x, edge_index=edges)
        graph_x = self.relu(graph_x)
        graph_x = self.sageconv2(x=graph_x, edge_index=edges)
        graph_x = self.relu(graph_x)
        graph_x = graph_x.permute(0, 2, 1).view(graph_x.shape[0], 256, 25, 25).float()
        
        x = torch.cat([x, graph_x], dim=1).permute(0, 2, 3, 1)
        x = self.linear(x)
        x = self.relu(x)
        x = x.permute(0, 3, 1, 2)
        #print(x.shape)
        for m, cat in zip([self.up1, self.up2, self.up3, self.up4], xs[::-1]):
            x = torch.cat([cat, x], dim=1)
            x = m(x)
            #print(x.shape)

        return x

In [11]:
class Trainer():
    def __init__(self,model,train_set,test_set,opts):
        self.model = model  # neural net
        # device agnostic code snippet
        self.device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
        print(self.device)
        self.model.to(self.device)
        
        self.epochs = opts['epochs']
        self.scaler = GradScaler()
        self.optim = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.001)
        self.criterion = DiceLoss(to_onehot_y=True, softmax=True, squared_pred=False)                     # loss function
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                        batch_size=opts['batch_size'],
                                                        shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                       batch_size=opts['batch_size'],
                                                       shuffle=False)
        #self.tb = SummaryWriter(log_dir='./runs/unet_2/')
        self.best_loss = 0
        
    def train(self):
        for epoch in range(self.epochs):
            self.model.train() #put model in training mode
            self.tr_loss = []
            for i, (data,labels) in tqdm(enumerate(self.train_loader),
                                                   total = len(self.train_loader)):
                data, labels = data.to(self.device),labels.to(self.device)
                self.optimizer.zero_grad()  
                outputs = self.model(data, self.device, edges)   
                loss = self.criterion(outputs, labels)
                loss.backward()                        
                self.optimizer.step()                  
                self.tr_loss.append(loss.item())
                #self.tb.add_scalar("Train Loss", np.mean(self.tr_loss), epoch)
            
            self.test(epoch) # run through the validation set
        self.tb.close()
            
    def test(self,epoch):
            
            self.model.eval()    # puts model in eval mode - not necessary for this demo but good to know
            self.test_loss = []
            self.test_dice = []
            self.test_acc = []
            
            for i, (data, labels) in enumerate(self.test_loader):
                
                data, labels = data.to(self.device),labels.to(self.device)
                
                with torch.no_grad():
                    outputs = self.model(data, self.device, edges)
                loss = self.criterion(outputs, labels)
                self.test_loss.append(loss.item())
                outputs = torch.nn.functional.softmax(outputs, 1)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.view(-1, 1, 400, 400)
                temp_dice = compute_meandice(one_hot(predicted, 2), one_hot(labels, 2), include_background=False).detach().cpu().numpy()
                if np.nanmean(temp_dice) == np.nanmean(temp_dice):
                    self.test_dice.append(np.nanmean(temp_dice))
                self.test_acc.append((predicted == labels).sum().item() / (predicted.size(0)*400*400))
               
            print('epoch: {}, train loss: {}, test loss: {}'.format( 
                  epoch+1, np.mean(self.tr_loss), np.mean(self.test_loss)))
            print('epoch: {}, test dice: {}, test acc: {}'.format( 
                  epoch+1, np.nanmean(self.test_dice), np.mean(self.test_acc)))           
            #self.tb.add_scalar("Val Loss", np.mean(self.test_loss), epoch)
            #self.tb.add_scalar("Val dice", np.nanmean(self.test_dice), epoch)
            #self.tb.add_scalar("Val acc", np.mean(self.test_acc), epoch)
            if np.nanmean(self.test_dice) > self.best_loss:
                self.best_loss = np.nanmean(self.test_dice)
                #torch.save(self.model, './model_weights/best_unet_2.pt')

In [12]:
model = UNet_GNN(
    dimensions=2,
    in_channels=3,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

4 4


In [13]:
x = model(x_test, torch.device('cpu'), edges)

In [14]:
train_set, val_set = segmentation(train_paths[:2000]), segmentation(val_paths[:1000], train=False)

opts = {
    'lr': 5e-4,
    'epochs': 40,
    'batch_size': 32
}
train = Trainer(model, train_set, val_set, opts)
train.train()

cuda:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))


epoch: 1, train loss: 0.6082643393486266, test loss: 0.596521932631731
epoch: 1, test dice: 0.12290536612272263, test acc: 0.6805059387207032


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=63.0), HTML(value='')))




KeyboardInterrupt: 

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [None]:
params