In [2]:
import numpy as np
import torch
from torchvision import transforms
from torchvision.ops import sigmoid_focal_loss


import torch.nn as nn
import torch.nn.functional as F


from torch.autograd import Variable

# https://d2l.ai/chapter_recurrent-modern/gru.html

import matplotlib.pyplot as plt

In [3]:
criterion_reg = nn.MSELoss()
criterion_class = nn.BCELoss()

In [3]:
class FocalLoss(nn.Module):
    def __init__(
            self,
            weight=None,
            gamma=2.0,
            reduction='mean'
    ):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = torch.sigmoid(input_tensor)
        prob = torch.exp(log_prob)
        
        term = ((1.0 - prob) ** self.gamma) * log_prob
        return F.nll_loss(term, target_tensor)
        
        #return F.nll_loss(term, target_tensor, weight=self.weight, reduction=self.reduction)
        

In [4]:
class FocalLoss01(nn.Module):

    def __init__(self, gamma=0, weight=None, size_average=True):
        super(FocalLoss01, self).__init__()

        self.gamma = gamma
        self.weight = weight
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.contiguous().view(input.size(0), input.size(1), -1)
            input = input.transpose(1,2)
            input = input.contiguous().view(-1, input.size(2)).squeeze()
        if target.dim()==4:
            target = target.contiguous().view(target.size(0), target.size(1), -1)
            target = target.transpose(1,2)
            target = target.contiguous().view(-1, target.size(2)).squeeze()
        elif target.dim()==3:
            target = target.view(-1)
        else:
            target = target.view(-1, 1)

        # compute the negative likelyhood
        weight = Variable(self.weight)
        logpt = -F.cross_entropy(input, target)
        pt = torch.exp(logpt)

        # compute the loss
        loss = -((1-pt)**self.gamma) * logpt

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [5]:
class FocalLoss02(nn.Module):

    def __init__(self, gamma=0, weight=None, size_average=True):
        super(FocalLoss02, self).__init__()

        self.gamma = gamma
        self.weight = weight
        self.size_average = size_average

    def forward(self, input, target):
        # if input.dim()>2:
        #     input = input.contiguous().view(input.size(0), input.size(1), -1)
        #     input = input.transpose(1,2)
        #     input = input.contiguous().view(-1, input.size(2)).squeeze()
        # if target.dim()==4:
        #     target = target.contiguous().view(target.size(0), target.size(1), -1)
        #     target = target.transpose(1,2)
        #     target = target.contiguous().view(-1, target.size(2)).squeeze()
        # elif target.dim()==3:
        #     target = target.view(-1)
        # else:
        #     target = target.view(-1, 1)

        input, target = input.unsqueeze(0), target.unsqueeze(0)

        # compute.unsqueeze(0)F.cross_entropy(input, target)
        logpt = -F.cross_entropy(input, target)
        pt = torch.exp(logpt)

        # compute the loss
        loss = -((1-pt)**self.gamma) * logpt

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [6]:
class FocalLoss03(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss03, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)

        # compute.unsqueeze(0)F.cross_entropy(input, target)
        p = 1/(1+np.exp(-input))
        loss = -( self.alpha*target + (1-self.alpha)*(1-target) ) * (( 1 - ( target*p + (1-target)*(1-p)) )**self.gamma) * ( target*np.log(p)+(1-target)*np.log(1-p) )

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [7]:
class FocalLoss04(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss04, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):
        
        input, target = input.unsqueeze(0), target.unsqueeze(0)

        loss = -(target * np.log(input) + (1-target) * np.log(1-input)) #BCE

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [8]:
class FocalLoss05(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss05, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        input = torch.clamp(input, min = np.exp(-100)) # so we do not log(0)

        logpt = (target * np.log(input) + (1-target) * np.log(1-input))
        loss = -self.alpha * ((1-np.exp(logpt))**self.gamma) * logpt # for gamma = 0 and alpha = 1 we get the BCELoss

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [9]:
# BEST!!!

class FocalLoss05(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss05, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        input = torch.clamp(input, min = np.exp(-100)) # so we do not log(0)

        logpt = (target * np.log(input) + (1-target) * np.log(1-input))
        loss = -self.alpha * ((1-np.exp(logpt))**self.gamma) * logpt # for gamma = 0 and alpha = 1 we get the BCELoss

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [10]:
class FocalLossClass(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLossClass, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        input = torch.clamp(input, min = torch.exp(torch.tensor(-100))) # so we do not log(0)

        logpt = (target * torch.log(input) + (1-target) * torch.log(1-input))
        loss = -self.alpha * ((1-torch.exp(logpt))**self.gamma) * logpt # for gamma = 0 and alpha = 1 we get the BCELoss

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [11]:
# #WORKS!
# class BalancedFocalLossClass(nn.Module):

#     def __init__(self, gamma=0, alpha=0.5, size_average=True):
#         super(BalancedFocalLossClass, self).__init__()

#         self.gamma = gamma
#         self.alpha = alpha
#         self.size_average = size_average

#     def forward(self, input, target):

#         input, target = input.unsqueeze(0), target.unsqueeze(0)
#         input = torch.clamp(input, min = torch.exp(torch.tensor(-100))) # so we do not log(0)

#         pos = (-self.alpha * (1-input)**self.gamma * torch.log(input))
#         neg = (-(1-self.alpha) * (1-1-input)**self.gamma *  torch.log(1-input))

#         #pos = ( (1-input)**self.gamma * torch.log(input))
#         #neg = ( (input)**self.gamma *  torch.log(1-input))

#         loss = -(pos * target + neg * (1-target))

#         # averaging (or not) loss
#         if self.size_average:
#             return loss.mean()
#         else:
#             return loss.sum()

In [12]:
# class BalancedFocalLossClass(nn.Module):

#     def __init__(self, gamma=0, alpha=0.5, size_average=True):
#         super(BalancedFocalLossClass, self).__init__()

#         self.gamma = gamma
#         self.alpha = alpha
#         self.size_average = size_average

#     def forward(self, input, target):

#         input, target = input.unsqueeze(0), target.unsqueeze(0)
        

#         #for logits
#         # pos = (-self.alpha * (1-F.sigmoid(input))**self.gamma * F.logsigmoid(input))
#         # neg = (-(1-self.alpha) * (-F.sigmoid(input))**self.gamma *  F.logsigmoid(1-input))
#         # loss = (pos * target + neg * (1-target))

#         # for probs
#         input = torch.clamp(input, min = torch.exp(torch.tensor(-100))) # so we do not log(0)
#         pos = (-self.alpha * (1-input)**self.gamma * torch.log(input))
#         neg = (-(1-self.alpha) * (1-1-input)**self.gamma *  torch.log(1-input))
#         loss = (pos * target + neg * (1-target))

#         # averaging (or not) loss
#         if self.size_average:
#             return loss.mean()
#         else:
#             return loss.sum()

In [160]:
# class stableBalancedFocalLossClass(nn.Module):

#     def __init__(self, gamma=0, alpha=0.5, size_average=True):
#         super(stableBalancedFocalLossClass, self).__init__()

#         self.gamma = gamma
#         self.alpha = alpha
#         self.size_average = size_average

#     def forward(self, input, target):

#         input, target = input.unsqueeze(0), target.unsqueeze(0)

#         # fo   r probs
#         min_ind = torch.exp(torch.tensor(-100)) # almost 0
#         max_ind = torch.tensor(1.0)- torch.exp(torch.tensor(-10)) # almost 1
#         input = torch.clamp(input, min = min_ind, max = max_ind) # so we do not log(0)

#         pos = (-self.alpha * (1-input)**self.gamma * torch.log(input))
#         neg = (-(1-self.alpha) * (1-1-input)**self.gamma *  torch.log(1-input))
#         loss = (pos * target + neg * (1-target))

#         # Seem pytorch have something like this..
#         if loss.mean() >= max_ind:
#             floor = 10
#         else:
#             floor = 1

#         loss =  loss * 2 * floor # *2 is just a constant to make it more like BCE

#         # averaging (or not) lossinput = torch.clamp(input, min = torch.exp(torch.tensor(-100))) # so we do not log(0)
        
#         if self.size_average:
#             return loss.mean()
#         else:
#             return loss.sum()

In [12]:
def shannon_entropy(p):
    return (p * np.log(1/p)).sum()

def cross_entropy(p, q):
    return -(p * np.log(q)).sum() # same as (p * np.log(1/q)).sum() 

In [None]:
class BalancedFocalLossClass(nn.Module):

    def __init__(self, gamma=0, alpha=0.5, size_average=True):
        super(BalancedFocalLossClass, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        
        # Numerical stabilityt pytorhc trick.
        log_input = torch.clamp(torch.log(input), -100, 100)
        log_input_rev = torch.clamp(torch.log(1-input), -100, 100)

        # for probs
        pos = (-self.alpha * (1-input)**self.gamma * log_input)
        neg = (-(1-self.alpha) * (1-1-input)**self.gamma * log_input_rev)
        
        loss = (pos * target + neg * (1-target))

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [13]:
criterion_class = nn.BCELoss()

x1 = torch.rand([ 3 ,100, 100])
x1_b = (x1 > 0.5)*1.0 

In [17]:
# criterion_focal01 = FocalLoss01(gamma=1)
# criterion_focal02 = FocalLoss02(gamma=1, size_average=False)
# criterion_focal03 = FocalLoss03(gamma=1, alpha = 1, size_average=True)

# criterion_focal04 = FocalLoss04(gamma=0)
# criterion_focal05 = FocalLoss05()

#criterion_focal06 = FocalLossClass()
criterion_focal07 = BalancedFocalLossClass(gamma=2, alpha=0.05)
# criterion_focal07 = stableBalancedFocalLossClass(gamma=0, alpha=0.5)


# print(criterion_focal01(x1.unsqueeze(0), x1_b.unsqueeze(0)))
# print(criterion_focal02(x1, x1_b))
# print(criterion_focal03(x1, x1_b))


# print(criterion_focal04(x1, x1_b))

# print(criterion_focal05(x1, x1_b))

#print(criterion_focal06(x1, x1_b))

print(criterion_focal07(x1, x1_b)) 


print(criterion_class(x1, x1_b))


tensor(0.0201)
tensor(0.3062)


In [30]:
criterion_focal07 = BalancedFocalLossClass(gamma=0, alpha=0.5)

print(criterion_focal07(x1, x1_b)) 
print(criterion_class(x1, x1_b))

tensor(0.1536)
tensor(0.3071)


In [44]:
x1 = torch.rand([ 3 ,100, 100]) 

x2 = torch.zeros([ 3 ,100, 100]) + torch.tensor(1.0) - torch.exp(torch.tensor(0))

x1_b = (x1 > 0.5)*1.0 


print(criterion_focal07(x2, x1_b)*2) 
print(criterion_class(x2, x1_b))

tensor(49.8000)
tensor(49.8000)


In [45]:
x2

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [19]:
x1 = torch.rand([ 3 ,100, 100]) 

x2 = torch.zeros([ 3 ,100, 100]) + torch.tensor(1.0) - torch.exp(torch.tensor(-12))

x1_b = (x1 > 0.5)*1.0 

if x2.mean() > torch.tensor(1.0) - torch.exp(torch.tensor(-15)):
    floor = 10
else:
    floor = 1

print(criterion_focal07(x2, x1_b) * floor) 
print(criterion_class(x2, x1_b))

tensor(5.6996)
tensor(5.9996)


In [36]:
print(criterion_class)

BCELoss()


In [849]:
criterion_focal_class01 = BalancedFocalLossClass(gamma=0.0, alpha=0.75)
criterion_focal_class02 = BalancedFocalLossClass(gamma=2.0, alpha=0.25)
criterion_focal_class03 = BalancedFocalLossClass(gamma=5.0, alpha=0.25)

In [878]:
print(f'focal, gamma = 0, alpha = 1 : \t {criterion_focal_class01(x1, x1_b)}')
print(f'focal, gamma = 1, alpha = 1 : \t {criterion_focal_class02(x1, x1_b)}')
print(f'focal, gamma = 2, alpha = 1 : \t {criterion_focal_class03(x1, x1_b)}')

focal, gamma = 0, alpha = 1 : 	 0.1558169275522232
focal, gamma = 1, alpha = 1 : 	 0.019834911450743675
focal, gamma = 2, alpha = 1 : 	 -0.0006891923840157688


In [872]:
F.logsigmoid(torch.tensor(.00))

tensor(-0.6931)

In [857]:
# targer
IT = torch.tensor(np.stack([np.identity(100), np.identity(100),  np.identity(100)]), dtype= torch.float32)

# candidates
ZT = torch.zeros([ 3 ,100, 100])
RT = torch.rand([ 3 ,100, 100]) 

# the loss must priorities 0T over ZT

In [863]:
criterion = BalancedFocalLossClass(gamma=2, alpha=0.95)


print(criterion(ZT, IT)) 
print(criterion(RT, IT)) 

tensor(0.9498)
tensor(0.0358)


In [362]:
class FocalLoss_reg(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss_reg, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        input = torch.clamp(input, min = np.exp(-100)) # so we do not log(0)

        mse = (target - input)**2

        loss = self.alpha * ((1-np.exp(mse))**self.gamma) * mse # for gamma = 0 and alpha = 1 we get the mse

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [353]:
criterion_reg = nn.MSELoss()

x1 = torch.rand([ 3 ,100, 100])
x1_b = (x1 > 0.5)*1.0 

In [500]:
criterion_focal_reg01 = FocalLoss_reg(gamma=0)
criterion_focal_reg02 = FocalLoss_reg(gamma=1)
criterion_focal_reg03 = FocalLoss_reg(gamma=2)

In [851]:
# targer
IT = torch.tensor(np.stack([np.identity(100), np.identity(100),  np.identity(100)]), dtype= torch.float32)

# candidates
ZT = torch.zeros([ 3 ,100, 100])
OT = torch.zeros([ 3 ,100, 100]) + 1

# the loss must priorities 0T over ZT

In [501]:
print(f'MSE: \t\t\t\t {criterion_reg(ZT, IT)}')
print(f'focal, gamma = 0, alpha = 1 : \t {criterion_focal_reg01(ZT, IT)}')
print(f'focal, gamma = 1, alpha = 1 : \t {criterion_focal_reg02(ZT, IT)}')
print(f'focal, gamma = 2, alpha = 1 : \t {criterion_focal_reg03(ZT, IT)}')


print('\n')

print(f'focal, gamma = 1, alpha = 1 : \t {criterion_focal_reg01(OT, IT)}')
print(f'focal, gamma = 2, alpha = 1 : \t {criterion_focal_reg02(OT, IT)}')
print(f'focal, gamma = 3, alpha = 1 : \t {criterion_focal_reg03(OT, IT)}')



MSE: 				 0.3157702684402466
focal, gamma = 0, alpha = 1 : 	 24.743696212768555
focal, gamma = 1, alpha = 1 : 	 -inf
focal, gamma = 2, alpha = 1 : 	 inf


focal, gamma = 1, alpha = 1 : 	 -0.33513134717941284
focal, gamma = 2, alpha = 1 : 	 -10.566515922546387
focal, gamma = 3, alpha = 1 : 	 586.1552124023438


In [365]:
def MSE(input, target):
    loss = (target - input)**2
    return loss.mean()

In [415]:
def MSE(input, target):
    se = target - input
    lse = np.exp(se)  #torch.clamp(se, min = np.exp(-100)))
    mlse = lse.mean()

    return mlse

    # return - np.log(loss.mean())

In [421]:
noise = torch.rand([ 3 ,100, 100]) * 10

# targer
IT = torch.tensor(np.stack([np.identity(100), np.identity(100),  np.identity(100)]), dtype= torch.float32) * noise

# candidates
ZT = torch.zeros([ 3 ,100, 100])
OT = (torch.zeros([ 3 ,100, 100]) + 1 ) * noise

# the loss must priorities 0T over ZT

In [422]:
print(MSE(ZT, IT))
print(MSE(OT, IT))

tensor(19.1620)
tensor(0.1097)


In [418]:
OT.min()

tensor(0.0006)

In [426]:
class FocalLoss_reg(nn.Module):

    def __init__(self, gamma=0, alpha=1, size_average=True):
        super(FocalLoss_reg, self).__init__()

        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0)
        #input = torch.clamp(input, min = np.exp(-100)) # could do this for no negatives???

        error = target - input
        exp_error = np.exp(error)  #torch.clamp(se, min = np.exp(-100)))
        loss = exp_error.mean()

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [450]:
class ShrinkageLoss(nn.Module):

    def __init__(self, a=10, c=0.2, size_average=True):
        super(ShrinkageLoss, self).__init__()

        self.a = a
        self.c = c
        self.size_average = size_average

    def forward(self, input, target):

        input, target = input.unsqueeze(0), target.unsqueeze(0) 

        l = torch.abs(target - input)     #F.l1_loss(input, target)

        loss = (l**2)/(1 + torch.exp(self.a*(self.c-l)))

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [451]:
criterion_focal_reg = ShrinkageLoss(a=10, c=0.2)

In [461]:
criterion_focal_reg = ShrinkageLoss(a=1, c=12)

print(criterion_focal_reg(ZT, IT))
print(criterion_focal_reg(OT, IT))


tensor(0.0083)
tensor(1.0195)


In [330]:
x1_b.dtype

torch.float32

In [332]:
IT

tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]]])

