# Feature Extractor

In [1]:
import tensorflow as tf
tf.random.set_seed(221)

import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import torch

extract_feature = False


Pickle

In [2]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

Resnet18

In [3]:
if extract_feature:
    import timm

    model = timm.create_model(
        'resnet18.a1_in1k',
        pretrained=True,
        num_classes=0,  # remove classifier nn.Linear
    )
    model = model.eval()

    # get model specific transforms (normalization, resize)
    data_config = timm.data.resolve_model_data_config(model)
    transforms = timm.data.create_transform(**data_config, is_training=False)


Test print image from Cifar 10

In [4]:
if extract_feature:
    test = unpickle('E:/Work/DS/Datasets/Raw/cifar-10-batches-py/train/data_batch_1')

    img = test[b'data'][0].reshape((3,32,32)).transpose(1,2,0).astype("uint8") # frog???
    label = test[b'filenames'][0]
    plt.imshow(img)
    plt.title(label)
    plt.show()

Load CIFAR 10

In [5]:
import pandas as pd
if extract_feature:
    # Directory containing subdirectories for each class
    base_dir = 'E:/Work/DS/Datasets/Raw/cifar-10-batches-py/train'

    cifar10 = []

    # Iterate through each subdirectory
    for batch_file_name in os.listdir(base_dir):
        file_dir = os.path.join(base_dir, batch_file_name)
        cifar10.append(unpickle(file_dir))


Process images from the cifar10 for the unwrapper

In [6]:
def process_minibatch(minibatch):
    batch_data = []
    for image in minibatch:
        # Convert to PIL image format
        img = Image.fromarray(image.reshape(3,32,32).transpose(1,2,0))
        # transform to model format and stack
        img = transforms(img)
        batch_data.append(img)
    batch_data = torch.stack(batch_data)
    # run through the resnet
    processed_data = model(batch_data)
    return processed_data

In [7]:
def feature_extraction(file):

    output_df = pd.DataFrame(columns=['Class'])

    MINIBATCH_SIZE = 200


    for k,batch in enumerate(file):
        print('processing batch', k+1)
        # get images
        images = batch[b'data']
        labels = batch[b'labels']
        # Convert the list of images to a batch tensor
        for i in range(0,int(10000/MINIBATCH_SIZE)):
            print('processing minibatch %d out of %d' %(i+1 , int(10000/MINIBATCH_SIZE)))
            minibatch = images[i*MINIBATCH_SIZE:(i+1)*MINIBATCH_SIZE]

            # get features and put into df format
            minibatch_output = process_minibatch(minibatch)
            minibatch_output_df = pd.DataFrame(minibatch_output.detach().numpy(), columns=[f'features_{i}' for i in range(512)])
            minibatch_output_df['Class'] = labels[i*MINIBATCH_SIZE:(i+1)*MINIBATCH_SIZE]

            output_df = pd.concat([output_df, minibatch_output_df], ignore_index=True)
    return output_df

In [8]:
if extract_feature:
    output_df = feature_extraction(cifar10)
    output_df.to_csv('extracted_features_train.csv', index = False)
else:
    output_df = pd.read_csv('extracted_features_train.csv')

In [9]:
print(output_df.head())
print(output_df.shape)

   Class  features_0  features_1  features_2  features_3  features_4  \
0      6    0.040742    0.035822    0.222851    0.015412    0.486823   
1      9    0.672252    0.000000    0.342131    0.000000    0.073739   
2      9    0.301723    0.000000    0.134019    0.081969    0.069278   
3      4    0.000000    0.000000    0.128621    0.001375    0.219229   
4      1    0.142378    0.000000    0.067117    0.044040    0.281984   

   features_5  features_6  features_7  features_8  ...  features_502  \
0    0.015395    0.189043    0.500762    0.499615  ...      0.109454   
1    0.000000    0.123008    0.251687    0.029337  ...      0.135525   
2    0.001227    0.045751    0.047892    0.386089  ...      0.219279   
3    0.063678    0.005415    0.165327    0.888420  ...      0.119232   
4    0.221451    0.032381    0.138317    0.583788  ...      0.000000   

   features_503  features_504  features_505  features_506  features_507  \
0      0.391888      0.035869      0.041942      0.680499  

Do the same for test set

In [10]:
if extract_feature:
    cifar10_test = [unpickle('test_batch')]
    test = feature_extraction(cifar10_test)
    test.to_csv('extracted_features_test.csv', index = False)
else:
    test = pd.read_csv('extracted_features_test.csv')

In [11]:
transformed_X_train = output_df.drop(columns = 'Class')
y_train = output_df['Class']

In [12]:
transformed_X_train.head(10)

Unnamed: 0,features_0,features_1,features_2,features_3,features_4,features_5,features_6,features_7,features_8,features_9,...,features_502,features_503,features_504,features_505,features_506,features_507,features_508,features_509,features_510,features_511
0,0.040742,0.035822,0.222851,0.015412,0.486823,0.015395,0.189043,0.500762,0.499615,0.259726,...,0.109454,0.391888,0.035869,0.041942,0.680499,0.61715,0.306534,0.007564,0.148896,0.012736
1,0.672252,0.0,0.342131,0.0,0.073739,0.0,0.123008,0.251687,0.029337,0.153819,...,0.135525,1.281373,0.405868,0.0,0.058778,0.07456,0.063702,1.13217,0.026337,0.0
2,0.301723,0.0,0.134019,0.081969,0.069278,0.001227,0.045751,0.047892,0.386089,0.0,...,0.219279,0.207228,0.072956,0.002798,0.102566,0.043786,0.204728,0.508546,0.370723,0.291213
3,0.0,0.0,0.128621,0.001375,0.219229,0.063678,0.005415,0.165327,0.88842,0.382979,...,0.119232,0.244319,0.048261,0.030268,0.302496,0.018989,0.047324,0.234554,0.118996,0.113055
4,0.142378,0.0,0.067117,0.04404,0.281984,0.221451,0.032381,0.138317,0.583788,0.302565,...,0.0,0.603535,0.303451,0.0,0.038988,0.058191,0.0,0.796191,0.0,0.107168
5,1.035197,0.058702,0.095312,0.03745,0.053576,0.01323,0.259931,0.056354,0.170153,0.235934,...,0.091722,0.405168,0.130961,0.093053,0.0,0.21382,0.491387,0.074986,0.036422,0.115289
6,0.273047,0.07476,0.267747,0.831988,0.431168,0.0,0.513849,0.082853,0.031651,0.264359,...,0.57858,0.111145,0.319857,0.160226,0.279058,0.053023,0.0,0.13941,0.300243,0.138246
7,0.001875,0.0,0.02555,0.014554,0.588634,0.062992,0.616284,0.03779,0.070613,0.0,...,0.098866,0.067091,0.008325,0.0,0.0,0.019947,0.597903,0.878241,0.015466,0.088653
8,0.037763,0.0,0.136463,0.37093,0.484324,0.004421,0.222002,0.478854,0.056549,0.032293,...,0.194435,0.268003,0.049738,0.03389,0.275827,0.030199,0.0,0.235028,0.012788,0.300287
9,0.023211,0.0,0.229256,0.159833,0.620174,0.016002,0.034277,0.158599,0.309702,0.08502,...,0.158672,0.374362,0.029272,0.097099,0.212243,0.047205,0.058019,0.053052,0.056359,0.052717
