In [1]:
import sys
sys.path.insert(0,'PATH')

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn



import torchvision
import torchvision.transforms as transforms
from models import resnet_rotnet


import os
import argparse


import numpy as np

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn import metrics

In [None]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')


In [4]:
from datasets.osr_dataloader import MNIST_OSR, CIFAR10_OSR, CIFAR100_OSR, SVHN_OSR, Tiny_ImageNet_OSR

In [5]:
class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=512, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
        return loss

In [6]:
def get_distance_matrix(
        embeddings: torch.Tensor,  #  [B, E]
    ):
    B = embeddings.size(0)
    dot_product = embeddings @ embeddings.T  # [B, B]
    squared_norm = torch.diag(dot_product) # [B]
    distances = squared_norm.view(1, B) - 2.0 * dot_product + squared_norm.view(B, 1)  # [B, B]
    return torch.sqrt(nn.functional.relu(distances) + 1e-16)  # [B, B]
  
def get_positive_mask(
        labels: torch.Tensor,  # [B]
        device: torch.device
    ):
    B = labels.size(0)
    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]
    indices_equal = torch.eye(B, dtype=torch.bool).cuda()  # [B, B]
    return labels_equal & ~indices_equal  # [B, B]

def get_negative_mask(
        labels: torch.Tensor,  # [B]
        device: torch.device
    ):
    B = labels.size(0)
    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]
    indices_equal = torch.eye(B, dtype=torch.bool).cuda()  # [B, B]
    return ~labels_equal & ~indices_equal  # [B, B]

def get_triplet_mask(
        labels: torch.Tensor,  # [B]
        device: torch.device
    ):

    B = labels.size(0)

    # Make sure that i != j != k
    indices_equal = torch.eye(B, dtype=torch.bool).cuda()  # [B, B]
    indices_not_equal = ~indices_equal  # [B, B]
    i_not_equal_j = indices_not_equal.view(B, B, 1)  # [B, B, 1]
    i_not_equal_k = indices_not_equal.view(B, 1, B)  # [B, 1, B]
    j_not_equal_k = indices_not_equal.view(1, B, B)  # [1, B, B]
    distinct_indices = i_not_equal_j & i_not_equal_k & j_not_equal_k  # [B, B, B]

    # Make sure that labels[i] == labels[j] but labels[i] != labels[k]
    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]
    i_equal_j = labels_equal.view(B, B, 1)  # [B, B, 1]
    i_equal_k = labels_equal.view(B, 1, B)  # [B, 1, B]
    valid_labels = i_equal_j & ~i_equal_k  # [B, B, B]

    return distinct_indices & valid_labels  # [B, B, B]

def test_get_distance_matrix(device_for_tests):
    embeddings = torch.FloatTensor(
        [[1, 1], 
        [7, 7], 
        [1, 1]], 
    ).to(device=device_for_tests)
    distance_matrix = get_distance_matrix(embeddings)
    assert torch.allclose(
        torch.diag(distance_matrix), 
        torch.zeros(3, device=device_for_tests)
    )
    assert torch.allclose(distance_matrix, distance_matrix.T)
    assert distance_matrix[0, 2] < distance_matrix[0, 1]


def test_get_positive_mask(device_for_tests):
    labels = torch.LongTensor([1, 2, 3, 1])
    pos_mask = get_positive_mask(labels, device_for_tests)
    assert pos_mask[0, 3]
    assert not pos_mask[0, 1]
    assert not pos_mask[0, 0] and not pos_mask[1, 1]


def test_get_negative_mask(device_for_tests):
    labels = torch.LongTensor([1, 2, 3, 1])
    neg_mask = get_negative_mask(labels, device_for_tests)
    assert not neg_mask[0, 3]
    assert neg_mask[0, 1]
    assert not neg_mask[0, 0] and not neg_mask[1, 1]


def test_get_triplet_mask(device_for_tests):
    labels = torch.LongTensor([1, 2, 3, 1, 3])
    mask = get_triplet_mask(labels, device_for_tests)
    assert mask[0, 3, 2]
    assert mask[2, 4, 1]
    assert mask[4, 2, 0]
    assert not mask[0, 0, 0]
    assert not mask[0, 3, 3]
    assert not mask[0, 0, 4]

