In [None]:
import os, sys
os.chdir('../')

In [2]:
import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
import loss as module_loss
import model.metric as module_metric
import model.model as module_arch

import easydict
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import json
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 

import data_loader.data_loaders as module_data
import model.model as module_arch

from selection.svd_classifier import *
from selection.gmm import *
from selection.util import *

from utils.parse_config import ConfigParser
from utils.util import *
from utils.args import *

In [3]:
config_file = './hyperparams/multistep/config_cifar10_cce_rn34.json'
with open(config_file, 'r') as f:
    config = json.load(f)

# resume_path = './rn34/multistep_asym_40_elr.pth'

In [4]:
def decode(path):
    items = path.split('_')
    noisetype = True if items[1]=='asym' else False
    noiserate = float(items[2]) * 0.01
    
    return noisetype, noiserate, items[3].split('.')[0]

In [5]:
def make_parse(resume_path, config, noise_rate, noisetype):
    parse = easydict.EasyDict({
    "load_name" : resume_path,
    "reinit": False,
    "distill_mode": 'kmeans'
    })
    
    config['trainer']['percent'] = noise_rate
    config['trainer']['asym'] = noisetype
    
    return parse, config

In [6]:
def extract_cleanidx(teacher, data_loader, parse, print_statistics = True):
    teacher.load_state_dict(torch.load('./checkpoint/' + parse.load_name)['state_dict'])
    teacher = teacher.cuda()

    if not parse.reinit: teacher.load_state_dict(torch.load('./checkpoint/' + parse.load_name)['state_dict'])
    for params in teacher.parameters(): params.requires_grad = False
    
    features, labels = get_features(teacher, data_loader)
    clean_labels = fine(current_features=features, current_labels=labels, fit = parse.distill_mode)
    
    if print_statistics: 
        selected, precision, recall, specificity, accuracy = return_statistics(data_loader, clean_labels, datanum=len(labels))
    
    return selected, precision, recall, specificity, accuracy

In [7]:
def make_pd_list(root, config, log_filename):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])
    torch.backends.cudnn.deterministic = True
    np.random.seed(config['seed'])
    
    # load checkpoint path
    pathlist = os.listdir(root)
    pathlist = [path for path in pathlist if ('eigen' not in path) and ('kmeans' not in path) and ('c100') not in path]
    
    # initialize model
    model = module_arch.resnet34(num_classes=10)
    
    # make pandas file
    logcolumns = ['noisetype', 'noiserate', 'lossfunction', 'selected', 'precision', 'recall', 'specificity', 'accuracy']
    log_pd = pd.DataFrame(np.zeros([len(pathlist), len(logcolumns)]), columns = logcolumns)
    
    # write pandas file
    noisetypelst = ['']
    for i in range(len(pathlist)):
        noisetype, noiserate, lossfunction = decode(pathlist[i])
        parse, config = make_parse('./rn34/' + pathlist[i], config, noiserate, noisetype)
        
        # load original dataloader
        data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size= 100,
        shuffle=False,
        validation_split=0.0,
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory'],
        config=config)
        
        selected, precision, recall, specificity, accuracy = extract_cleanidx(model, data_loader, parse)
        log_pd.loc[i] = [str(noisetype), str(noiserate), lossfunction, selected, precision, recall, specificity, accuracy]
        log_pd.to_csv(log_filename)
        
    return log_pd

In [None]:
make_pd_list(root = './checkpoint/rn34/', config=config, log_filename = 'pretrained_statistics.csv')

Files already downloaded and verified
##############
[8 9 1 9 4 8 3 6 3 6]
[3 2 1 1 3 2 2 7 6 5]
Train: 50000 Val: 0


100%|██████████| 500/500 [00:17<00:00, 23.18it/s]
100%|██████████| 10/10 [00:14<00:00,  1.44s/it]
100%|██████████| 50000/50000 [00:00<00:00, 336421.71it/s]
100%|██████████| 500/500 [00:03<00:00, 165.54it/s]


Noisy: 36036, Clean: 13964
Selected samples: 10844 
Precision: 0.8965 
Recall: 0.6962 
Specificity: 0.9689
Accuracy: 0.8927 
Fraction of clean samples/selected samples: 0.8965
Files already downloaded and verified
##############
[3 9 1 1 3 8 2 7 6 5]
[3 2 1 1 3 2 2 7 6 5]
Train: 50000 Val: 0


100%|██████████| 500/500 [00:16<00:00, 29.63it/s]
100%|██████████| 10/10 [00:14<00:00,  1.44s/it]
100%|██████████| 50000/50000 [00:00<00:00, 324385.96it/s]
100%|██████████| 500/500 [00:02<00:00, 169.47it/s]


Noisy: 9006, Clean: 40994
Selected samples: 37773 
Precision: 0.9991 
Recall: 0.9206 
Specificity: 0.9961
Accuracy: 0.9342 
Fraction of clean samples/selected samples: 0.9991
Files already downloaded and verified
##############
[8 9 1 9 4 8 3 6 3 6]
[3 2 1 1 3 2 2 7 6 5]
Train: 50000 Val: 0


100%|██████████| 500/500 [00:16<00:00, 29.91it/s]
 60%|██████    | 6/10 [00:08<00:05,  1.47s/it]