In [1]:
import os

import logging
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.datasets
import torchvision.transforms as transforms
import torchvision



from mean_teacher import datasets, architectures
from mean_teacher.utils import *

from IPython.display import clear_output

device = "cuda" if torch.cuda.is_available() else "cpu"

import matplotlib.pyplot as plt

LOG = logging.getLogger('main')
NO_LABEL = -1
to_image = transforms.ToPILImage()

In [2]:
dataset_config = datasets.__dict__['sslMini']()

In [3]:
def load_weights(model_arch, pretrained_model_path, state_dict, cuda=True):
        # Load pretrained model
        pretrained_model = torch.load(f=pretrained_model_path, map_location="cuda" if cuda else "cpu")

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in pretrained_model[state_dict].items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        # Load pre-trained weights in current model
        with torch.no_grad():
            model_arch.load_state_dict(new_state_dict, strict=True)

        # Debug loading
        #print('Parameters found in pretrained model:')
        pretrained_layers = new_state_dict.keys()
        #for l in pretrained_layers:
        #    print('\t' + l)
        #print('')

        for name, module in model_arch.state_dict().items():
            if name in pretrained_layers:
                assert torch.equal(new_state_dict[name].cpu(), module.cpu())
                #print('{} have been loaded correctly in current model.'.format(name))
            else:
                raise ValueError("state_dict() keys do not match")
                
        return model_arch


In [12]:
BATCH_SIZE = 64

evaldir = "/scratch/ehd255/ssl_data_96/supervised/val"

eval_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(evaldir, dataset_config['eval_transformation']),
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=2,
                                              #pin_memory=True,
                                              drop_last=False)


In [5]:
pretrained_model_path = "/scratch/ijh216/ssl/ssl2/2019-05-09_17-19-21/10/transient/checkpoint.350.ckpt" 
log = torch.load(f=pretrained_model_path, map_location="cuda" if False else "cpu")

In [7]:
model_dir = "/scratch/ijh216/ssl/ssl_shake_mini_augment/2019-05-06_18-04-18/10/transient/checkpoint.325.ckpt" 
model = architectures.__dict__['cifar_shakeshake26']().to(device)
model = load_weights(model, model_dir, state_dict="ema_state_dict", cuda=False)