In [7]:
class TripletLoss(nn.Module):
    
    def __init__(self, resnet: nn.Module):
        super().__init__()
        self.resnet = resnet
        self.resnet.fc = nn.Identity()
        self.embeddings = nn.Linear(512, 10).cuda()
        
    def forward(
            self, 
            inputs: torch.Tensor,  # [B, C, H, W]
            labels: torch.Tensor  # [B]
        ):
        B = labels.size(0)
        embeddings = self.embeddings(inputs)  # [B, E]
        distance_matrix = get_distance_matrix(embeddings)  # [B, B]
        with torch.no_grad():
            mask_pos = get_positive_mask(labels, device)  # [B, B]
            mask_neg = get_negative_mask(labels, device)  # [B, B]
            triplet_mask = get_triplet_mask(labels, device)  # [B, B, B]
            unmasked_triplets = torch.sum(triplet_mask)  # [1]
            mu_pos = torch.mean(distance_matrix[mask_pos])  # [1]
            mu_neg = torch.mean(distance_matrix[mask_neg])  # [1]
            mu = mu_neg - mu_pos  # [1]
        
        distance_i_j = distance_matrix.view(B, B, 1)  # [B, B, 1]
        distance_i_k = distance_matrix.view(B, 1, B)  # [B, 1, B]
        triplet_loss_unmasked = distance_i_k - distance_i_j   # [B, B, B]
        triplet_loss_unmasked = triplet_loss_unmasked[triplet_mask] # [valid_triplets]
        hardest_triplets = triplet_loss_unmasked < max(mu, 0)  # [valid_triplets]
        triplet_loss = triplet_loss_unmasked[hardest_triplets]  # [valid_triplets_after_mask]
        triplet_loss = nn.functional.relu(triplet_loss)  # [valid_triplets_after_mask]

        loss = triplet_loss.mean()
        """
        logs = {
            'positive_pairs': torch.sum(mask_pos).cpu().detach().item(),
            'negative_pairs': torch.sum(mask_neg).cpu().detach().item(),
            'mu_neg': mu_neg.cpu().detach().item(),
            'mu_pos': mu_pos.cpu().detach().item(),
            'valid_triplets': unmasked_triplets.cpu().detach().item(),
            'valid_triplets_after_mask': triplet_loss.size(0),
            'triplet_loss': triplet_loss.mean().cpu().detach().item()
        }
        """
        return loss

In [8]:
def load_dataset(options):
    print("{} Preparation".format(options['dataset']))

    if 'mnist' in options['dataset']:
        Data = MNIST_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'],
                         img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'cifar10' == options['dataset']:
        Data = CIFAR10_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'],
                           img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'svhn' in options['dataset']:
        Data = SVHN_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'],
                        img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'cifar100' in options['dataset']:
        Data = CIFAR10_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'],
                           img_size=options['img_size'])
        trainloader, testloader = Data.train_loader, Data.test_loader
        out_Data = CIFAR100_OSR(known=options['unknown'], dataroot=options['dataroot'],
                                batch_size=options['batch_size'], img_size=options['img_size'])
        outloader = out_Data.test_loader
    else:
        Data = Tiny_ImageNet_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'],
                                 img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader


    return Data, trainloader, testloader, outloader

In [11]:

# openset 탐지 능력을 검증하는 코드들 입니다.
# ------------------------------------------------------------------------------
def evaluate_openset(networks, dataloader_on, dataloader_off, **options):

    # closed-set test-data에서 softmax-max값을 추출하여 저장합니다.
    d_scores_on = get_openset_scores(dataloader_on, networks, open=False ,**options)

    # open-set test-data에서 softmax-max값을 추출하여 저장합니다.
    d_scores_off = get_openset_scores(dataloader_off, networks, open=True ,**options)


    # closed-set을 클래스 '0' open-set을 클래스 '1'로 지정하여 label을 생성합니다.
    y_true = np.array([0] * len(d_scores_on) + [1] * len(d_scores_off))

    # 각 레이블당 confidence (softmax-max값)을 할당하여 저장합니다.
    y_score = np.concatenate([d_scores_on, d_scores_off])

    # 생성한 label값과 이에 해당하는 confidence값을 이용하여 AUROC값을 추출합니다.
    auc_score = roc_auc_score(y_true, y_score)


    #metrics.confusion_matrix(target_all, pred_all, labels=range(num_classes))

    return auc_score

