In [1]:
!nvidia-smi

Sat Jun 19 05:36:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
from collections import deque
import numpy as np


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, 3, 1)
        self.conv1_2 = nn.Conv2d(32, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(num_features=32)
        self.conv2_1 = nn.Conv2d(32, 64, 3, 1)
        self.conv2_2 = nn.Conv2d(64, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(num_features=64)
        self.conv3_1 = nn.Conv2d(64, 64, 3, 1)
        self.conv3_2 = nn.Conv2d(64, 32, 3, 1)
        self.bn3 = nn.BatchNorm2d(num_features=32)
        self.fc1 = nn.Linear(32 * 16 * 16, 10)

    def forward(self, x):
        x = self.bn1(F.relu(self.conv1_2(F.relu(self.conv1_1(x)))))
        x = F.dropout(x, p=0.2)
        x = self.bn2(F.relu(self.conv2_2(F.relu(self.conv2_1(x)))))
        x = F.dropout(x, p=0.2)
        x = self.bn3(F.relu(self.conv3_2(F.relu(self.conv3_1(x)))))
        x = F.dropout(x, p=0.2)
        x = self.fc1(x.view([-1, 32 * 16 * 16]))
        return x


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=128, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Net().to(device)
model.to(device)

optimizer = torch.optim.Adam(model.parameters())


def calc_circle_loss(y_predict, y_target):
    # y_predict = 10 * y_predict
    gamma = 1e5
    y_predict_negative = torch.masked_fill(y_predict, mask=y_target.bool(), value=-gamma)
    y_predict_positive = torch.masked_fill(y_predict, mask=torch.logical_not(y_target.bool()), value=gamma)
    y_predict_zeros = torch.zeros_like(y_predict[..., :1])
    y_predict_negative = torch.cat([y_predict_negative, y_predict_zeros], dim=-1)
    y_predict_positive = torch.cat([y_predict_positive, y_predict_zeros], dim=-1)
    loss = torch.logsumexp(y_predict_negative, dim=-1) + torch.logsumexp(-y_predict_positive, dim=-1)
    loss = loss.mean()
    return loss


def calc_bce_loss(y_predict, y_target):
    return F.binary_cross_entropy_with_logits(y_predict, y_target.float())


for epoch in range(10):
    # train
    model.train()
    loss_count = deque([], maxlen=100)
    recall_count = deque([], maxlen=100)
    precision_count = deque([], maxlen=100)
    pbar = tqdm(train_loader, position=0, leave=True)
    pbar.set_description("train epoch {}".format(epoch))
    for data, y_target in pbar:
        data, y_target = data.to(device), F.one_hot(y_target, 10).to(device)
        # print(data.shape, y_target.shape) torch.Size([16, 1, 28, 28]) torch.Size([16, 10])
        optimizer.zero_grad()

        y_predict = model(data)
        loss = calc_circle_loss(y_predict, y_target)
        # loss = calc_bce_loss(y_predict, y_target)

        loss.backward()
        optimizer.step()
        recall = torch.sum(torch.logical_and(torch.gt(y_predict, 0), y_target)) / (torch.sum(y_target) + 1e-5)
        precision = torch.sum(torch.logical_and(torch.gt(y_predict, 0), y_target)) / (torch.sum(torch.gt(y_predict, 0)) + 1e-5)

        loss_count.append(loss.item())
        recall_count.append(recall.item())
        precision_count.append(precision.item())

        log_str = "loss={},recall={},precision={}".format(np.mean(loss_count), np.mean(recall_count), np.mean(precision_count))
        pbar.set_postfix_str(log_str)
    # test
    model.eval()
    loss_count = []
    recall_count = []
    precision_count = []
    pbar = tqdm(test_loader, position=0, leave=True)
    pbar.set_description("test epoch {}".format(epoch))
    for data, y_target in pbar:
        data, y_target = data.to(device), F.one_hot(y_target, 10).to(device)
        y_predict = model(data)

        loss = calc_circle_loss(y_predict, y_target)
        # loss = calc_bce_loss(y_predict, y_target)

        recall = torch.sum(torch.logical_and(torch.gt(y_predict, 0), y_target)) / (torch.sum(y_target) + 1e-5)
        precision = torch.sum(torch.logical_and(torch.gt(y_predict, 0), y_target)) / (torch.sum(torch.gt(y_predict, 0)) + 1e-5)

        loss_count.append(loss.item())
        recall_count.append(recall.item())
        precision_count.append(precision.item())

        log_str = "loss={},recall={},precision={}".format(np.mean(loss_count), np.mean(recall_count), np.mean(precision_count))
        pbar.set_postfix_str(log_str)


train epoch 0: 100%|██████████| 469/469 [00:13<00:00, 34.33it/s, loss=0.13030975576490164,recall=0.9774998813867569,precision=0.9799993526935578]
test epoch 0: 100%|██████████| 79/79 [00:01<00:00, 43.85it/s, loss=0.1073021200320483,recall=0.9817047805725774,precision=0.9817139498795135]
train epoch 1: 100%|██████████| 469/469 [00:13<00:00, 34.67it/s, loss=0.10946824367158114,recall=0.9824738395214081,precision=0.9858382046222687]
test epoch 1: 100%|██████████| 79/79 [00:01<00:00, 43.35it/s, loss=0.1145825153761956,recall=0.9867482932308053,precision=0.9766291983519928]
train epoch 2: 100%|██████████| 469/469 [00:13<00:00, 34.38it/s, loss=0.08248568762093783,recall=0.9857811313867569,precision=0.986984156370163]
test epoch 2: 100%|██████████| 79/79 [00:01<00:00, 40.49it/s, loss=0.08956730428823753,recall=0.9846715519699869,precision=0.9865026677711101]
train epoch 3: 100%|██████████| 469/469 [00:13<00:00, 34.84it/s, loss=0.07188060081098228,recall=0.9878123813867569,precision=0.98820108

KeyboardInterrupt: ignored