In [None]:
def select_indexes(max_samples_per_class, trng_prob, in_frame):
    selected_idx = []

    for label in class_labels:
        class_idx = np.where(in_frame['damage'] == label)[0]
        num_samples = int(np.min([max_samples_per_class, len(class_idx)*trng_prob]))
        sub_idx = rng.choice(class_idx, num_samples, replace=False).tolist()
        selected_idx += sub_idx
        
    return selected_idx

In [None]:
def generate_dataset(in_frame):

    num_samples = in_frame.shape[0]
    
    img_path, img_label = [], []
    for i in range(num_samples):
        img_path += [image_path.format(in_frame.iloc[i, 2])]
        img_label += [in_frame.iloc[i, 1]]

    dataset = Dataset.from_dict({"image":img_path, "label":img_label})
    dataset = dataset.cast_column("image", Image()) # by default decode = TRUE
    dataset = dataset.cast_column("label", ClassLabel(num_classes=5, names=class_labels, id=None))
    return dataset

In [None]:
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
def preprocess_trng(batch):
    batch["pixel_values"] = [trng_transform(image.convert("RGB")) for image in batch["image"]]
    return batch

In [None]:
def preprocess_val(batch):
    batch["pixel_values"] = [val_transform(image.convert("RGB")) for image in batch["image"]]
    return batch

In [None]:
# Collate function returns a dictionary with keys corresponding to the input parameters for image classification model
# Data collators are objects that will form a batch by using a list of dataset elements as input (https://huggingface.co/docs/transformers/main_classes/data_collator)
def img_collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}