In [12]:

def get_openset_scores(dataloader, networks,open, dataloader_train=None, **options):

    #위 코드에서 사용되는 함수로 softmax의 max값을 추출하는 함수입니다.
    openset_scores = openset_softmax_confidence(dataloader, networks,open=open)
    return openset_scores

In [13]:
def openset_softmax_confidence(dataloader, netC, open=False):

    # softmax의 max값을 추출하여 저장하는 부분입니다.

    # 먼저 값을 저장할 list를 선언합니다.
    openset_scores = []
    pred_all = []
    target_all = []

    #openset_prediction =print(metrics.confusion_matrix(target_all, pred_all, labels=range(num_classes)))

    #dataloader를 통해서 data를 받으면서 softmax값을 추출하고 이의 max값을 저장해줍니다.
    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            if torch.cuda.is_available():
                images = images.cuda()

            #수정 전
            # preds = F.softmax(netC(images), dim=1)

            #수정 후
            logits, features = netC(images)
            preds = F.softmax(logits, dim=1)

            pred_all.extend(preds.max(dim=1)[1].data.cpu().numpy())
            target_all.extend(labels.data.cpu().numpy())

            openset_scores.extend(preds.max(dim=1)[0].data.cpu().numpy())


    # 마지막에 '-'를 붙여서 return하는 이유는 다음과 같습니다.
    # 위에서 closed-set을 '0' 클래스, open-set을 '1' 클래스로 정의하였습니다.
    # 이때 confidence값이 작으면 '0' 클래스, 크면 '1' 클래스로 지정되도록 현재 AUROC 계산함수는 인식합니다.
    # 그러나 softmax-max output값은 closed-set ('0')이 큰 값을 가지고 open-set ('1')이 작은 값을 가집니다.
    # 때문에 AUROC 함수가 인식하는 결과에 맞게 -를 붙여서 closed-set('0')이 작은 값, open-set ('1')은 큰값이 되도록 합니다.

    # if open == True:
    #     print("Open Confusion matrix")
    #     print(metrics.confusion_matrix(target_all, pred_all, labels=range(num_classes)))

    return -np.array(openset_scores)

In [14]:

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [15]:
def create_Rot_batch(inputs):
       num_trans=8

       rot0 = inputs
        rot90 = torch.rot90(inputs, 1, [2, 3])
        rot180 = torch.rot90(inputs, 2, [2, 3])
        rot270 = torch.rot90(inputs, 3, [2, 3])


        '''
        im90=rot90[0].cpu().numpy().transpose([1,2,0])
        cv2.imwrite('./check_Rot_samples/90.png', np.clip(im90 * 255, 0, 255).astype(np.uint8))
        im180 = rot180[0].cpu().numpy().transpose([1, 2, 0])
        cv2.imwrite('./check_Rot_samples/180.png', np.clip(im180 * 255, 0, 255).astype(np.uint8))
        im270 = rot270[0].cpu().numpy().transpose([1, 2, 0])
        cv2.imwrite('./check_Rot_samples/270.png', np.clip(im270 * 255, 0, 255).astype(np.uint8))
        '''

        rot0_flip = torch.flip(rot0, [2])
        rot90_flip = torch.flip(rot90, [2])
        rot180_flip = torch.flip(rot180, [2])
        rot270_flip = torch.flip(rot270, [2])

        rot0_label = torch.zeros(inputs.size(0), 1)
        rot90_label = rot0_label + 1
        rot180_label = rot0_label + 2
        rot270_label = rot0_label + 3
        rot0_flip_label = rot0_label + 4
        rot90_flip_label = rot0_label + 5
        rot180_flip_label = rot0_label + 6
        rot270_flip_label = rot0_label + 7

        rot_data_cat = torch.stack((rot0, rot90, rot180, rot270, rot0_flip, rot90_flip, rot180_flip, rot270_flip),
                                   dim=0)
        rot_label_cat = torch.stack((rot0_label, rot90_label, rot180_label, rot270_label, rot0_flip_label,
                                     rot90_flip_label, rot180_flip_label, rot270_flip_label), dim=0)

        rot_data_cat = torch.transpose(rot_data_cat, 0, 1)
        rot_label_cat = torch.transpose(rot_label_cat, 0, 1)

        idx = torch.randint(num_trans, size=(rot_data_cat.size(0),))

        sample_rot_data_batch = rot_data_cat[torch.arange(rot_data_cat.size(0)), idx]
        sample_rot_label_batch = rot_label_cat[torch.arange(rot_label_cat.size(0)), idx]

        return sample_rot_data_batch, sample_rot_label_batch


