In [1]:
import os, sys
os.chdir('../')

In [2]:
import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
import loss as module_loss
import model.metric as module_metric
import model.model as module_arch

import easydict
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import json
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 

import data_loader.data_loaders as module_data
import model.model as module_arch

from selection.svd_classifier import *
from selection.gmm import *
from selection.util import *

from utils.parse_config import ConfigParser
from utils.util import *
from utils.args import *

In [3]:
config_file = './hyperparams/multistep/config_cifar10_cce_rn34.json'
with open(config_file, 'r') as f:
    config = json.load(f)

# resume_path = './rn34/multistep_asym_40_elr.pth'

In [4]:
model = module_arch.resnet34(num_classes=10)

In [6]:
parse, config = make_parse(resume_path, config, 0.4, True)

In [8]:
data_loader = getattr(module_data, config['data_loader']['type'])(
    config['data_loader']['args']['data_dir'],
    batch_size= 100,
    shuffle=config['data_loader']['args']['shuffle'],
    validation_split=0.0,
    num_batches=config['data_loader']['args']['num_batches'],
    training=True,
    num_workers=config['data_loader']['args']['num_workers'],
    pin_memory=config['data_loader']['args']['pin_memory'],
    config=config
)

Files already downloaded and verified
##############
[3 2 1 1 3 0 0 7 6 5]
[3 2 1 1 3 2 2 7 6 5]
Train: 50000 Val: 0


In [11]:
selected, precision, recall, specificity, accuracy = extract_cleanidx(model, data_loader, parse)

100%|██████████| 500/500 [00:16<00:00, 22.15it/s]
100%|██████████| 10/10 [00:15<00:00,  1.37s/it]
100%|██████████| 50000/50000 [00:00<00:00, 345861.51it/s]
100%|██████████| 500/500 [00:02<00:00, 167.56it/s]

Selected samples: 31607 
Precision: 0.8145 
Recall: 0.6314 
Specificity: 0.3645
Accuracy: 0.5821 
Fraction of clean samples/selected samples: 0.8145





(31607, 0.8145, 0.6314, 0.3645, 0.5821)

In [15]:
pathlist = os.listdir('./checkpoint/rn34/')
pathlist = [path for path in pathlist if ('eigen' not in path) and ('kmeans' not in path) and ('c100') not in path]

In [28]:
logcolumns = ['name', 'selected', 'precision', 'recall', 'specificity', 'accuracy']
log_pd = pd.DataFrame(np.zeros([1, len(logcolumns)]), columns = logcolumns)

In [29]:
log_pd.loc[0] = [pathlist[0], 31607, 0.8145, 0.6314, 0.3645, 0.5821]

In [30]:
log_pd

Unnamed: 0,name,selected,precision,recall,specificity,accuracy
0,multistep_sym_80_elr.pth,31607.0,0.8145,0.6314,0.3645,0.5821


In [19]:
def decode(path):
    items = path.split('_')
    noisetype = True if items[1]=='asym' else False
    noiserate = float(items[2]) * 0.01
    
    return noisetype, noiserate, items[3].split('.')[0]

In [5]:
def make_parse(resume_path, config, noise_rate, noisetype):
    parse = easydict.EasyDict({
    "load_name" : resume_path,
    "reinit": False,
    "distill_mode": 'kmeans'
    })
    
    config['trainer']['percent'] = noise_rate
    config['trainer']['asym'] = noisetype
    
    return parse, config

In [6]:
def extract_cleanidx(teacher, data_loader, parse, print_statistics = True):
    teacher.load_state_dict(torch.load('./checkpoint/' + parse.load_name)['state_dict'])
    teacher = teacher.cuda()

    if not parse.reinit: teacher.load_state_dict(torch.load('./checkpoint/' + parse.load_name)['state_dict'])
    for params in teacher.parameters(): params.requires_grad = False
    
    features, labels = get_features(teacher, data_loader)
    clean_labels = fine(current_features=features, current_labels=labels, fit = parse.distill_mode)
    
    if print_statistics: 
        selected, precision, recall, specificity, accuracy = return_statistics(data_loader, clean_labels, datanum=len(labels))
    
    return selected, precision, recall, specificity, accuracy

In [17]:
def make_pd_list(root, config, log_filename):
    # load checkpoint path
    pathlist = os.listdir(root)
    pathlist = [path for path in pathlist if ('eigen' not in path) and ('kmeans' not in path) and ('c100') not in path]
    
    # load original dataloader
    data_loader = getattr(module_data, config['data_loader']['type'])(
    config['data_loader']['args']['data_dir'],
    batch_size= 100,
    shuffle=config['data_loader']['args']['shuffle'],
    validation_split=0.0,
    num_batches=config['data_loader']['args']['num_batches'],
    training=True,
    num_workers=config['data_loader']['args']['num_workers'],
    pin_memory=config['data_loader']['args']['pin_memory'],
    config=config)
    
    # initialize model
    model = module_arch.resnet34(num_classes=10)
    
    # make pandas file
    logcolumns = ['noisetype', 'noiserate', 'lossfunction', 'selected', 'precision', 'recall', 'specificity', 'accuracy']
    log_pd = pd.DataFrame(np.zeros([len(pathlist), len(logcolumns)]), columns = logcolumns)
    
    # write pandas file
    for i in range(len(pathlist)):
        noisetype, noiserate, lossfunction = decode(pathlist[i])
        parse, config = make_parse('./rn34/' + pathlist[i], config, noiserate, noisetype)
        selected, precision, recall, specificity, accuracy = extract_cleanidx(model, data_loader, parse)
        log_pd.loc[i] = [str(noisetype), str(noiserate), lossfunction, pathlist[i].split('.')[0], selected, precision, recall, specificity, accuracy]
        log_pd.to_csv(log_filename)
        
    return log_pd

In [None]:
make_pd_list(root = './checkpoint/rn34/', config=config, log_filename = 'pretrained_statistics.csv')

Files already downloaded and verified
##############
[8 9 1 9 3 8 3 7 3 5]
[3 2 1 1 3 2 2 7 6 5]
Train: 50000 Val: 0


100%|██████████| 500/500 [01:31<00:00,  5.39it/s]
 20%|██        | 2/10 [00:03<00:12,  1.52s/it]