In [1]:
from monai.networks.nets import AttentionUnet, Unet
from torchsummary import summary 
import torch
import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [64]:
attunet = AttentionUnet(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
    channels=[64, 128, 256, 512], #1024
    strides=[2,2,2,2]
)

count_parameters(attunet)

7874288

In [3]:
unet = Unet(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
    channels=[32,32,64,128,128],
    strides=[2,2,2,2],
    num_res_units=2,
    act='RELU',
    norm=None,
)

In [4]:
from models import get_attention_unet

In [2]:
from models import get_unet_small, get_unet_large, get_attention_unet, count_parameters

unetS = get_unet_small()
unetL = get_unet_large()
aunet = get_attention_unet()

print(count_parameters(aunet),count_parameters(unetS),count_parameters(unetL))

31739380 1848152 14614520


In [1]:
from data_utils import SegDataset

ds = SegDataset(
    image_dir = './data/f1',
    mask_dir = './data/m1',
    aug_image_dir = './data/f2',
    aug_mask_dir = './data/m2',
    aug_prop=0.5,
)

In [2]:
ds.masks, ds.frames

(['./data/m1\\ch2_ed_mask_50.png',
  './data/m1\\ch2_ed_mask_51.png',
  './data/m1\\ch2_ed_mask_52.png',
  './data/m1\\ch2_ed_mask_53.png',
  './data/m2\\ch2_ed_mask_55.png',
  './data/m2\\ch2_ed_mask_54.png'],
 ['./data/f1\\ch2_ed_frame_50.png',
  './data/f1\\ch2_ed_frame_51.png',
  './data/f1\\ch2_ed_frame_52.png',
  './data/f1\\ch2_ed_frame_53.png',
  './data/f2\\ch2_ed_frame_55.png',
  './data/f2\\ch2_ed_frame_54.png'])

In [5]:
import torch
import torch.nn.functional as F

loss_func = torch.nn.CrossEntropyLoss(reduction='mean')

In [9]:
y_true = torch.randint(4,(20,1,256,256))
x_true = torch.rand((20,1,256,256))
y_pred = unetS(x_true)
print(y_pred.shape)

yb = F.one_hot(y_true.long(), 4).squeeze().permute(0, 3, 1, 2)
print(yb.shape)

torch.Size([20, 4, 256, 256])
torch.Size([20, 4, 256, 256])


In [83]:
yb = F.one_hot(y_true.long(), 4).squeeze().permute(0, 3, 1, 2)
yb.shape

torch.Size([20, 4, 256, 256])

In [12]:
from monai.losses import DiceCELoss

loss_func = DiceCELoss(include_background=False, softmax=True)

loss_func(y_pred, yb.float())

tensor(2.2116, grad_fn=<AddBackward0>)

In [1]:
from data_utils import CustomTransform, SegDataset