In [12]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy@k for the specified values of k"""
    maxk = max(topk)
    labeled_minibatch_size = max(target.ne(NO_LABEL).sum(), 1e-8).item()

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True).item()
        res.append(correct_k * (100.0 / labeled_minibatch_size))
    return res

In [13]:
model_dir = "/scratch/ijh216/ssl/ssl2/2019-05-09_17-19-21/10/transient/checkpoint.350.ckpt" 
model = architectures.__dict__['cifar_shakeshake26']().to(device)
model = load_weights(model, model_dir, state_dict="ema_state_dict", cuda=False)

n_samples = 0.
n_correct_top_1 = 0
n_correct_top_k = 0

for i, (img, target) in enumerate(eval_loader):
    img, target = img.to(device), target.to(device)
    n_samples += BATCH_SIZE

        # Forward
    output = model(img)[0]

        # Top 1 accuracy
    pred_top_1 = torch.topk(output, k=1, dim=1)[1]
    n_correct_top_1 += pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
    pred_top_k = torch.topk(output, k=5, dim=1)[1]
    target_top_k = target.view(-1, 1).expand(BATCH_SIZE, 5)
    n_correct_top_k += pred_top_k.eq(target_top_k).int().sum().item()
    
    if i % 100 == 0:
        print("******************************")
        print("Acc@1", n_correct_top_1/n_samples)
        print("Acc@5", n_correct_top_k/n_samples)
        print("******************************")
    
    # Accuracy
top_1_acc = n_correct_top_1/n_samples
top_k_acc = n_correct_top_k/n_samples

print("******************************")
print("Acc@1", top_1_acc)
print("Acc@5", top_k_acc)
print("******************************")

******************************
Acc@1 0.265625
Acc@5 0.484375
******************************


KeyboardInterrupt: 

In [12]:
import os

In [17]:
for j, i in enumerate(os.listdir("/scratch/ijh216/ssl_mini"+'/supervised/train')):
    if i == "n02485536":
        print(i)
    #print(j, len(os.listdir("/scratch/ijh216/ssl_mini"+'/supervised/train/'+i)))

n02485536


In [20]:
'n02485536_4796.JPEG' in os.listdir("/scratch/ijh216/ssl_mini/supervised/train/n02485536")

True

In [None]:
'n02485536_4796.JPEG'

In [40]:
img_dir = eval_loader.dataset.imgs[0][0]
img = eval_loader.dataset[0][0].unsqueeze(0).to(device)

In [49]:
labels = []

for ckpt in range(0, 325, 5):

    model_dir = "/scratch/ijh216/ssl/ssl_shake_mini_augment/2019-05-06_18-04-18/10/transient/checkpoint.{}.ckpt".format(ckpt) 
    model = architectures.__dict__['cifar_shakeshake26']().to(device)
    model = load_weights(model, model_dir, state_dict="ema_state_dict", cuda=False)


    output = F.softmax(model(img)[0], dim=1)
    labels.append(output)

labels = torch.cat(labels)    

In [52]:
labels = torch.cat(labels) 

In [58]:
torch.cat((labels, labels)).size()

torch.Size([130, 1000])

In [46]:
labels[0].size()

torch.Size([1, 1000])

In [1]:

import os
import shutil
import random
from tqdm import tqdm

import numpy as np

os.makedirs("/scratch/ijh216/sslK")

root_sup_train_dir = "/scratch/ehd255/ssl_data_96/supervised/train"

def randomSelect2(path, n ,out_dir):

    for i in tqdm(os.listdir(path)):
        
        cur_dir = path + "/" + i
        
        all_files = os.listdir(cur_dir)
        
        files_1 = np.random.choice(all_files, n, replace=False)
                      
        for file in files_1:
            new_dir = out_dir + "/" + i + "/."
            if not os.path.exists(new_dir):
                os.makedirs(new_dir)
            
            file_path = cur_dir + "/" + file
            shutil.copy(file_path, new_dir)
           
                            
    return


for i in [1, 2, 4, 8, 16, 32]:
    
    os.makedirs("/scratch/ijh216/sslK/train_{}".format(i))

    output_dir = "/scratch/ijh216/sslK/train_{}".format(i)

    
    randomSelect2(root_sup_train_dir, i, output_dir)

100%|██████████| 1000/1000 [01:06<00:00, 15.02it/s]
100%|██████████| 1000/1000 [01:33<00:00, 12.90it/s]
100%|██████████| 1000/1000 [03:25<00:00,  3.13it/s]
100%|██████████| 1000/1000 [08:06<00:00,  1.49it/s]
100%|██████████| 1000/1000 [15:11<00:00,  1.10it/s]
100%|██████████| 1000/1000 [28:12<00:00,  1.63s/it]


In [3]:
!rm -rf /scratch/ijh216/sslK/unsup

In [4]:
def randomSelect2(path, n ,out_dir):

    for i in tqdm(os.listdir(path)):
        
        cur_dir = path + "/" + i
        
        all_files = os.listdir(cur_dir)
        
        files_1 = np.random.choice(all_files, n, replace=False)
                      
        for file in files_1:
            new_dir = out_dir + "/" + i + "/."
            if not os.path.exists(new_dir):
                os.makedirs(new_dir)
            
            file_path = cur_dir + "/" + file
            shutil.copy(file_path, new_dir)
           
                            
    return

root_sup_train_dir = "/scratch/ehd255/ssl_data_96/unsupervised"
os.makedirs("/scratch/ijh216/sslK/unsup")

output_dir = "/scratch/ijh216/sslK/unsup"

    
randomSelect2(root_sup_train_dir, 50, output_dir)

100%|██████████| 512/512 [18:47<00:00,  1.88s/it]
