In [1]:
from comet_ml import Experiment
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
import torch
from torchsummary import summary
import numpy as np
from tqdm import tqdm
torch.cuda.set_device(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

# stage should be in 0 to 5 and for 0, use -1 (this is due to inconsistency in the model generated by PyTorch)
hyper_params = {
    "stage": 0,
    "repeated": 1,
    "num_classes": 10,
    "batch_size": 64,
    "num_epochs": 100,
    "learning_rate": 1e-4
}

def save_torch(name:str, tensor):
    new = tensor.clone()
    np.save(name, new.detach().cpu().numpy())
    
def load_np_torch(name):
    return torch.from_numpy(np.load(str(name)))

In [2]:
path = untar_data(URLs.IMAGENETTE)

In [6]:
path_train_feat = path/'train_feat'
path_val_feat = path/'val_feat'

path_train_feat.mkdir(exist_ok=True)
path_val_feat.mkdir(exist_ok=True)

for i in range(0, 6):
    (path_train_feat/str(i)).mkdir(exist_ok=True)
    (path_val_feat/str(i)).mkdir(exist_ok=True)
    
path_train_feat.ls()

[PosixPath('/home/navid/.fastai/data/imagenette/train_feat/0'),
 PosixPath('/home/navid/.fastai/data/imagenette/train_feat/2'),
 PosixPath('/home/navid/.fastai/data/imagenette/train_feat/1'),
 PosixPath('/home/navid/.fastai/data/imagenette/train_feat/4'),
 PosixPath('/home/navid/.fastai/data/imagenette/train_feat/5'),
 PosixPath('/home/navid/.fastai/data/imagenette/train_feat/3')]

In [7]:
class SaveFeatures :
    def __init__(self, m) : 
        self.handle = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, m, inp, outp) : 
        self.features = outp
    def remove(self) :
        self.handle.remove()

In [8]:
data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = hyper_params["batch_size"], size = 224).normalize(imagenet_stats)

In [6]:
learn = cnn_learner(data, models.resnet34, metrics = accuracy)
learn.freeze()

mdl = learn.model
summary(mdl, (3, 224, 224))

In [8]:
sf = [SaveFeatures(m) for m in [mdl[0][2], mdl[0][4], mdl[0][5], mdl[0][6], mdl[0][7]]]


In [153]:
def save_features(mdl=mdl, data=data):
    mdl.eval()
    for j in ["train", "val"]:
        if j == "train" : dataset = data.train_ds
        else: dataset = data.valid_ds       
        for i in tqdm(range(len(dataset))) :
            image = dataset[i][0]
            name = dataset.items[i]
            lst = str(name).split(j)
            tensor = image.data.cuda().view(1, 3, 224, 224)
            _ = mdl(tensor)
        
            for idx, feature in enumerate(sf):  
                if j == "train":
#                     image.save(f"{lst[0]}{j}_images/{lst[1].split('/')[2].split('.')[0]}.JPEG")
                    save_torch(f"{lst[0]}{j}_feat/{idx+1}/{lst[1].split('/')[2].split('.')[0]}.npy",
                              feature.features)
                if j == "val":
#                     image.save(f"{lst[0]}{j}_images/{lst[1].split('/')[1] + lst[-1]}")
                    save_torch(f"{lst[0]}{j}_feat/{idx+1}/{(lst[1].split('/')[1] + lst[-1])[:-4]}npy",
                              feature.features)

In [154]:
save_features()

100%|██████████| 12894/12894 [14:45<00:00, 14.57it/s] 
100%|██████████| 500/500 [00:38<00:00, 13.12it/s]


In [11]:
def save_images():
    for j in ["train", "val"]:
        if j == "train" : dataset = data.train_ds
        else: dataset = data.valid_ds       
        for i in tqdm(range(len(dataset))) :
            image = dataset[i][0].data 
            name = dataset.items[i]
            lst = str(name).split(j)
            if j == "train":
                    save_torch(f"{lst[0]}{j}_feat/0/{lst[1].split('/')[2].split('.')[0]}.npy",
                              image)
            if j == "val":
                save_torch(f"{lst[0]}{j}_feat/0/{(lst[1].split('/')[1] + lst[-1])[:-4]}npy",
                              image)

In [12]:
save_images()

100%|██████████| 12894/12894 [06:02<00:00, 35.61it/s]
100%|██████████| 500/500 [00:12<00:00, 40.79it/s]
