In [1]:
import lightning as L
import torch
from torch import nn
torch.set_float32_matmul_precision('medium')

# Data Prep

In [2]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import os
import shutil
from sklearn.model_selection import train_test_split

In [4]:
L.seed_everything(123)

Seed set to 123


123

In [5]:
#Prepare dataset for training
image_path = '../data/CombinedAll'
dest_dir = '../data/CombinedAll'
categories = ['parkinson', 'sehat']

# Create destination directories
for category in categories:
    os.makedirs(os.path.join(dest_dir, 'train', category), exist_ok=True)
    os.makedirs(os.path.join(dest_dir, 'test', category), exist_ok=True)

# Split and copy files
for category in categories:
    category_path = os.path.join(image_path, category)
    files = os.listdir(category_path)
    train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
    
    for file in train_files:
        shutil.copy(os.path.join(category_path, file), os.path.join(dest_dir, 'train', category, file))
    
    for file in test_files:
        shutil.copy(os.path.join(category_path, file), os.path.join(dest_dir, 'test', category, file))

print("Dataset split into training and test sets successfully.")

Dataset split into training and test sets successfully.


In [6]:
from pathlib import Path

In [7]:
# Setup path to data folder
data_path = Path("../data/")
image_path = data_path / "CombinedAll"

In [8]:
# Setup Dirs
train_dir = image_path / "train"
test_dir = image_path / "test"
train_dir, test_dir

(WindowsPath('../data/CombinedAll/train'),
 WindowsPath('../data/CombinedAll/test'))

# Data Management

In [9]:
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import os

In [10]:
L.seed_everything(123)

Seed set to 123


123

In [11]:
# class DataModule(L.LightningDataModule):
#     def __init__(
#         self,
#         data_dir: str,
#         batch_size: int = 32,
#         num_workers: int = 8,
#         image_size: int = 224
#     ):
#         super().__init__()
#         self.data_dir = data_dir
#         self.batch_size = batch_size
#         self.num_workers = num_workers
#         self.image_size = image_size
#         self.transform = transforms.Compose([
#             transforms.Resize((self.image_size, self.image_size)),
#             transforms.ToTensor()
#         ])

#         self.train_transform = transforms.Compose([
#             transforms.Resize((image_size, image_size)),
#             transforms.RandomHorizontalFlip(),
#             transforms.RandomRotation(10),
#             transforms.ToTensor(),
#             transforms.Normalize(
#                 mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225]
#             )
#         ])
        
#         self.val_transforms = transforms.Compose([
#             transforms.Resize((image_size, image_size)),
#             transforms.ToTensor(),
#             transforms.Normalize(
#                 mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225]
#             )
#         ])
        
#     def setup(self, stage=None):
#         if stage == 'fit' or stage is None:
#             self.train_dataset = datasets.ImageFolder(
#                 root=os.path.join(self.data_dir, 'train'),
#                 transform=self.train_transforms
#             )
#             self.val_dataset = datasets.ImageFolder(
#                 root=os.path.join(self.data_dir, 'test'),
#                 transform=self.val_transforms
#             )
    
#     def train_dataloader(self):
#         return DataLoader(
#             self.train_dataset,
#             batch_size=self.batch_size,
#             shuffle=True,
#             num_workers=self.num_workers,
#             pin_memory=True
#             persistent_workers=True
#         )

#     def val_dataloader(self):
#         return DataLoader(
#             self.val_dataset,
#             batch_size=self.batch_size,
#             shuffle=False,
#             num_workers=self.num_workers,
#             pin_memory=True
#             persistent_workers=True
#         )

#     def test_dataloader(self):
#         return self.val_dataloader()

In [12]:
num_workers = os.cpu_count() - 8
num_workers
batch_size = 32

In [13]:
vitb16_0_weights = models.ViT_B_16_Weights.DEFAULT 
vitb16_0_transforms = vitb16_0_weights.transforms()
print(vitb16_0_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [14]:
train_transform = transforms.Compose(
   [
      transforms.Resize(256),
      transforms.CenterCrop(224),
      # transforms.RandomApply([transforms.RandomRotation(10)], p=0.5),
      # transforms.RandomHorizontalFlip(p=0.5),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
   ],
)

test_transform = transforms.Compose(
   [
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
   ],
)


In [15]:
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(test_dir, transform=test_transform)
train_dataset,val_dataset 

(Dataset ImageFolder
     Number of datapoints: 1011
     Root location: ..\data\CombinedAll\train
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 Dataset ImageFolder
     Number of datapoints: 253
     Root location: ..\data\CombinedAll\test
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ))

In [16]:
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

