In [1]:
working_dir = "../."
dataset_h5_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/dataset.h5"
hospitaldict_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/hospitals-patients-dict.pkl"
libraries_dir = working_dir + "/libraries"


import warnings
import sys
import os
import glob
import pickle
import lightning as pl
from tabulate import tabulate
from torch.utils.data import Subset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import torchvision
from kornia import tensor_to_image
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor
from transformers import ViTConfig


from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, DeviceStatsMonitor, ModelCheckpoint
from lightning.pytorch.tuner import Tuner


sys.path.append(working_dir)
from data_setup import HDF5Dataset, FrameTargetDataset, DataAugmentation

os.chdir(working_dir)
os.getcwd()


'/Users/andry/Documents/GitHub/lus-dl-framework'

# Dataset

In [13]:
import pickle
from torch.utils.data import Subset

train_ratio = 0.2
rseed = 21

In [14]:
dataset = HDF5Dataset(dataset_h5_path)

train_indices_path = os.path.dirname(dataset_h5_path) + f"/train_indices_{train_ratio}.pkl"
test_indices_path = os.path.dirname(dataset_h5_path) + f"/test_indices_{train_ratio}.pkl"

Serialized frame index map FOUND.

Loaded serialized data.


277 videos (58924 frames) loaded.


In [15]:
if os.path.exists(train_indices_path) and os.path.exists(test_indices_path):
    print("Loading pickled indices")
    with open(train_indices_path, 'rb') as train_pickle_file:
        train_indices = pickle.load(train_pickle_file)
    with open(test_indices_path, 'rb') as test_pickle_file:
        test_indices = pickle.load(test_pickle_file)
    # Create training and test subsets
    train_subset = Subset(dataset, train_indices)
    test_subset = Subset(dataset, test_indices)  
else:
    train_subset, test_subset, split_info, train_indices, test_indices = dataset.split_dataset(hospitaldict_path, 
                                                              rseed, 
                                                              train_ratio)
    print("Pickling sets...")
    
    # Pickle the indices
    with open(train_indices_path, 'wb') as train_pickle_file:
        pickle.dump(train_indices, train_pickle_file)
    with open(test_indices_path, 'wb') as test_pickle_file:
        pickle.dump(test_indices, test_pickle_file)

Loading pickled indices


In [16]:
test_subset_size = train_ratio/2
test_subset = Subset(test_subset, range(int(test_subset_size * len(test_indices))))
test_subset

<torch.utils.data.dataset.Subset at 0x1744d0710>

In [17]:
train_dataset = FrameTargetDataset(train_subset)
test_dataset = FrameTargetDataset(test_subset)

print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 8052
Test size: 5087


# Dataloader

In [18]:
def collate_fn(examples):
    frames = torch.stack([example[0] for example in examples])  # Extract the preprocessed frames
    scores = torch.tensor([example[1] for example in examples])  # Extract the scores
    return {"pixel_values": frames, "labels": scores}
    # return (frames, scores)



In [19]:
preprocess = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
transform = DataAugmentation()

# train_dataset.set_transform(preprocess)
# test_dataset.set_transform(preprocess)

In [22]:
batch_size = 16
train_dataloader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          num_workers=0,
                          pin_memory=True,
                          collate_fn=collate_fn, shuffle=False)

test_dataloader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          pin_memory=True,
                          collate_fn=collate_fn)


In [21]:

def show_batch(data_loader, num_batches=1, win_size=(20, 20)):
    def _to_vis(data):
        # Ensure that pixel values are in the valid range [0, 1]
        data = torch.clamp(data, 0, 1)
        return tensor_to_image(torchvision.utils.make_grid(data, nrow=8))

    for batch_num, (imgs, labels) in enumerate(data_loader):
        if batch_num >= num_batches:
            break
        
        # Apply data augmentation to the batch (you need to define DataAugmentation function)
        imgs_aug = transform(imgs)

        # Create subplots for original and augmented images
        plt.figure(figsize=win_size)
        plt.subplot(1, 2, 1)
        plt.imshow(_to_vis(imgs))
        plt.title("Original Images")

        plt.subplot(1, 2, 2)
        plt.imshow(_to_vis(imgs_aug))
        plt.title("Augmented Images")

        plt.show()

# To display one batch from the training DataLoader
show_batch(train_dataloader, num_batches=3)

AttributeError: 'str' object has no attribute 'shape'

# Models

In [23]:
pretrained = True

configuration = {
    "num_labels": 4,
    "num_attention_heads": 4,
    "num_hidden_layers":4
}

id2label = {0: 'no', 1: 'yellow', 2: 'orange', 3: 'red'}
label2id = {"no": 0, "yellow": 1, "orange": 2, "red": 3}

if pretrained == True:
    vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224",
                                                            num_labels=4,
                                                            id2label=id2label,
                                                            label2id=label2id,
                                                            ignore_mismatched_sizes = True)
else:
    config = ViTConfig(**configuration)
    vit = ViTForImageClassification(config=config)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Model configuration

In [25]:
selected_model="google_vit"

hyperparameters = {
  "train_dataset": train_dataset,
  "test_dataset": test_dataset,
  "batch_size": 16,
  "lr": 0.0001,
  "optimizer": "sgd",
  "num_workers": 0,
  "pretrained": False,
  "configuration": configuration
}
# Instantiate lightning model
if selected_model == "google_vit":
  model = vit
elif selected_model == "resnet18":
  model =  ResNet18LightningModule(**hyperparameters)
elif selected_model == "beit": 
  model =  BEiTLightningModule(**hyperparameters)
else:
  raise ValueError("Invalid model name. Please choose either 'google_vit' or 'resnet18'.")

table_data = []
table_data.append(["MODEL HYPERPARAMETERS"])
table_data.append(["model", selected_model])
for key, value in hyperparameters.items():
    if key not in ["train_dataset", "test_dataset"]:
      table_data.append([key, value])

table = tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")
print(table)

╒═══════════════╤═════════════════════════════════════════════════════════════════════╕
│               │ MODEL HYPERPARAMETERS                                               │
╞═══════════════╪═════════════════════════════════════════════════════════════════════╡
│ model         │ google_vit                                                          │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ batch_size    │ 16                                                                  │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ lr            │ 0.0001                                                              │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ optimizer     │ sgd                                                                 │
├───────────────┼─────────────────────────────────────────────────────────────────────┤
│ num_workers   │ 0             

In [26]:
model.config

ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224",
  "architectures": [
    "ViTForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "no",
    "1": "yellow",
    "2": "orange",
    "3": "red"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "no": 0,
    "orange": 2,
    "red": 3,
    "yellow": 1
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.33.3"
}

# Trainer configuration

In [27]:
from transformers import TrainingArguments, Trainer
import numpy as np

# Logger configuration
name_trained = "pretrained_" if pretrained==True else ""
model_name = f"{name_trained}{selected_model}/{hyperparameters['optimizer']}/{hyperparameters['lr']}_{hyperparameters['batch_size']}"
logger = TensorBoardLogger("tb_logs", name=model_name)

args = TrainingArguments(
    model_name,
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=hyperparameters['lr'],
    per_device_train_batch_size=hyperparameters['batch_size'],
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=hyperparameters['batch_size'],
    num_train_epochs=3,
    optim=hyperparameters['optimizer'],
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

In [28]:
from datasets import load_metric

metric = load_metric("accuracy")
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)


  metric = load_metric("accuracy")


In [29]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

# Train

In [30]:
train_results = trainer.train()
# rest is optional but nice to have
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

  0%|          | 0/378 [00:00<?, ?it/s]



KeyboardInterrupt: 