## Setup Device & Environment

In [1]:
from datasets import load_dataset
import numpy as np

import torch
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda')

device

device(type='cuda')

## Data Preparation

### Load Dataset

In [3]:
raw_dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 18982.56it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<?, ?it/s]


In [4]:
print(raw_dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [5]:
labels = labels = raw_dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [6]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [7]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [8]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [9]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [10]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [11]:
def transforms(examples):
  examples["image"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  # del examples["image"]
  return examples

In [12]:
dataset = raw_dataset.with_transform(transforms)

In [13]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Preparing metrics for the model

In [14]:
accuracy = evaluate.load("accuracy")

In [15]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

## Loading Best Model

In [16]:
from own_model import CompositeModel
from training_own import get_model_params
from torch import nn

own_layer = nn.Sequential(
  nn.LayerNorm(1000),
  nn.Linear(1000, 64),
  nn.ReLU(),
  nn.Linear(64, 128),
  nn.Dropout(0.1),
  nn.ReLU(),
  nn.Linear(128, 64),
  nn.Dropout(0.25),
  nn.ReLU(),
  nn.Linear(64, 2)
)

model = CompositeModel(own_layer)
model.eval()
# saved
get_model_params(model.additional_layers)

82770

## Training Setup

### Data Loaders

In [17]:
training_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)

### Loss Function

In [18]:
loss_fn = torch.nn.CrossEntropyLoss()

### Optimizer

In [19]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.additional_layers.parameters(), lr=1e-5)

In [20]:
from torchmetrics import Accuracy

accuracy_metric = Accuracy(task='multiclass', num_classes=2).to(device)

## Training Model

In [21]:
from training_own import train_one_epoch, train_many_epochs

In [22]:
from training_own import evaluate_model

N_EXAMPLES = 25

testing_fragment = dataset['test'].shuffle(seed=1)[:N_EXAMPLES]

evaluate_model(model, testing_fragment=testing_fragment)

index 0: true/predicted: 0/0
index 1: true/predicted: 1/0
index 2: true/predicted: 1/0
index 3: true/predicted: 1/0
index 4: true/predicted: 1/0
index 5: true/predicted: 0/0
index 6: true/predicted: 0/0
index 7: true/predicted: 0/0
index 8: true/predicted: 0/0
index 9: true/predicted: 1/0
index 10: true/predicted: 0/0
index 11: true/predicted: 1/0
index 12: true/predicted: 1/0
index 13: true/predicted: 1/0
index 14: true/predicted: 0/0
index 15: true/predicted: 1/0
index 16: true/predicted: 1/0
index 17: true/predicted: 1/0
index 18: true/predicted: 1/0
index 19: true/predicted: 0/0
index 20: true/predicted: 1/0
index 21: true/predicted: 1/0
index 22: true/predicted: 0/0
index 23: true/predicted: 1/0
index 24: true/predicted: 1/0
testing accuracy: 0.36


In [24]:
train_many_epochs(
  10,
  model=model,
  training_loader=training_loader,
  validation_loader=validation_loader,
  optimizer=optimizer,
  loss_fn=loss_fn,
  accuracy_metric=accuracy_metric,
  cuda_device=device,
  epoch_index=0,
  logging_frequency=50
)

EPOCH 1:
  batch 50 loss: 0.5416188043355942 training_accuracy: 0.7400000095367432
  batch 100 loss: 0.5228772228956222 training_accuracy: 0.75
  batch 150 loss: 0.5333236253261566 training_accuracy: 0.7450000047683716
  batch 200 loss: 0.5165081006288529 training_accuracy: 0.75
  batch 250 loss: 0.5390014296770096 training_accuracy: 0.7362499833106995
  batch 300 loss: 0.529529277086258 training_accuracy: 0.7324999570846558
LOSS train 0.529529277086258 valid 0.6311175227165222 ACCURACY validation 0.625
EPOCH 2:
  batch 50 loss: 0.49056779861450195 training_accuracy: 0.7649999856948853
  batch 100 loss: 0.5032666021585465 training_accuracy: 0.7437499761581421
  batch 150 loss: 0.4890257292985916 training_accuracy: 0.7437499761581421
  batch 200 loss: 0.51743632376194 training_accuracy: 0.7262499928474426
  batch 250 loss: 0.49279205620288846 training_accuracy: 0.7475000023841858
  batch 300 loss: 0.4918600368499756 training_accuracy: 0.7462499737739563
LOSS train 0.4918600368499756 val

In [None]:
train_many_epochs(
  2,
  model=model,
  training_loader=training_loader,
  validation_loader=validation_loader,
  optimizer=optimizer,
  loss_fn=loss_fn,
  accuracy_metric=accuracy_metric,
  cuda_device=device,
  epoch_index=0,
  logging_frequency=50
)

###