In [17]:
# L.seed_everything(123)
# NUM_IMAGES = 4
# images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
# orig_images = [Image.open(train_dataset.samples[idx][0]) for idx in range(NUM_IMAGES)]
# orig_images = [test_transform(img) for img in orig_images]

# img_grid = torchvision.utils.make_grid(torch.stack(images, dim=0), nrow=4, normalize=True, pad_value=0.5)
# img_grid = img_grid.permute(1, 2, 0)

# plt.figure(figsize=(8, 8))
# plt.title("Augmentation examples")
# plt.imshow(img_grid)
# plt.axis("off")
# plt.show()
# plt.close()

In [18]:
from torchvision.transforms import functional as F
import random
import numpy as np

In [19]:

# folder_path = "../data/CombinedAll/sehat"


# image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
# selected_images = random.sample(image_files, 4)


# transform_list = [
#     ("Horizontal Flip", transforms.RandomHorizontalFlip(p=1.0)),
#     ("Rotate 45°", transforms.RandomRotation(degrees=(45, 45))),
#     ("Vertical Flip", transforms.RandomVerticalFlip(p=1.0)),
#     ("Color Jitter", transforms.ColorJitter(brightness=0.5))
# ]


# original_images = []
# transformed_images = []
# titles = []

# for idx, img_file in enumerate(selected_images):
#     img_path = os.path.join(folder_path, img_file)
#     img = Image.open(img_path).convert('RGB')
    
#     original_images.append(img)
    
#     transform = transform_list[idx][1]
#     transformed_img = transform(img)
    
#     transformed_images.append(transformed_img)
#     titles.append(transform_list[idx][0])

# fig, axes = plt.subplots(2, 4, figsize=(30, 20))

# for idx, img in enumerate(original_images):
#     axes[idx//2, idx%2*2].imshow(img)
#     axes[idx//2, idx%2*2].set_title('Original',fontsize=30, color='red')
#     axes[idx//2, idx%2*2].axis('off')

# for idx, (img, title) in enumerate(zip(transformed_images, titles)):
#     axes[idx//2, idx%2*2 + 1].imshow(img)
#     axes[idx//2, idx%2*2 + 1].set_title(f'Transformed: {title}', fontsize=30, color='green')
#     axes[idx//2, idx%2*2 + 1].axis('off')

# plt.tight_layout()
# plt.show()

In [20]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, persistent_workers=True)

# ViT Model Prep

In [21]:
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchmetrics.classification import BinaryConfusionMatrix
import io
import numpy as np
import optuna

In [22]:
L.seed_everything(123)

Seed set to 123


123

In [23]:
class VisionTransformerClassifier(L.LightningModule):
    def __init__(self, trial=None):
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        self.trial = trial
        if self.trial:
            freeze_backbone = trial.suggest_categorical("freeze_backbone", [True, False])
            if freeze_backbone:
                for param in self.vit.parameters():
                    param.requires_grad = False
                
        if self.trial:
            dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.5)
            
        self.vit.heads = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(in_features=768, out_features=1)
        )
            
        self.loss_fn = nn.BCEWithLogitsLoss()
        
        self.confusion_matrix = BinaryConfusionMatrix()
    
        self.val_preds = []
        self.val_labels = []
        self.training_step_outputs = []
        self.training_epoch_losses = []
        self.training_epoch_accs = []
        self.validation_epoch_losses = []
        self.validation_epoch_accs = []
        
    def forward(self, x):
        return self.vit(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        logits = self(x)
        loss = self.loss_fn(logits.squeeze(), y)
        preds = torch.sigmoid(logits.squeeze()) > 0.5
        acc = (preds == y).float().mean()        
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        
        self.training_step_outputs.append({'loss': loss, 'acc': acc})
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        logits = self(x)
        preds = torch.sigmoid(logits.squeeze()) > 0.5
        acc = (preds == y).float().mean()
        loss = self.loss_fn(logits.squeeze(), y)
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        
        self.val_preds.append(preds)
        self.val_labels.append(y)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        logits = self(x)
        preds = torch.sigmoid(logits.squeeze()) > 0.5
        acc = (preds == y).float().mean()        
        loss = self.loss_fn(logits.squeeze(), y)
        self.log('test_loss', loss)
        self.log("test_acc", acc)
    
    def configure_optimizers(self):
        if self.trial:
            optimizer_name = self.trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
            lr = self.trial.suggest_float("lr", 1e-6, 1e-1, log=True)
            optimizer = getattr(torch.optim, optimizer_name)(self.parameters(), lr=lr)
            
            return optimizer
        else:
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-03)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=10, eta_min=1e-6
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val_loss",
                },            
            }

    def on_train_epoch_end(self):
        avg_loss = self.trainer.callback_metrics['train_loss']
        avg_acc = self.trainer.callback_metrics['train_acc']
        
        self.training_epoch_losses.append(avg_loss.item())
        self.training_epoch_accs.append(avg_acc.item())
            
        
        if len(self.training_epoch_losses) > 0:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Loss plot
            # Omit the first value from the list as its messes up the scale
            ax1.plot(self.training_epoch_losses[1:], label='Train Loss')
            if self.validation_epoch_losses:
                ax1.plot(self.validation_epoch_losses[1:], label='Val Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend()
            
            # Accuracy plot
            ax2.plot(self.training_epoch_accs[1:], label='Train Accuracy')
            if self.validation_epoch_accs:
                ax2.plot(self.validation_epoch_accs[1:], label='Val Accuracy')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Accuracy')
            ax2.legend()
            
            # Log to tensorboard
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            im = transforms.ToTensor()(Image.open(buf))
            self.logger.experiment.add_image('training_curves', im, global_step=self.current_epoch)
            
            plt.close()
        
        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_preds)
        all_labels = torch.cat(self.val_labels)
        self.confusion_matrix(all_preds, all_labels)

        fig, ax = plt.subplots(figsize=(8, 8))
        self.confusion_matrix.plot(ax=ax, labels=["Healthy", "Patient"])
        
        buf = io.BytesIO()
        fig.savefig(buf, format="png")
        buf.seek(0)
        im = transforms.ToTensor()(Image.open(buf))
        
        self.logger.experiment.add_image(
            "confusion_matrix",
            im,
            global_step=self.current_epoch
        )
        
        self.val_preds.clear()
        self.val_labels.clear()
        plt.close()

        self.validation_epoch_losses.append(self.trainer.callback_metrics['val_loss'].item())
        self.validation_epoch_accs.append(self.trainer.callback_metrics['val_acc'].item())

