In [None]:
def select_indexes(max_samples_per_class, trng_prob, in_frame):
    selected_idx = []

    for label in class_labels:
        class_idx = np.where(in_frame['damage'] == label)[0]
        num_samples = int(np.min([max_samples_per_class, len(class_idx)*trng_prob]))
        sub_idx = rng.choice(class_idx, num_samples, replace=False).tolist()
        selected_idx += sub_idx
        
    return selected_idx

In [None]:
def generate_dataset(in_frame, start_idx, end_idx, other_path, nd_path, image_exist, save_suffix):

    num_samples = in_frame.shape[0]
    num_extra_images = end_idx - start_idx
    
    img_path, img_label = [], []
    for i in range(num_samples):
        img_path += [image_path.format(in_frame.iloc[i, 2])]
        img_label += [in_frame.iloc[i, 1]]

    dataset = Dataset.from_dict({"image":img_path, "label":img_label})
    dataset = dataset.cast_column("image", Image()) # by default decode = TRUE
    dataset = dataset.cast_column("label", ClassLabel(num_classes=5, names=class_labels, id=None))
    
    if num_extra_images==0:
        return dataset
    
    if image_exist is False:
        file_names = augment_data('other', in_frame, start_idx, num_extra_images)
        pd.DataFrame(file_names).to_csv('Other-Augment-' + save_suffix +'.txt', index=False, sep=",")

        file_names = augment_data('ND', in_frame, start_idx, num_extra_images)
        pd.DataFrame(file_names).to_csv('ND-Augment-' + save_suffix +'.txt', index=False, sep=",")

    other_images = [other_path + str(idx) + '.jpg' for idx in range(start_idx, end_idx)]
    nd_images = [nd_path + str(idx) + '.jpg' for idx in range(start_idx, end_idx)]

    other_labels = [3]*num_extra_images
    nd_labels = [4]*num_extra_images

    extra_images = other_images + nd_images
    extra_labels = other_labels + nd_labels
    
    extra_dataset = Dataset.from_dict({'image':extra_images, 'label':extra_labels})
    extra_dataset = extra_dataset.cast_column('image', Image())
    extra_dataset = extra_dataset.cast_column('label', ClassLabel(num_classes=5, names=class_labels, id=None))
    
    final_dataset = concatenate_datasets([extra_dataset, dataset])
    
    return final_dataset

In [None]:
def augment_data(class_type, in_frame, start_idx, num_extra_images):
    class_idx = np.where(in_frame['damage'] == class_type)[0]
    img_files = rng.choice(in_frame.iloc[class_idx, 2], num_extra_images)

    for i in range(num_extra_images):
        current_img_path = image_path.format(img_files[i])
        img = pil.open(current_img_path)
        mod_img = aug_transform(img)
        rel_path = class_type + '/' + class_type + '_' + str(i+start_idx) + '.jpg'
        topilImage(mod_img).save(aug_image_path.format(rel_path), quality=95)
        
    return img_files

In [None]:
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
# 'batch' is a dictionary with each value being a list of elements
def preprocess_img(batch):
    batch["pixel_values"] = [img_transform(image.convert("RGB")) for image in batch["image"]]
    return batch

In [None]:
# Data collators are objects that will form a batch by using a list of dataset elements as input
def img_collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}