## Data Preparation

In [28]:
from datasets import load_dataset

### Load Dataset

In [29]:
dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 34797.59it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<00:00, 311836.73it/s]


In [30]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [31]:
labels = labels = dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [32]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [33]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [34]:
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

In [35]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [36]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [37]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [38]:
def transforms(examples):
  examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  del examples["image"]
  return examples

In [39]:
dataset  = dataset.with_transform(transforms)

In [40]:
print(dataset['train'])

Dataset({
    features: ['image', 'label'],
    num_rows: 5216
})


### Preparing metrics for the model

###