In [1]:
import os
import glob
import shutil


In [11]:
path = "logs/amos_ver2/weights/*.pt"

for w in glob.glob(path):
    num = int(w.split("_")[-1].strip(".pt"))
    if num % 10 != 0:
        os.remove(w)

In [110]:

import torch
from torch.nn.modules.loss import _Loss


class MultiNeighborLoss(_Loss):
    def __init__(self, 
                 num_classes: int, 
                 reduction: str = "mean", 
                 centroid_method: str = "mean"):
        super(MultiNeighborLoss, self).__init__()
        self.num_classes = num_classes
        self.reduction = reduction
        self.centroid_method = centroid_method
        self.max_count = self.num_classes * (self.num_classes - 1) // 2
        
    def forward(self, probs: torch.Tensor, labels: torch.Tensor):
        assert probs.ndim == labels.ndim == 5, "The dimensions of probs and labels should be same and 5."
        
        delta = []
        for i in range(probs.size(0)):
            p_angles, l_angles = self.compute_angles(torch.sigmoid(probs[i, ...])), self.compute_angles(labels[i, ...])
            delta.append(torch.square(p_angles - l_angles))
        
        delta = torch.cat(delta)
        not_nans = ~torch.isnan(delta)
        delta = delta[not_nans]
        delta = delta[delta > 0]
        
        if self.reduction == "mean":
            return torch.mean(delta)
        
    def compute_angles(self, t: torch.Tensor) -> torch.Tensor:
        angles = torch.zeros(self.max_count*self.max_count).to(t.device)
        vectors = torch.zeros(self.max_count, 3).to(t.device)
        centroids = torch.zeros((self.num_classes, 3)).to(t.device)
        
        t = torch.argmax(t, dim=0)
        
        for i in range(self.num_classes):
            z, y, x = torch.where(t == i)
            centroids[i] = torch.stack(self.compute_centroids(x, y, z))
            print(centroids[i])
        
        idx = 0
        for i in range(self.num_classes):
            for j in range(i+1, self.num_classes):
                vectors[idx] = centroids[j] - centroids[i]
                idx += 1
    
        idx = 0
        for i in range(self.max_count):
            m = vectors[i]
            for j in range(i+1, self.max_count):
                n = vectors[j]
                angle = torch.acos(torch.dot(m, n) / (torch.norm(m) * torch.norm(n)))
                angles[idx] = angle
                idx += 1
                
        return angles
    
    def compute_centroids(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        if self.centroid_method == "mean":
            return [torch.mean(x.float()), torch.mean(y.float()), torch.mean(z.float())]
        else:
            raise NotImplementedError(f"The centroid method is not supported. : {self.centroid_method}")

In [111]:
num_classes = 16
device = torch.device("cuda:1")
loss = MultiNeighborLoss(num_classes)
for _ in range(100):
    probs = torch.randint(0, 16, (1, num_classes, 96, 96, 96)).to(device)
    labels = torch.randint(0, num_classes, (1, num_classes, 96, 96, 96)).to(device)
    
    l = loss(probs, labels)
    
    print(f"loss : {l:.4f}")

tensor([47.5223, 47.6678, 47.4686], device='cuda:1')
tensor([47.5407, 47.6690, 47.4988], device='cuda:1')
tensor([47.6857, 47.3836, 47.4982], device='cuda:1')
tensor([47.5707, 47.4679, 47.6160], device='cuda:1')
tensor([47.4402, 47.6460, 47.5364], device='cuda:1')
tensor([47.5677, 47.1164, 47.4269], device='cuda:1')
tensor([47.3891, 47.4585, 47.5519], device='cuda:1')
tensor([47.4632, 47.5281, 47.5006], device='cuda:1')
tensor([47.4569, 47.5409, 47.4113], device='cuda:1')
tensor([47.5372, 47.4871, 47.5200], device='cuda:1')
tensor([47.3665, 47.4282, 47.4687], device='cuda:1')
tensor([47.5855, 47.3465, 47.5200], device='cuda:1')
tensor([47.7364, 47.5262, 47.5768], device='cuda:1')
tensor([47.3583, 47.4969, 47.3249], device='cuda:1')
tensor([47.2896, 47.5045, 47.5268], device='cuda:1')
tensor([47.1772, 47.6582, 47.5131], device='cuda:1')
tensor([47.6740, 47.4874, 47.4963], device='cuda:1')
tensor([47.3916, 47.4529, 47.4113], device='cuda:1')
tensor([47.5216, 47.3103, 47.6167], device='cu

KeyboardInterrupt: 

In [1]:
from monai.networks.nets.attentionunet import AttentionUnet

model = AttentionUnet(3, 1, 13, [64, 128, 256, 512, 1024], [2, 2, 2, 2, 2])

print(model)

AttentionUnet(
  (model): Sequential(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Convolution(
          (conv): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (adn): ADN(
            (N): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.0, inplace=False)
            (A): ReLU()
          )
        )
        (1): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (adn): ADN(
            (N): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.0, inplace=False)
            (A): ReLU()
          )
        )
      )
    )
    (1): AttentionLayer(
      (attention): AttentionBlock(
        (W_g): Sequential(
          (0): Convolution(
            (conv): Conv3d(64, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
          )
          (1): BatchNorm3d(32, eps=1e-

In [4]:
from models.attention_diff_unet import AttentionDiffUNet

model = AttentionDiffUNet(3, 3, 13)
x = torch.randn((4, 3, 96, 96, 96))
image = torch.randint(0, 1000, (4, 1))
label = torch.randn((4, 3, 96, 96, 96))
x_start = (label) * 2 - 1
x_t, t, _ = model(x=x_start, pred_type="q_sample")
pred = model(x=x_t, step=t, image=image, pred_type="denoise")

print(f"x_t : {x_t.shape}")
print(f"t : {t.shape}")
print(f"pred : {pred.shape}")

TypeError: 'int' object is not subscriptable

In [9]:
import os
import pickle
import numpy as np

dices = "logs/diff-unet-msd-1/dices.pkl"

with open(dices, 'rb') as file:
    dices = pickle.load(file)


num_classes = 2
scores = np.zeros((len(dices), num_classes))
for i, dice in enumerate(dices):
    dice = list(dice.values())
    scores[i, 0] = dice[0]
    scores[i, 1] = dice[1]
    
print(np.mean(scores, axis=0))

[0.65603514 0.47183063]


In [None]:
import os
import pickle
import numpy as np

dices = "logs/diff-unet-btcv-29/dices.pkl"

with open(dices, 'rb') as file:
    dices = pickle.load(file)


num_classes = 2
scores = np.zeros((len(dices), num_classes))
for i, dice in enumerate(dices):
    dice = list(dice.values())
    scores[i, 0] = dice[0]
    scores[i, 1] = dice[1]
    
print(np.mean(scores, axis=0))

In [5]:
import os
import glob

import nibabel

data_path = "/home/song99/ws/datasets/AMOS"
images = glob.glob(os.path.join(data_path, "imagesTr/*.nii.gz"))

for image in images:
    img = nibabel.load(image)
    print(img.header)
    break


<class 'nibabel.nifti1.Nifti1Header'> object, endian='<'
sizeof_hdr      : 348
data_type       : b''
db_name         : b''
extents         : 0
session_error   : 0
regular         : b'r'
dim_info        : 0
dim             : [  3 768 768  90   1   1   1   1]
intent_p1       : 0.0
intent_p2       : 0.0
intent_p3       : 0.0
intent_code     : none
datatype        : int16
bitpix          : 16
slice_start     : 0
pixdim          : [-1.         0.5703125  0.5703125  5.         0.         0.
  0.         0.       ]
vox_offset      : 0.0
scl_slope       : nan
scl_inter       : nan
slice_end       : 0
slice_code      : unknown
xyzt_units      : 10
cal_max         : 0.0
cal_min         : 0.0
slice_duration  : 0.0
toffset         : 0.0
glmax           : 0
glmin           : 0
descrip         : b'Time=144225.000'
aux_file        : b''
qform_code      : scanner
sform_code      : scanner
quatern_b       : 0.0
quatern_c       : 1.0
quatern_d       : 0.0
qoffset_x       : 233.0
qoffset_y       : -373.4

In [4]:
import os
import glob

import nibabel

data_path = "/home/song99/ws/datasets/BTCV"
images = glob.glob(os.path.join(data_path, "imagesTr/*.nii.gz"))

for image in images:
    img = nibabel.load(image)
    # print(img.affine)
    print(nibabel.orientations.aff2axcodes(img.affine))
    # print(img.header)
    break


[[  -0.90625     0.          0.        231.546875]
 [   0.          0.90625     0.       -231.546875]
 [   0.          0.          3.       -174.      ]
 [   0.          0.          0.          1.      ]]
('L', 'A', 'S')
<class 'nibabel.nifti1.Nifti1Header'> object, endian='<'
sizeof_hdr      : 348
data_type       : b''
db_name         : b''
extents         : 0
session_error   : 0
regular         : b'r'
dim_info        : 0
dim             : [  3 512 512 117   1   1   1   1]
intent_p1       : 0.0
intent_p2       : 0.0
intent_p3       : 0.0
intent_code     : none
datatype        : int16
bitpix          : 16
slice_start     : 0
pixdim          : [1.      0.90625 0.90625 3.      0.      0.      0.      0.     ]
vox_offset      : 0.0
scl_slope       : nan
scl_inter       : nan
slice_end       : 0
slice_code      : unknown
xyzt_units      : 10
cal_max         : 0.0
cal_min         : 0.0
slice_duration  : 0.0
toffset         : 0.0
glmax           : 2861
glmin           : -1024
descrip        

In [None]:
import torch
import torch.nn.functional as F

def window_partition_5d(x: torch.Tensor, window_size: int):
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, D, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * D * num_windows, window_size, window_size, C].
        (Dp, Hp, Wp): padded depth, height, and width before partition
    """
    B, D, H, W, C = x.shape

    pad_d = (window_size - D % window_size) % window_size
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_d > 0 or pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, 0, 0, pad_w, 0, pad_h, 0, pad_d))
    Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w

    x = x.view(B, Dp // window_size, window_size, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 5, 4, 6, 7).contiguous().view(-1, window_size, window_size, C)
    return windows, (Dp, Hp, Wp)

x = torch.zeros((1, 96, 96, 96, 768))

window_partition_5d(x, 4)


In [None]:
import torch
from models.attention_unet import AttentionUNet

model = AttentionUNet()
x = torch.ones((1, 1, 96, 96, 96))
out = model(x)

In [4]:
import pickle
import numpy as np

score_path = "logs/diff-unet-msd-1/dice_epoch_500.pkl"

with open(score_path, 'rb') as f:
    dices = pickle.load(f)
    
print(dices)
num_classes = 2
scores = np.zeros((len(dices), num_classes))
for i, dice in enumerate(dices):
    dice = list(dice.values())
    scores[i, 0] = dice[0]
    scores[i, 1] = dice[1]
    
print(np.mean(scores, axis=0))
print(np.mean(scores))

[OrderedDict([('pancreas', 0.6825417876243591), ('cancer', 0.34127089381217957)]), OrderedDict([('pancreas', 0.6981605887413025), ('cancer', 0.34908029437065125)]), OrderedDict([('pancreas', 0.8605238795280457), ('cancer', 0.4739607274532318)]), OrderedDict([('pancreas', 0.5984222888946533), ('cancer', 0.4047772288322449)]), OrderedDict([('pancreas', 0.7038456797599792), ('cancer', 0.3519228398799896)]), OrderedDict([('pancreas', 0.6969820857048035), ('cancer', 0.6875232458114624)]), OrderedDict([('pancreas', 0.7937067151069641), ('cancer', 0.5360509753227234)]), OrderedDict([('pancreas', 0.4904691278934479), ('cancer', 0.26022183895111084)]), OrderedDict([('pancreas', 0.7019046545028687), ('cancer', 0.3509523272514343)]), OrderedDict([('pancreas', 0.7608415484428406), ('cancer', 0.3804207742214203)]), OrderedDict([('pancreas', 0.43491634726524353), ('cancer', 0.3177517354488373)]), OrderedDict([('pancreas', 0.6025535464286804), ('cancer', 0.3012767732143402)]), OrderedDict([('pancreas