In [16]:
def train(epoch,trainloader,f, CE_COEF, CEN_COEF, TRP_COEF):
    print('\nEpoch: %d' % epoch)
    print("Current lr : {}".format(get_lr(optimizer)))

    f.write('\nEpoch: %d \n' % epoch)
    f.write("Current lr : {} \n".format(get_lr(optimizer)))

    net.train()
    train_loss = 0
    cls_loss=0
    rot_loss=0
    correct = 0
    total = 0


    features_list=[]
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        targets = targets.long()
        inputs ,targets = inputs.to(device) ,targets.to(device)

        Rot_inputs,Rot_labels =create_Rot_batch(inputs)

        Rot_inputs, Rot_labels = Rot_inputs.to(device), Rot_labels.to(device)

        Rot_labels = torch.reshape(Rot_labels, (-1,))
        Rot_labels = Rot_labels.long()

        optimizer.zero_grad()
        outputs,features = net(inputs)
        Rot_outputs, Rot_features = net(Rot_inputs, Rot=True)


        # print(net.linear.weight)
        # print(net.linear.weight.shape)




        # print(features.shape)



        c_loss = criterion(outputs, targets) * CE_COEF
        c_loss += criterion_centerloss(features, targets) * CEN_COEF
        c_loss += criterion_tripletloss(features, targets) * TRP_COEF
        Rot_loss = criterion(Rot_outputs,Rot_labels )


        total_loss = c_loss+Rot_loss*0.3

        total_loss.backward()
        optimizer.step()

        cls_loss += c_loss
        rot_loss += Rot_loss*0.3
        train_loss += total_loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
      


    print("Train result")
    print('Loss: %.3f , cls_loss: %.3f , rot_loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1),cls_loss/(batch_idx+1) ,rot_loss/(batch_idx+1) ,100.*correct/total, correct, total))
    
    f.write("Train result \n")
    f.write('Loss: %.3f , cls_loss: %.3f , rot_loss: %.3f | Acc: %.3f%% (%d/%d) \n' % (train_loss/(batch_idx+1),cls_loss/(batch_idx+1) ,rot_loss/(batch_idx+1), 100.*correct/total, correct, total))

In [17]:
def test(epoch,testloader,f, CE_COEF, CEN_COEF, TRP_COEF):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    pred_all = []
    target_all = []

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            targets = targets.long()
            inputs, targets = inputs.to(device), targets.to(device)
            outputs,features = net(inputs)
            
            loss = criterion(outputs, targets) * CE_COEF
            loss += criterion_centerloss(features, targets) * CEN_COEF
            loss += criterion_tripletloss(features, targets) * TRP_COEF

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pred_all.extend(predicted.data.cpu().numpy())
            target_all.extend(targets.data.cpu().numpy())

    print("Test result")
    print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    f.write("Test result \n")
    f.write('Loss: %.3f | Acc: %.3f%% (%d/%d) \n' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))


