In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
sample_weights = [(x+1)/55 for x in trainset.targets]
trainsampler = WeightedRandomSampler(sample_weights, 10000, replacement=False)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          sampler=trainsampler, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class_frequency = [0] * 10
for iteration, batch in enumerate(trainloader):
    for class_idx in batch[1].tolist():
        class_frequency[class_idx] += 1

print(class_frequency)

Files already downloaded and verified
Files already downloaded and verified
[223, 399, 571, 815, 954, 1132, 1258, 1428, 1533, 1687]


In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1),
                               -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(
                -1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=-1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.detach().exp()

        if self.alpha is not None:
            assert False

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

class LADELoss(nn.Module):
    def __init__(self,
                 num_classes=10,
                 img_max=None,
                 prior=None,
                 prior_txt=None,
                 remine_lambda=0.1):
        super().__init__()
        if img_max is not None or prior_txt is not None:
            self.img_num_per_cls = calculate_prior(
                num_classes, img_max, prior, prior_txt,
                return_num=True).float()
            self.prior = self.img_num_per_cls / self.img_num_per_cls.sum()
        else:
            self.prior = None

        self.balanced_prior = torch.tensor(1. / num_classes).float()
        self.remine_lambda = remine_lambda

        self.num_classes = num_classes
        self.cls_weight = (self.img_num_per_cls.float() / torch.sum(self.img_num_per_cls.float()))

    def mine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        N = x_p.size(-1)
        first_term = torch.sum(x_p, -1) / (num_samples_per_cls + 1e-8)
        second_term = torch.logsumexp(x_q, -1) - np.log(N)

        return first_term - second_term, first_term, second_term

    def remine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        loss, first_term, second_term = self.mine_lower_bound(
            x_p, x_q, num_samples_per_cls)
        reg = (second_term**2) * self.remine_lambda
        return loss - reg, first_term, second_term

    def forward(self, y_pred, target, q_pred=None):
        """
        y_pred: N x C
        target: N
        """
        per_cls_pred_spread = y_pred.T * (target == torch.arange(
            0, self.num_classes).view(-1, 1).type_as(target))  # C x N
        pred_spread = (y_pred - torch.log(self.prior + 1e-9) +
                       torch.log(self.balanced_prior + 1e-9)).T  # C x N

        num_samples_per_cls = torch.sum(
            target == torch.arange(0,
                                   self.num_classes).view(-1,
                                                          1).type_as(target),
            -1).float()  # C
        estim_loss, first_term, second_term = self.remine_lower_bound(
            per_cls_pred_spread, pred_spread, num_samples_per_cls)

        loss = -torch.sum(estim_loss * self.cls_weight)
        return loss

def calculate_prior(num_classes, img_max=None, prior=None, prior_txt=None, reverse=False, return_num=False):
    if prior_txt:
        labels = []
        with open(prior_txt) as f:
            for line in f:
                labels.append(int(line.split()[1]))
        occur_dict = dict(Counter(labels))
        img_num_per_cls = [occur_dict[i] for i in range(num_classes)]
    else:
        img_num_per_cls = []
        for cls_idx in range(num_classes):
            if reverse:
                num = img_max * (prior ** ((num_classes - 1 - cls_idx) / (num_classes - 1.0)))
            else:
                num = img_max * (prior ** (cls_idx / (num_classes - 1.0)))
            img_num_per_cls.append(int(num))
    img_num_per_cls = torch.Tensor(img_num_per_cls)

    if return_num:
        return img_num_per_cls
    else:
        return img_num_per_cls / img_num_per_cls.sum()

In [5]:
ce_loss = nn.CrossEntropyLoss()
focal_loss = FocalLoss()
lade_loss = LADELoss(num_classes=10, prior=0.1, img_max=500)

In [6]:
net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(10):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = ce_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

[1] loss: 1.691
[2] loss: 1.426
[3] loss: 1.350
[4] loss: 1.269
[5] loss: 1.247
[6] loss: 1.217
[7] loss: 1.164
[8] loss: 1.155
[9] loss: 1.135
[10] loss: 1.103
Finished Training
Accuracy of the network on the 10000 test images: 47 %


In [7]:
net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(10):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = focal_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

[1] loss: 1.679
[2] loss: 1.386
[3] loss: 1.300
[4] loss: 1.261
[5] loss: 1.233
[6] loss: 1.194
[7] loss: 1.150
[8] loss: 1.123
[9] loss: 1.100
[10] loss: 1.096
Finished Training
Accuracy of the network on the 10000 test images: 51 %


In [9]:
net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(10):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = lade_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

[1] loss: -1.529
[2] loss: -1.614
[3] loss: -1.657
[4] loss: -1.679
[5] loss: -1.698
[6] loss: -1.696
[7] loss: -1.728
[8] loss: -1.729
[9] loss: -1.730
[10] loss: -1.742
Finished Training
Accuracy of the network on the 10000 test images: 52 %
