In [80]:
import os
import random
import torch
import numpy as np
import pandas as pd
import shutil
# from analysis import *
import argparse
from sys import platform

In [81]:
torch.manual_seed(12)
torch.cuda.manual_seed(12)
np.random.seed(12)

print(f"Device Available: {torch.cuda.is_available()}")
print(f"Device Count: {torch.cuda.device_count()}")
print(f"Current Device Index: {torch.cuda.current_device()}")
print(f"Device Name: {torch.cuda.get_device_name(0)}")


Device Available: True
Device Count: 1
Current Device Index: 0
Device Name: NVIDIA TITAN V


# Load image filenames from EEG dataset

In [82]:
if platform == "linux" or platform == "linux2":
    torch_models_dir = r"/media/titan/AI Research1/Data/CVPR2017"
elif platform == "win32":
    torch_models_dir = r"D:\Data\CVPR2021-02785\CVPR2021-02785\preprocessed\torch_models"
block_splits_all, block_splits_single, eeg_14_70, eeg_55_95, eeg_5_95, eeg_raw = os.listdir(torch_models_dir)
print(os.listdir(torch_models_dir))

['block_splits_by_image_all.pth', 'block_splits_by_image_single.pth', 'eeg_14_70_std.pth', 'eeg_55_95_std.pth', 'eeg_5_95_std.pth', 'eeg_signals_raw_with_mean_std.pth']


In [83]:
eeg_path = os.path.join(torch_models_dir, eeg_5_95)

print(eeg_path)

/media/titan/AI Research1/Data/CVPR2017/eeg_5_95_std.pth


In [84]:
eeg_dataset = torch.load(eeg_path)
dataset, labels, images = [eeg_dataset[k] for k in eeg_dataset.keys()]

In [85]:
print(len(labels))
print(len(images))
print(images[0])

40
1996
n02951358_31190


In [90]:
class_list = [i.split("_")[0] for i in images]
image_idx_list = [f"{i}.JPEG" for i in images]
df_imagenet = pd.DataFrame({"class": class_list, "image_filename": image_idx_list})
df_imagenet.to_csv('imagenet_filenames_original.csv', index=False)
df_imagenet.head()

Unnamed: 0,class,image_filename
0,n02951358,n02951358_31190.JPEG
1,n03452741,n03452741_16744.JPEG
2,n04069434,n04069434_10318.JPEG
3,n02951358,n02951358_34807.JPEG
4,n03452741,n03452741_5499.JPEG


In [91]:
df_imagenet = df_imagenet.sort_values("class")
df_imagenet.to_csv('imagenet_filenames_sorted.csv', index=False)

In [92]:
df_imagenet.value_counts('class')

class
n02106662    50
n03888257    50
n03584829    50
n03590841    50
n03709823    50
n03773504    50
n03775071    50
n03792782    50
n03792972    50
n03982430    50
n02124075    50
n04044716    50
n04069434    50
n04086273    50
n04120489    50
n07753592    50
n07873807    50
n11939491    50
n03452741    50
n03445777    50
n02951358    50
n02492035    50
n03272562    50
n03272010    50
n03197337    50
n03180011    50
n03100240    50
n03063599    50
n02992529    50
n02281787    50
n03297495    50
n02389026    50
n02690373    50
n02607072    50
n02510455    50
n02504458    50
n03376595    49
n02906734    49
n03877472    49
n13054560    49
Name: count, dtype: int64

In [93]:
dict_imagenet = {}
for cls in df_imagenet.value_counts("class").index.to_list():
    dict_imagenet[cls] = df_imagenet.loc[df_imagenet['class'] == cls]['image_filename'].to_list()

# Extract images from ImageNet dataset

In [98]:
imagenet_dir = r"/media/titan/AI Research1/Data/imagenet/ILSVRC/Data/CLS-LOC/train"

In [94]:

#Check if image classes in eeg dataset all exist in imagenet train/
imagenet_classes = set(os.listdir(imagenet_dir))
print(f"Number of classes in imagenet train/: {len(imagenet_classes)}")
count = 0
for cls in dict_imagenet.keys():
    if cls not in imagenet_classes:
        count +=1
        print(f"Class {cls} not found in imagenet classes")
print(f"Num of not found classes: {count}")

Number of classes in imagenet train/: 747
Class n04044716 not found in imagenet classes
Class n04069434 not found in imagenet classes
Class n04086273 not found in imagenet classes
Class n04120489 not found in imagenet classes
Class n07753592 not found in imagenet classes
Class n07873807 not found in imagenet classes
Class n11939491 not found in imagenet classes
Class n13054560 not found in imagenet classes
Num of not found classes: 8


In [95]:
extract_dir = '/home/titan/GithubClonedRepo/EEG-Research/Dataset/imagenet'

In [99]:
for cls in dict_imagenet.keys():
    if cls in imagenet_classes:
        for filename in dict_imagenet[cls]:
            dest_path = os.path.join(extract_dir, filename)
            if not os.path.exists(dest_path):
                shutil.copy(os.path.join(imagenet_dir, cls, filename), os.path.join(extract_dir, filename))

KeyboardInterrupt: 

In [100]:
print(len(os.listdir(extract_dir)))

1222