In [201]:
shannon_entropy(x1)

tensor(7491.8901)

In [220]:
cross_entropy(x1_b, x1)

tensor(4593.0488)

In [224]:
cross_entropy(x1_b, x1)

tensor(4593.0488)

In [223]:
F.cross_entropy( x1_b, x1)

tensor(223.4796)

In [225]:
F.cross_entropy?

[0;31mSignature:[0m
[0mF[0m[0;34m.[0m[0mcross_entropy[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtarget[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mweight[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msize_average[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mbool[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mignore_index[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;34m-[0m[0;36m100[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduce[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mbool[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduction[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'mean'[0m[0;34m,[0m[0;34m

In [196]:
-(x1_b * np.log(x1)).mean()

tensor(0.1531)

In [197]:
(x1_b * (np.log(1/x1))).mean()

tensor(0.1531)

In [96]:
print(criterion_focal02(x1, x1_b))
print(criterion_class(x1, x1_b))


tensor(1.0988)
tensor(0.3070)


In [3]:
class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))

In [13]:
t = 3

if t == 1 or t == 0:
    print('cool')

In [7]:
cr.

RMSLELoss(
  (mse): MSELoss()
)

In [10]:
cr = RMSLELoss()
print(f'sgoairb: {cr}\n..................')

sgoairb: RMSLELoss(
  (mse): MSELoss()
)
..................


In [24]:
x1 = torch.rand([1, 3 ,100, 100])
x1_b = (x1 > 0.5)*1.0 

In [28]:
x1 = torch.rand([3,100, 100])
x1_b = (x1 > 0.5)*1.0 

In [25]:
criterion_focal = FocalLoss2d()

In [29]:
print(criterion_focal)

FocalLoss2d()


In [35]:
criterion_focal(x1.unsqueeze(0), x1_b.unsqueeze(0))

tensor(1.4365)

In [11]:
x1 = torch.rand([1, 3 ,100, 100])
x1_b = (x1 > 0.5)*1.0 

sigmoid_focal_loss(x1, x1_b, reduction= 'mean')

tensor(0.1051)

In [100]:
x1 = torch.rand([1, 3 ,100, 100])#.type(torch.LongTensor)
x1_b = (x1 > 0.5)*1.0 

criterion_focal(x1.reshape(-1), x1_b.reshape(-1).type(torch.LongTensor))

tensor(-0.2466)

tensor(-0.3178)

In [None]:
x1 = torch.randn([1, 3 ,100, 100]).float()

#x1 = torch.randn([1, 3 ,100, 100]).float()
x2 = torch.randn([1, 3 ,100, 100]).float()

x3 = torch.randn([1, 3 ,100, 100]).float()
x4 = torch.randn([1, 3 ,100, 100]).float()

In [42]:
# x1_ = x1.reshape(-1)
# x2_ = x2.reshape(-1)

# mask = x1_ > 0

# x1_[mask].shape

# criterion_reg(x1_, x2_)


tensor(1.9961)

In [50]:
type(losses_list) == list

True

In [55]:
losses_list = []

for i in range(3):

    x1_ = x1[:,i,:,:].reshape(-1)
    x2_ = x2[:,i,:,:].reshape(-1)
    mask = (x3[:,i,:,:].reshape(-1) > 0.0001) | (x4[:,i,:,:].reshape(-1) > 0.0001)

    losses_list.append(criterion_reg(x1_[mask], x2_[mask]))


In [49]:
losses_list

[tensor(2.0036), tensor(1.9821), tensor(1.9411)]

In [None]:
for i in range(3):
    losses_list.append(criterion_class(t1_pred_class[:,i,:,:], t1_binary[:,i,:,:]))

In [37]:
losses_list = []

for i in range(3):
    print(i)
    losses_list.append(torch.tensor(1.0))

for i in range(3):
    losses_list.append(torch.tensor(2.0))


0
1
2


In [38]:
losses = torch.stack(losses_list)

In [39]:
losses[:3].sum()

tensor(3.)

In [40]:
losses[-3:].sum()


tensor(6.)

In [41]:
for i in range(6):
    print(losses[i])

tensor(1.)
tensor(1.)
tensor(1.)
tensor(2.)
tensor(2.)
tensor(2.)


In [46]:
hidden_channels = 32
D = 16 


X = torch.rand(1, 3, D, D)
H = torch.rand(1, hidden_channels, D, D)

In [48]:
H+X

RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 1

In [61]:
c1 = torch.nn.Conv2d(32,32, 3, padding= 'same')
c2 = torch.nn.Conv2d(32,32, 3, padding= 'same', bias = False)

c3 = torch.nn.Conv2d(32,32, 3, padding= 'same')
c4 = torch.nn.Conv2d(32,32, 3, padding= 'same', bias = False)

c5 = torch.nn.Conv2d(32,32, 3, padding= 'same')
c6 = torch.nn.Conv2d(32,32, 3, padding= 'same', bias = False)

In [63]:
Z = torch.sigmoid(c1(X) + c2(H))

R = torch.sigmoid(c3(X) + c4(H))

H_tilde = torch.tanh(c5(X) + c6(torch.mul(R,H)))

H = torch.mul(torch.mul(Z,H) + (1 - Z), H_tilde)

In [65]:
H.shape

torch.Size([1, 32, 16, 16])

In [7]:
vol = np.zeros([1,48,3,180,180])
tens  =torch.tensor(vol)

torch.stack((tens,tens)).shape

torch.Size([2, 1, 48, 3, 180, 180])

In [2]:
vol = np.zeros([1,48,3,180,180])

In [64]:
vol[:,:,0,:,:] = 1
vol[:,:,1,:,:] = 2
vol[:,:,2,:,:] = 3


In [65]:
print(vol.shape)
print(vol[:,:,0,:,:].mean())
print(vol[:,:,1,:,:].mean())
print(vol[:,:,2,:,:].mean())

(1, 48, 3, 180, 180)
1.0
2.0
3.0


In [79]:
N = vol.shape[0] # batch size. Always 1
C = vol.shape[1] # months
D = vol.shape[2] # features
H = vol.shape[3] # height
W = vol.shape[4] # width

vol2 = vol.reshape(N, C*D, H, W)
print(vol2.shape)

(1, 144, 180, 180)


In [78]:
# transformer = transforms.Compose([transforms.RandomRotation((0,360)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5)])
transformer = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5)])

In [80]:
#data augmentation (can be turned of for final experiments)        
vol2 = transformer(torch.tensor(vol2)) # rotations and flips # skip for now... '''''''''''''''''''''''''''''''''''''''''''''''''''''' bug only take 4 dims.. could just squezze the batrhc dom and then give it again afterwards?#train_tensor = train_tensor.reshape(N, C, D, H, W)

In [81]:
vol3 = vol2.reshape(N,C,D,H,W)
print(vol3.shape)
print(vol3[:,:,0,:,:].mean())
print(vol3[:,:,1,:,:].mean())
print(vol3[:,:,2,:,:].mean())

torch.Size([1, 48, 3, 180, 180])
tensor(1., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(3., dtype=torch.float64)


In [42]:
(vol3 == vol).all()

True