In [None]:
import shutil, os
from glob import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
# Define params
WAYS=5
SHOTS=5
SRC_SAMPLES=20
SRC_THR=30
TGT_SAMPLES=55

In [None]:
all_classes = glob('data_/PokemonData/*')
print(f"total {len(all_classes)} classes")

In [None]:
data_per_class=[len(glob(sdir+'/*.jpg')) for sdir in all_classes]
print("Min",min(data_per_class),",Max",max(data_per_class))

In [None]:
# Sort all the classes by ID
sorted_classID=np.array(sorted([(l,e) for e,l in enumerate(data_per_class)]))

In [None]:
sum(sorted_classID[:,0]>=SRC_THR)

In [None]:
# Total classes that can be used as target
sum(sorted_classID[:,0]>=TGT_SAMPLES)

In [None]:
sorted_classID=sorted_classID[sorted_classID[:,0]>=SRC_THR,:]
print(f"total {len(sorted_classID)} classes")

In [None]:
sourceID=sorted_classID[sorted_classID[:,0]<TGT_SAMPLES,1]
targetID=sorted_classID[sorted_classID[:,0]>=TGT_SAMPLES,1]
print(f"total {len(sorted_classID)} classes=source {len(sourceID)} + target {len(targetID)} classes")


In [None]:
np.random.seed(2022)
chosen_ones=np.random.choice(range(len(targetID)),size=WAYS,replace=False)
final_targetID=set(targetID[chosen_ones])
other_targetID=set(targetID)-final_targetID

In [None]:
# Assign source and target class directories
source_classes=np.array(all_classes)[list(set(sourceID)|other_targetID)]
target_classes=np.array(all_classes)[list(final_targetID)]

In [None]:
target_classes

In [None]:
# Creating source dataset in a directory
np.random.seed(2022)
for class_dir in source_classes:
    label_name=os.path.basename(class_dir)
    # List all files
    label_pics=glob(os.path.join(class_dir,"*.jpg"))
    # Pick some pictures
    label_pics=np.random.choice(label_pics,size=SRC_SAMPLES,replace=False)
    # Copy those pictures
    class_dir=os.path.join("source",label_name)
    os.makedirs(class_dir,exist_ok = True)
    for f in label_pics:
        shutil.copy(src=f, dst=class_dir)

In [None]:
Q_SAMPLES=50
print(Q_SAMPLES)

In [None]:
# Creating Target in a directory

np.random.seed(2022)
for class_dir in target_classes:
    label_name=os.path.basename(class_dir)
    # copy some pics to new dir
    label_pics=glob(os.path.join(class_dir,"*.jpg"))
    label_pics=np.random.choice(label_pics,size=Q_SAMPLES+SHOTS,replace=False)
    class_dir=os.path.join("target_support",label_name)
    os.makedirs(class_dir,exist_ok = True)
    for f in label_pics:
        shutil.copy(src=f, dst=class_dir)

In [None]:
# Move those data and record tasks in a list
target_ans=pd.DataFrame(columns=['filename']+["ans"])
all_tgt_classes = glob('target_support/*')
os.makedirs("target_query",exist_ok = True)
np.random.seed(1234)
for cls_id,class_dir in enumerate(all_tgt_classes):
    label_name=os.path.basename(class_dir)
    label_pics=glob(os.path.join(class_dir,"*.jpg"))
    label_pics=np.random.choice(label_pics,size=Q_SAMPLES,replace=False)
    for pics in label_pics:
        label=cls_id
        filename=os.path.basename(pics)
        shutil.move(pics,"target_query")
        sample_frame=pd.DataFrame([[filename,label]],
                                  columns=['filename']+["ans"])
        target_ans=target_ans.append(sample_frame)
target_ans=target_ans.reset_index(drop=True)

In [None]:
np.random.seed(2022)
target_ans.sample(len(target_ans)).to_csv("ans.csv",index=False)