In [None]:
print(os.getcwd())
print(os.listdir(os.getcwd()))

In [19]:
import numpy as np
splits = {
    'mnist': [
        [2, 4, 5, 9, 8, 3],
        [3, 2, 6, 9, 4, 0],
        [5, 8, 3, 2, 4, 6],
        [3, 7, 8, 4, 0, 5],
        [6, 3, 4, 9, 8, 2]
    ],
    'svhn': [
        [5, 3, 7, 2, 8, 6],
        [3, 8, 7, 6, 2, 5],
        [8, 9, 4, 7, 2, 1],
        [3, 8, 2, 5, 0, 6],
        [4, 9, 2, 7, 1, 0]
    ],
    'cifar10': [
        [0, 6, 4, 9, 1, 7],
        [7, 6, 4, 9, 0, 1],
        [1, 5, 7, 3, 9, 4],
        [8, 6, 1, 9, 0, 7],
        [2, 4, 1, 7, 9, 6]
    ],
    'cifar100': [
        [4, 7, 9, 1],
        [6, 7, 1, 9],
        [9, 6, 1, 7],
        [6, 4, 9, 1],
        [1, 0, 9, 8]
    ],
    'cifar100-10': [
        [30, 25, 1, 9, 8, 0, 46, 52, 49, 71],
        [41, 9, 49, 40, 73, 60, 48, 30, 95, 71],
        [8, 9, 49, 40, 73, 60, 48, 95, 30, 71],
        [95, 60, 30, 73, 46, 49, 68, 99, 8, 71],
        [33, 2, 3, 97, 46, 21, 64, 63, 88, 43]
    ],
    'cifar100-50': [
        [27, 94, 29, 77, 88, 26, 69, 48, 75, 5, 59, 93, 39, 57, 45, 40, 78, 20, 98, 47, 66, 70, 91, 76, 41, 83, 99, 32, 53, 72, 2, 95, 21, 73, 84, 68, 35, 11, 55, 60, 30, 25, 1, 9, 8, 0, 46, 52, 49, 71],
        [65, 97, 86, 24, 45, 67, 2, 3, 91, 98, 79, 29, 62, 82, 33, 76, 0, 35, 5, 16, 54, 11, 99, 52, 85, 1, 25, 66, 28, 84, 23, 56, 75, 46, 21, 72, 55, 68, 8, 69, 41, 9, 49, 40, 73, 60, 48, 30, 95, 71],
        [20, 83, 65, 97, 94, 2, 93, 16, 67, 29, 62, 33, 24, 98, 5, 86, 35, 54, 0, 91, 52, 66, 85, 84, 56, 11, 1, 76, 25, 55, 21, 99, 72, 41, 23, 75, 28, 68, 69, 46, 8, 9, 49, 40, 73, 60, 48, 95, 30, 71],
        [92, 82, 77, 64, 5, 33, 62, 56, 70, 0, 20, 28, 67, 14, 84, 53, 91, 29, 85, 2, 52, 83, 75, 35, 11, 21, 72, 98, 55, 1, 41, 76, 25, 66, 69, 9, 48, 54, 40, 23, 95, 60, 30, 73, 46, 49, 68, 99, 8, 71],
        [47, 6, 19, 0, 62, 93, 59, 65, 54, 70, 34, 55, 23, 38, 72, 76, 53, 31, 78, 96, 77, 27, 92, 18, 82, 50, 98, 32, 1, 75, 83, 4, 51, 35, 80, 11, 74, 66, 36, 42, 33, 2, 3, 97, 46, 21, 64, 63, 88, 43]
    ],
    'cifar100-100': [
        np.arange(100).tolist(),
        np.arange(100).tolist(),
        np.arange(100).tolist(),
        np.arange(100).tolist(),
        np.arange(100).tolist()
    ],
    'tiny_imagenet': [
        [108, 147, 17, 58, 193, 123, 72, 144, 75, 167, 134, 14, 81, 171, 44, 197, 152, 66, 1, 133],
        [198, 161, 91, 59, 57, 134, 61, 184, 90, 35, 29, 23, 199, 38, 133, 19, 186, 18, 85, 67],
        [177, 0, 119, 26, 78, 80, 191, 46, 134, 92, 31, 152, 27, 60, 114, 50, 51, 133, 162, 93],
        [98, 36, 158, 177, 189, 157, 170, 191, 82, 196, 138, 166, 43, 13, 152, 11, 75, 174, 193, 190],
        [95, 6, 145, 153, 0, 143, 31, 23, 189, 81, 20, 21, 89, 26, 36, 170, 102, 177, 108, 169]
    ]
}