In [24]:
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [25]:
def objective(trial):
    model = VisionTransformerClassifier(trial)

    early_stopping = EarlyStopping(monitor="val_loss", patience=5, mode="min", verbose=False)
    # pruner = PyTorchLightningPruningCallback(trial, monitor="val_loss")
    logger = TensorBoardLogger(save_dir="../lightning_logs", name=f"vitb16_tuning_trial_{trial.number}")

    trainer = L.Trainer(
      max_epochs=20,
      callbacks=[early_stopping],
      logger=logger,
      accelerator="gpu" if torch.cuda.is_available() else "cpu",
      devices="auto",
      log_every_n_steps=1,
      enable_progress_bar=False,
    )

    trainer.fit(model, train_loader, val_loader)

    val_loss = trainer.callback_metrics["val_loss"].cpu().item()
    val_acc = trainer.callback_metrics["val_acc"].cpu().item()
    return val_loss, val_acc 

In [26]:
study = optuna.create_study(
    directions=["minimize", "maximize"],
    pruner=optuna.pruners.MedianPruner()
)

study.optimize(objective, n_trials=10)  

print("Number of finished trials: ", len(study.trials))
print("Best trials (Pareto front):")
for trial in study.best_trials:
    print(f"  Trial Number: {trial.number}")
    print(f"    Values (loss, accuracy): {trial.values}")
    print("    Params: ")
    for key, value in trial.params.items():
        print(f"      {key}: {value}")

[I 2024-12-24 23:36:04,851] A new study created in memory with name: no-name-07afc5b0-5c28-4368-82b6-fe47f95dd7e3
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                  | Params | Mode 
-------------------------------------------------------------------
0 | vit              | VisionTransformer     | 85.8 M | train
1 | loss_fn          | BCEWithLogitsLoss     | 0      | train
2 | confusion_matrix | BinaryConfusionMatrix | 0      | train
-------------------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.198   Total estimated model params size (MB)
155       Modules in train mode
0         Modules in eval mode
[I 2024-12-24 23:41:23,834] Trial 0 finished with values: [0.3164321184158325, 0.9169960618019104] and parameters: {'freeze_backbone': False, 'dropout

Number of finished trials:  10
Best trials (Pareto front):
  Trial Number: 1
    Values (loss, accuracy): [0.14930643141269684, 0.9446640610694885]
    Params: 
      freeze_backbone: False
      dropout_rate: 0.09695059257094596
      optimizer: SGD
      lr: 0.0013561414893345516


In [27]:
# import optuna.visualization as vis

In [None]:
# vis.plot_pareto_front(study, target_names=["val_loss", "val_acc"])