In [9]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
import torch
from torchsummary import summary
torch.cuda.set_device(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

from models.custom_resnet import *
from utils import _get_accuracy

In [10]:
def check(model_name, dataset) :
    if dataset == 'imagenette' : 
        path = untar_data(URLs.IMAGENETTE)
    elif dataset == 'cifar10' : 
        path = untar_data(URLs.CIFAR)
    elif dataset == 'imagewoof' : 
        path = untar_data(URLs.IMAGEWOOF)
    
    val = 'val'
    sz = 224
    stats = imagenet_stats

    tfms = get_transforms(do_flip=False)
    load_name = dataset
    if dataset == 'cifar10' : 
        val = 'test'
        sz = 32
        stats = cifar_stats
        load_name = dataset[ : -2]

    data = ImageDataBunch.from_folder(path, train = 'train', valid = val, bs = 64, size = sz, ds_tfms = tfms).normalize(stats)
    
    if model_name == 'resnet10' :
        net = resnet10(pretrained = False, progress = False)
    elif model_name == 'resnet14' : 
        net = resnet14(pretrained = False, progress = False)
    elif model_name == 'resnet18' :
        net = resnet18(pretrained = False, progress = False)
    elif model_name == 'resnet20' :
        net = resnet20(pretrained = False, progress = False)
    elif model_name == 'resnet26' :
        net = resnet26(pretrained = False, progress = False)
    savename = '../saved_models/' + dataset + '/' + model_name + '_classifier/model0.pt'
    net.load_state_dict(torch.load(savename, map_location = 'cpu'))
    net.cuda()
    print('stagewise : ', _get_accuracy(data.valid_dl, net))
    
    savename = '../saved_models/' + dataset + '/' + model_name + '_no_teacher/model0.pt'
    net.load_state_dict(torch.load(savename, map_location = 'cpu'))
    net.cuda()
    print('no_teacher : ', _get_accuracy(data.valid_dl, net))

In [8]:
check('resnet10', 'imagenette')
check('resnet14', 'imagenette')
check('resnet18', 'imagenette')
check('resnet20', 'imagenette')
check('resnet26', 'imagenette')

stagewise :  0.974
no_teacher :  0.918
stagewise :  0.988
no_teacher :  0.912
stagewise :  0.988
no_teacher :  0.914
stagewise :  0.988
no_teacher :  0.916
stagewise :  0.99
no_teacher :  0.906


In [13]:
check('resnet10', 'imagewoof')
check('resnet14', 'imagewoof')
check('resnet18', 'imagewoof')
check('resnet20', 'imagewoof')
check('resnet26', 'imagewoof')

stagewise :  0.906
no_teacher :  0.802
stagewise :  0.928
no_teacher :  0.786
stagewise :  0.924
no_teacher :  0.792
stagewise :  0.92
no_teacher :  0.798
stagewise :  0.934
no_teacher :  0.802


In [14]:
check('resnet10', 'cifar10')
check('resnet14', 'cifar10')
check('resnet18', 'cifar10')
check('resnet20', 'cifar10')
check('resnet26', 'cifar10')

stagewise :  0.8475
no_teacher :  0.7788
stagewise :  0.8497
no_teacher :  0.775
stagewise :  0.8599
no_teacher :  0.7735
stagewise :  0.8646
no_teacher :  0.7808
stagewise :  0.8662
no_teacher :  0.783


In [11]:
def check_ld(model_name, dataset) :
    if dataset == 'imagenette' : 
        path = untar_data(URLs.IMAGENETTE)
    elif dataset == 'cifar10' : 
        path = untar_data(URLs.CIFAR)
    elif dataset == 'imagewoof' : 
        path = untar_data(URLs.IMAGEWOOF)
    
    new_path = path/'new'
    val = 'val'
    sz = 224
    stats = imagenet_stats

    tfms = get_transforms(do_flip=False)
    load_name = dataset
    if dataset == 'cifar10' : 
        val = 'test'
        sz = 32
        stats = cifar_stats
        load_name = dataset[ : -2]

    data = ImageDataBunch.from_folder(new_path, train = 'train', valid = 'val', test = 'test', bs = 64, size = sz, ds_tfms = tfms).normalize(stats)
    
    if model_name == 'resnet10' :
        net = resnet10(pretrained = False, progress = False)
    elif model_name == 'resnet14' : 
        net = resnet14(pretrained = False, progress = False)
    elif model_name == 'resnet18' :
        net = resnet18(pretrained = False, progress = False)
    elif model_name == 'resnet20' :
        net = resnet20(pretrained = False, progress = False)
    elif model_name == 'resnet26' :
        net = resnet26(pretrained = False, progress = False)
    savename = '../saved_models/' + dataset + '/less_data/' + model_name + '_classifier/model0.pt'
    net.load_state_dict(torch.load(savename, map_location = 'cpu'))
    net.cuda()
    print('stagewise : ', _get_accuracy(data.valid_dl, net))
    
    savename = '../saved_models/' + dataset + '/less_data/' + model_name + '_no_teacher/model0.pt'
    net.load_state_dict(torch.load(savename, map_location = 'cpu'))
    net.cuda()
    print('no_teacher : ', _get_accuracy(data.valid_dl, net))

In [12]:
check_ld('resnet10', 'imagenette')
check_ld('resnet14', 'imagenette')
check_ld('resnet18', 'imagenette')
check_ld('resnet20', 'imagenette')
check_ld('resnet26', 'imagenette')

stagewise :  0.954
no_teacher :  0.848
stagewise :  0.95
no_teacher :  0.85
stagewise :  0.956
no_teacher :  0.854
stagewise :  0.958
no_teacher :  0.85
stagewise :  0.96
no_teacher :  0.832


In [15]:
check_ld('resnet10', 'imagewoof')
check_ld('resnet14', 'imagewoof')
check_ld('resnet18', 'imagewoof')
check_ld('resnet20', 'imagewoof')
check_ld('resnet26', 'imagewoof')

stagewise :  0.858
no_teacher :  0.632
stagewise :  0.89
no_teacher :  0.616
stagewise :  0.89
no_teacher :  0.602
stagewise :  0.876
no_teacher :  0.6
stagewise :  0.898
no_teacher :  0.588


In [16]:
check_ld('resnet10', 'cifar10')
check_ld('resnet14', 'cifar10')
check_ld('resnet18', 'cifar10')
check_ld('resnet20', 'cifar10')
check_ld('resnet26', 'cifar10')

stagewise :  0.8159
no_teacher :  0.6635
stagewise :  0.8255
no_teacher :  0.6489
stagewise :  0.8328
no_teacher :  0.6466
stagewise :  0.8324
no_teacher :  0.6519
stagewise :  0.8364
no_teacher :  0.6409


In [24]:
def check_teacher(model_name, dataset) :
    if dataset == 'imagenette' : 
        path = untar_data(URLs.IMAGENETTE)
    elif dataset == 'cifar10' : 
        path = untar_data(URLs.CIFAR)
    elif dataset == 'imagewoof' : 
        path = untar_data(URLs.IMAGEWOOF)
    
    val = 'val'
    sz = 224
    stats = imagenet_stats

    tfms = get_transforms(do_flip=False)
    load_name = dataset
    if dataset == 'cifar10' : 
        val = 'test'
        sz = 32
        stats = cifar_stats
        load_name = dataset[ : -2]

    data = ImageDataBunch.from_folder(path, train = 'train', valid = val, bs = 64, size = sz, ds_tfms = tfms).normalize(stats)

    if model_name == 'resnet34' :
        learn = cnn_learner(data, models.resnet34, metrics = accuracy)
        learn = learn.load('resnet34_' + load_name + '_bs64')
        learn.freeze()
        net = learn.model
    
    net = net.cuda()
    print(_get_accuracy(data.valid_dl, net))

In [25]:
check_teacher('resnet34', 'imagenette')
check_teacher('resnet34', 'imagewoof')
check_teacher('resnet34', 'cifar10')

0.992
0.914
0.8751