In [None]:
li = []
ce = 100
while ce >= 50:
    cen = 100 - ce
    trp = 0
    while trp <= 100 - ce:
        li.append([ce, cen, trp])
        cen -= 10
        trp += 10
    ce -= 10
li = np.asarray(li, dtype=float)*0.01

In [None]:
EPOCH = 10
if __name__=='__main__':
    #실제 코드 실행하는 부분입니다.
    options={}

    options['dataset']='cifar10'
    options['dataroot']='./data'
    options['batch_size']=64




    img_size = 32
    
    for ce_coef, cen_coef, trp_coef in li:
        print('==> Building model..')

        num_classes = 10
        net = resnet_rotnet.ResNet18()
        #net.fc =nn.Linear(512,num_classes)
        net = net.to(device)

        criterion = nn.CrossEntropyLoss()
        criterion_centerloss = CenterLoss()
        criterion_tripletloss = TripletLoss(net)
        optimizer = optim.SGD(net.parameters(), lr=0.1,
                              momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

        f1 = "{:.2f}".format(ce_coef)
        f2 = "{:.2f}".format(cen_coef)
        f3 = "{:.2f}".format(trp_coef)
        AUROC_list=[]

        with open("PATH/log/RotNet_{}_{}_{}_{}.txt".format(options['dataset'], f1, f2, f3), 'w') as f:
            print("ce_coef :{:.2f}".format(ce_coef))
            print("cen_coef :{:.2f}".format(cen_coef))
            print("trp_coef :{:.2f}".format(trp_coef))
            print()
            f.write("ce_coef :{:.2f}\n".format(ce_coef))
            f.write("cen_coef :{:.2f}\n".format(cen_coef))
            f.write("trp_coef :{:.2f}\n".format(trp_coef))
            f.write("\n")
            for i in range(len(splits[options['dataset']])):
                known = splits[options['dataset']][len(splits[options['dataset']]) - i - 1]
                unknown = list(set(list(range(0, 10))) - set(known))

                options.update(
                    {
                        'item': i,
                        'known': known,
                        'unknown': unknown,
                        'img_size': img_size
                    }
                )

                Data, trainloader, testloader, outloader = load_dataset(options)

                for epoch in range(start_epoch, start_epoch+EPOCH):
                    print("dataset :{} , split: {}".format(options['dataset'], i))
                    f.write("dataset :{} , split: {} \n".format(options['dataset'], i))
                    train(epoch,trainloader,f, ce_coef, cen_coef, trp_coef) #train 함수 호출
                    test(epoch,testloader,f, ce_coef, cen_coef, trp_coef)  #test 함수 호출

                    # 앞서 보았던 evaludate_openset함수를 실행하고 output인 auroc값을 출력
                    # 이때 입력으로는 network, closed-testloader, open-testloader를 줌.

                    cur_auroc=evaluate_openset(net,testloader,outloader)

                    print("AUROC : {:.2f} ".format(cur_auroc))
                    print("")
                    f.write("AUROC : {:.2f} \n\n".format(cur_auroc))


                    scheduler.step()
                AUROC_list.append(cur_auroc)

            AUROC=np.asarray(AUROC_list)
            print(f"AUROC list : {AUROC}")
            print("split mean AUROC :{:.2f}".format(np.mean(AUROC)))

            f.write(f"AUROC list : {AUROC}\n")
            f.write("split mean AUROC :{:.2f}\n".format(np.mean(AUROC)))
            f.write("\n")
