In [1]:
import os
import glob
import shutil


In [11]:
path = "logs/amos_ver2/weights/*.pt"

for w in glob.glob(path):
    num = int(w.split("_")[-1].strip(".pt"))
    if num % 10 != 0:
        os.remove(w)

In [105]:

import torch
from torch.nn.modules.loss import _Loss


class MultiNeighborLoss(_Loss):
    def __init__(self, 
                 num_classes: int, 
                 reduction: str = "mean", 
                 centroid_method: str = "mean"):
        super(MultiNeighborLoss, self).__init__()
        self.num_classes = num_classes
        self.reduction = reduction
        self.centroid_method = centroid_method
        self.max_count = self.num_classes * (self.num_classes - 1) // 2
        
    def forward(self, probs: torch.Tensor, labels: torch.Tensor):
        assert probs.ndim == labels.ndim == 5, "The dimensions of probs and labels should be same and 5."
        
        delta = []
        for i in range(probs.size(0)):
            p_angles, l_angles = self.compute_angles(torch.sigmoid(probs[i, ...])), self.compute_angles(labels[i, ...])
            delta.append(torch.square(p_angles - l_angles))
        
        delta = torch.cat(delta)
        not_nans = ~torch.isnan(delta)
        delta = delta[not_nans]
        delta = delta[delta > 0]
        
        print(delta)
        if self.reduction == "mean":
            return torch.mean(delta)
        
    def compute_angles(self, t: torch.Tensor) -> torch.Tensor:
        angles = torch.zeros(self.max_count*self.max_count).to(t.device)
        vectors = torch.zeros(self.max_count, 3).to(t.device)
        centroids = torch.zeros((self.num_classes, 3)).to(t.device)
        
        t = torch.argmax(t, dim=0)
        
        for i in range(self.num_classes):
            z, y, x = torch.where(t == i)
            centroids[i] = torch.stack(self.compute_centroids(x, y, z))
        
        idx = 0
        for i in range(self.num_classes):
            for j in range(i+1, self.num_classes):
                vectors[idx] = centroids[j] - centroids[i]
                idx += 1
    
        idx = 0
        for i in range(self.max_count):
            m = vectors[i]
            for j in range(i+1, self.max_count):
                n = vectors[j]
                angle = torch.acos(torch.dot(m, n) / (torch.norm(m) * torch.norm(n)))
                angles[idx] = angle
                idx += 1
                
        return angles
    
    def compute_centroids(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        if self.centroid_method == "mean":
            return [torch.mean(x.float()), torch.mean(y.float()), torch.mean(z.float())]
        else:
            raise NotImplementedError(f"The centroid method is not supported. : {self.centroid_method}")

In [107]:
num_classes = 16
device = torch.device("cuda:1")
loss = MultiNeighborLoss(num_classes)
for _ in range(100):
    probs = torch.randint(0, 16, (1, num_classes, 96, 96, 96)).to(device)
    labels = torch.randint(0, num_classes, (1, num_classes, 96, 96, 96)).to(device)
    
    l = loss(probs, labels)
    
    print(f"loss : {l:.4f}")

tensor([1.7866e-01, 7.4428e-07, 1.5919e-02,  ..., 2.1335e+00, 2.3881e+00,
        7.1768e-03], device='cuda:1')
loss : 0.8479
tensor([0.2745, 0.6340, 0.0576,  ..., 0.1568, 0.2313, 0.7691], device='cuda:1')
loss : 0.7851
tensor([0.0354, 0.0464, 0.1652,  ..., 1.4588, 3.6913, 0.5091], device='cuda:1')
loss : 0.8139
tensor([2.9965e-01, 3.1527e+00, 4.7777e-05,  ..., 9.2324e-02, 1.7611e-02,
        2.9290e-02], device='cuda:1')
loss : 0.9489
tensor([0.1273, 0.3987, 1.8056,  ..., 1.3059, 0.0624, 0.7973], device='cuda:1')
loss : 0.8713
tensor([2.6181, 0.0215, 0.1153,  ..., 0.0474, 0.5314, 0.8961], device='cuda:1')
loss : 0.8205
tensor([0.6195, 1.7064, 0.4644,  ..., 0.0512, 0.3895, 0.1583], device='cuda:1')
loss : 0.9941
tensor([4.0904, 0.7928, 0.6264,  ..., 0.0153, 0.0919, 0.1823], device='cuda:1')
loss : 0.8226
tensor([0.7202, 0.0067, 0.8958,  ..., 0.0260, 0.0068, 0.0593], device='cuda:1')
loss : 0.8866
tensor([0.0357, 1.5144, 0.1976,  ..., 1.0550, 1.7174, 0.0803], device='cuda:1')
loss : 0.8