In [1]:
import os
import glob
import shutil


In [11]:
path = "logs/amos_ver2/weights/*.pt"

for w in glob.glob(path):
    num = int(w.split("_")[-1].strip(".pt"))
    if num % 10 != 0:
        os.remove(w)

In [18]:

import torch
from torch.nn.modules.loss import _Loss


class MultiNeighborLoss(_Loss):
    def __init__(self, 
                 num_classes: int, 
                 reduction: str = "mean", 
                 centroid_method: str = "mean"):
        super(MultiNeighborLoss, self).__init__()
        self.num_classes = num_classes
        self.reduction = reduction
        self.centroid_method = centroid_method
        self.max_angles = self.num_classes * (self.num_classes - 1) // 2
        
    def forward(self, probs: torch.Tensor, labels: torch.Tensor):
        assert probs.ndim == labels.ndim == 5, "The dimensions of probs and labels should be same and 5."
        
        delta = []
        for i in range(probs.size(0)):
            p_angles, l_angles = self.compute_angles(torch.sigmoid(probs[i, ...])), self.compute_angles(labels[i, ...])
            delta.append(torch.square(p_angles - l_angles))
        
        delta = torch.cat(delta)
        nans = torch.isnan(delta)
        not_nans = ~nans
        print(delta)
        print(delta[not_nans])
        if self.reduction == "mean":
            return 0
        
    def compute_angles(self, t: torch.Tensor) -> torch.Tensor:
        idx = 0 
        angles = torch.zeros(self.max_angles).to(t.device)
        centroids = [None for _ in range(self.num_classes)]
        
        t = torch.argmax(t, dim=0)
        
        for i in range(self.num_classes):
            if i == 0: continue # do not consider backgrounds
            z, y, x = torch.where(t == i)
            centroids[i] = torch.stack(self.compute_centroids(x, y, z))
        
        for i in range(self.num_classes):
            for j in range(i+1, self.num_classes):
                if centroids[i] is not None and centroids[j] is not None:
                    m, n = centroids[i], centroids[j] # 2 vectors to calculate angles
                    angle = torch.acos(torch.dot(m, n) / (torch.norm(m) * torch.norm(n)))
                    
                    # if torch.isnan(angle):
                    #     angle = 0 # torch.randn((1, )).to(x.device)
                    
                    angles[idx] = angle
                    idx += 1
                    
        return angles
    
    def compute_centroids(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        if self.centroid_method == "mean":
            return [torch.mean(x.float()), torch.mean(y.float()), torch.mean(z.float())]
        else:
            raise NotImplementedError(f"The centroid method is not supported. : {self.centroid_method}")

In [19]:
num_classes = 13
device = torch.device("cuda:1")
loss = MultiNeighborLoss(num_classes)
for _ in range(100):
    probs = torch.rand((1, num_classes, 96, 96, 96)).to(device) * 0.0001
    labels = torch.randint(0, 1, (1, num_classes, 96, 96, 96)).to(device)
    
    l = loss(probs, labels)
    
    # print(f"loss : {l:.4f}")

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
tensor([0., 0., 0., 0., 0., 0., 0., 0.