In [1]:
import numpy as np
import yaml
import torch
import random
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import transforms
from torchvision.utils import make_grid, draw_bounding_boxes

from Capsule.Classifier import CapsuleWrappingClassifier
from Capsule.Segment import CapsuleWrappingSegment
from Capsule.Detector import CapsuleWrappingDetector
from Capsule.ultis import *

from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from neptune.types import File
from pytorch_lightning.loggers import NeptuneLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics
from torchmetrics.detection.mean_ap import MeanAveragePrecision

  from .autonotebook import tqdm as notebook_tqdm
  from neptune.version import version as neptune_client_version
  from neptune import new as neptune


In [2]:
seed = 666
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.set_float32_matmul_precision('medium')

# Training module

In [3]:
def get_lr_scheduler_config(optimizer, settings):
    '''
    set up learning rate scheduler
    Args:
        optimizer: optimizer
        settings: settings hyperparameters
    Returns:
        lr_scheduler_config: [learning rate scheduler, configuration]
    '''
    if settings['lr_scheduler'] == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=settings['lr_step'], gamma=settings['lr_decay'])
    elif settings['lr_scheduler'] == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=settings['lr_step'], gamma=settings['lr_decay'])
    elif settings['lr_scheduler'] == 'reduce_on_plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.1, patience=10, threshold=0.0001)
    else:
        raise NotImplementedError

    return {
            'scheduler': scheduler,
            'monitor': 'metrics/batch/train_loss',
            'interval': 'epoch',
            'frequency': 1,
        }

def get_optimizer(parameters, settings):
    '''
    set up learning optimizer
    Args:
        parameters: model's parameters
        settings: settings hyperparameters
    Returns:
        optimizer: optimizer
    '''
    if settings['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=settings['lr'], weight_decay=settings['weight_decay'])
    elif settings['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(
            parameters, lr=settings['lr'], weight_decay=settings['weight_decay'], momentum=settings['momentum'])
    else:
        raise NotImplementedError()

    return optimizer

def get_loss_function(type, n_classes=2):
    '''
    set up loss function
    Args:
        settings: settings hyperparameters, 
        n_classes: number of classes
    Returns:
        loss: loss function
    '''
    if type == "ce": loss = nn.CrossEntropyLoss()
    elif type == "nll": loss = nn.NLLLoss()
    elif type == "bce": loss = nn.BCELoss()    
    elif type == "spread": loss = SpreadLoss(num_classes=n_classes)
    elif type == "margin": loss = MarginLoss(num_classes=n_classes)
    elif type == "mse": loss = nn.MSELoss()
    elif type == "none": loss = None # only for task == detection
    else: raise NotImplementedError()

    return loss

def get_gpu_settings(gpu_ids, n_gpu):
    '''
    Get gpu settings for pytorch-lightning trainer:
    Args:
        gpu_ids (list[int])
        n_gpu (int)
    Returns:
        tuple[str, int, str]: accelerator, devices, strategy
    '''
    if not torch.cuda.is_available():
        return "cpu", None, None

    if gpu_ids is not None:
        devices = gpu_ids
        strategy = "ddp" if len(gpu_ids) > 1 else 'auto'
    elif n_gpu is not None:
        devices = n_gpu
        strategy = "ddp" if n_gpu > 1 else 'auto'
    else:
        devices = 1
        strategy = 'auto'

    return "gpu", devices, strategy

def get_basic_callbacks():
    '''
    Get basic callbacks for pytorch-lightning trainer:
    Args: 
        None
    Returns:
        last ckpt, best ckpt, lr callback, early stopping callback
    '''
    lr_callback = LearningRateMonitor(logging_interval='epoch')
    last_ckpt_callback = ModelCheckpoint(
        filename='last_model_{epoch:03d}-{val/loss:.4f}-{val/acc:02.0f}',
        auto_insert_metric_name=False,
        save_top_k=1,
        monitor=None,
    )
    best_ckpt_calllback = ModelCheckpoint(
        filename='best_model_{epoch:03d}-{val/loss:.4f}-{val/acc:02.0f}',
        auto_insert_metric_name=False,
        save_top_k=1,
        monitor='metrics/batch/train_loss',
        mode='min',
        verbose=True
    )
    early_stopping_callback = EarlyStopping(
        monitor='metrics/batch/train_loss',  # Metric to monitor for improvement
        mode='min',  # Choose 'min' or 'max' depending on the metric (e.g., 'min' for loss, 'max' for accuracy)
        patience=10,  # Number of epochs with no improvement before stopping
    )
    return [last_ckpt_callback, best_ckpt_calllback, lr_callback, early_stopping_callback]

def get_trainer(settings, task) -> Trainer:
    '''
    Get trainer and logging for pytorch-lightning trainer:
    Args: 
        settings: hyperparameter settings
        task: task to run training
    Returns:
        trainer: trainer object
        logger: neptune logger object
    '''
    callbacks = get_basic_callbacks()
    accelerator, devices, strategy = get_gpu_settings(settings['gpu_ids'], settings['n_gpu'])

    neptune_logger = NeptuneLogger(
        project="kaori/Capsule-wrap",
        api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyZjZiMDA2YS02MDM3LTQxZjQtOTE4YS1jODZkMTJjNGJlMDYifQ==",
        log_model_checkpoints=False,
        tags=[task]
    )

    trainer = Trainer(
        logger=[neptune_logger],
        max_epochs=settings['n_epoch'],
        default_root_dir=settings['ckpt_path'],
        accelerator=accelerator,
        devices=devices,
        strategy=strategy,
        callbacks=callbacks,
    )
    return trainer, neptune_logger

In [4]:
class DataModule(LightningDataModule):
    '''
    Data Module for Train/Val/Test data loadding
    Args: 
        data_settings, training_settings: hyperparameter settings
    Returns:
        Train/Test/Val data loader
    '''
    def __init__(self, data_settings, training_settings):
        super().__init__()

        self.dataset = data_settings['name']
        self.root_dir = data_settings['path']
        self.img_size = data_settings['img_size']
        self.batch_size = training_settings['n_batch']
        self.num_workers = training_settings['num_workers']
        self.class_list = None
        self.transform = None
        self.collate_fn = None

    def setup(self, stage: str):

        if stage == "fit":
            if self.dataset == 'CIFAR10':
                self.Train_dataset = CIFAR10read(mode="train", data_path=self.root_dir, 
                                                transform=self.transform, imgsize=self.img_size)
                self.Val_dataset =  CIFAR10read(mode="val", data_path=self.root_dir, 
                                                transform=self.transform, imgsize=self.img_size)
            elif self.dataset == 'LungCT-Scan':
                dataset = LungCTscan(data_dir=self.root_dir, transform=self.transform, imgsize=self.img_size)
                self.Train_dataset = Subset(dataset, range(int(len(dataset) * 0.8)))
                self.Val_dataset = Subset(dataset, range(int(len(dataset) * 0.8), len(dataset)))
                # self.Val_dataset = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
            elif self.dataset == 'PennFudan':
                self.collate_fn = collate_fn
                dataset = PennFudanDataset(root=self.root_dir, transform=self.transform, imgsize=self.img_size)
                self.Train_dataset = Subset(dataset, range(int(len(dataset) * 0.8)))
                self.Val_dataset = Subset(dataset, range(int(len(dataset) * 0.8), len(dataset)))
                # self.Train_dataset, self.Val_dataset = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
                
        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            if self.dataset == 'CIFAR10':
                self.Test_dataset =  CIFAR10read(mode="test", data_path=self.root_dir, 
                                                transform=self.transform, imgsize=self.img_size)
            elif self.dataset == 'LungCT-Scan':
                dataset = LungCTscan(data_dir=self.root_dir, transform=self.transform, imgsize=self.img_size)
                self.Test_dataset = Subset(dataset, range(len(dataset)))
            elif self.dataset == 'PennFudan':
                self.collate_fn = collate_fn
                dataset = PennFudanDataset(root=self.root_dir, transform=self.transform, imgsize=self.img_size)
                self.Test_dataset = Subset(dataset, range(len(dataset)))

    def train_dataloader(self):
        return DataLoader(self.Train_dataset, batch_size=self.batch_size, shuffle=True, 
                          num_workers=self.num_workers, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.Val_dataset, batch_size=self.batch_size, shuffle=False, 
                          num_workers=self.num_workers, collate_fn=self.collate_fn)
    
    def test_dataloader(self):
        return DataLoader(self.Test_dataset, batch_size=self.batch_size, shuffle=False, 
                          num_workers=self.num_workers, collate_fn=self.collate_fn)

In [5]:
class CapsuleModel(LightningModule):
    def __init__(self, PARAMS, task='classification'):
        super().__init__()
        self.save_hyperparameters()

        self.architect_settings = PARAMS['architect_settings']
        self.train_settings = PARAMS['training_settings']
        self.dataset_settings = PARAMS['dataset_settings']
        self.task = task
        # Model selection
        if(self.task == 'classification'):
            self.model = CapsuleWrappingClassifier(model_configs=self.architect_settings)
            self.train_metrics = torchmetrics.Accuracy(task='multiclass', num_classes=self.architect_settings['n_cls'])
            self.valid_metrics = torchmetrics.Accuracy(task='multiclass', num_classes=self.architect_settings['n_cls'])
        elif(self.task == 'segmentation'):
            self.model = CapsuleWrappingSegment(model_configs=self.architect_settings)
            self.train_metrics = torchmetrics.Dice(num_classes=self.architect_settings['n_cls'])
            self.valid_metrics = torchmetrics.Dice(num_classes=self.architect_settings['n_cls'])
        elif(self.task == 'detection'):
            self.model = CapsuleWrappingDetector(model_configs=self.architect_settings)
            self.train_metrics = MeanAveragePrecision()
            self.valid_metrics = MeanAveragePrecision()
        else:
            raise NotImplementedError()

        # Loss selection
        self.loss = get_loss_function(self.train_settings['loss'], self.architect_settings['n_cls'])
      
        self.validation_step_outputs = []
    
    def forward(self, x, y=None):
        return self.model(x, y)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        if(self.task == 'detection'):
            loss_dict = self(x, y)
            loss = sum(loss for loss in loss_dict.values())
        else:
            y_hat = self(x)
            loss = self.loss(y_hat, y)
            y_pred = torch.softmax(y_hat, dim=1)
            self.train_metrics.update(y_pred.cpu(), y.cpu())

        self.log("metrics/batch/train_loss", loss, prog_bar=False)

        return loss

    def on_train_epoch_end(self):
       
        if(self.task == 'classification'):
            self.log("metrics/epoch/train_acc", self.train_metrics.compute())
        elif(self.task == 'segmentation'):
            self.log("metrics/epoch/train_dice", self.train_metrics.compute())
    
        self.train_metrics.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        if(self.task == 'detection'):
            y_hat = self(x)
            y_pred = [{k: v for k, v in t.items()} for t in y_hat]
            targets = [{k: v for k, v in t.items()} for t in y]

            self.valid_metrics.update(y_pred, targets)
            self.validation_step_outputs.append({"image": x[0], "predictions": y_pred[0], "targets": targets[0]})
        else:
            y_hat = self(x)
            loss = self.loss(y_hat, y)
            y_pred = torch.softmax(y_hat, dim=-1)
            self.valid_metrics.update(y_pred.cpu(), y.cpu())
        
            if(self.task == 'segmentation'):
                y_pred = torch.argmax(y_hat, dim=1)
                self.validation_step_outputs.append({"loss": loss.item(), "predictions": y_pred.unsqueeze(1)})
            else:
                self.validation_step_outputs.append({"loss": loss.item()})

            self.log('metrics/batch/val_loss', loss)

    def on_validation_epoch_end(self):
        
        if(self.task == 'classification'):
            self.log('metrics/epoch/val_acc', self.valid_metrics.compute())
            loss =[outputs['loss'] for outputs in self.validation_step_outputs]
            self.log('metrics/epoch/val_loss', sum(loss) / len(loss))
           
        elif(self.task == 'segmentation'):
            self.log("metrics/epoch/val_dice", self.valid_metrics.compute())
            loss =[outputs['loss'] for outputs in self.validation_step_outputs]
            self.log('metrics/epoch/val_loss', sum(loss) / len(loss))

            outputs = self.validation_step_outputs
            reconstructions = make_grid(outputs[0]["predictions"], nrow=int(self.train_settings["n_batch"] ** 0.5))
            reconstructions = reconstructions.cpu().numpy().transpose(1, 2, 0)
            self.logger.experiment["val/reconstructions"].append(File.as_image(reconstructions))
            self.validation_step_outputs.clear()
        
        elif(self.task == 'detection'):
            self.log('metrics/epoch/val_mAP', self.valid_metrics.compute()['map'])
            #no validation loss

            outputs = self.validation_step_outputs[-1]
            image, predictions, targets = outputs["image"], outputs["predictions"], outputs["targets"]
            reconstructions = draw_bounding_boxes((image * 255.).to(torch.uint8), 
                                                boxes=predictions["boxes"][:5],
                                                colors="red",
                                                width=5, font_size=20)
            reconstructions = draw_bounding_boxes(reconstructions, 
                                                boxes=targets["boxes"][:5],
                                                colors="blue",
                                                width=5, font_size=20)
            reconstructions = reconstructions.cpu().numpy().transpose(1, 2, 0) / 255.
            self.logger.experiment["val/reconstructions"].append(File.as_image(reconstructions))
            self.validation_step_outputs.clear()

        self.valid_metrics.reset()
       
    def configure_optimizers(self):
        optimizer = get_optimizer(self.model.parameters(), self.train_settings)
        lr_scheduler_config = get_lr_scheduler_config(optimizer, self.train_settings)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

# Training

In [6]:
%%script false --no-raise-error

with open("Capsule/config.yaml", 'r') as stream:
    PARAMS = yaml.safe_load(stream)
    PARAMS = PARAMS['classifier']
    print(PARAMS)

In [7]:
%%script false --no-raise-error

data = DataModule(PARAMS['dataset_settings'], PARAMS['training_settings'])

In [8]:
%%script false --no-raise-error

model = CapsuleModel(PARAMS=PARAMS, task = 'classification')
trainer, logger = get_trainer(PARAMS['training_settings'], task = 'classification')
logger.log_hyperparams(params=PARAMS)

trainer.fit(model, data)

# Testing

In [9]:
# test_model = CapsuleModel.load_from_checkpoint("models/deeplab.ckpt")
# test_model.eval()
# print(test_model.task)
# x = torch.randn(1, 3, 224, 224)
# transform = test_model.model.preprocess
# x_preprocessed = transform(x)
# with torch.no_grad():
#     y_hat = test_model(x)
#     print(y_hat)

In [10]:
with open("models/class_name.txt", "r", encoding='utf-8') as f:
    class_names = f.read().splitlines()

def predict(image, model_choice):
    labels, segment, detection = None, None, None
    model = CapsuleModel.load_from_checkpoint(f"models/{model_choice}.ckpt")
    model.eval()
    transforms = model.model.preprocess
    tensor_image = transforms(image)
    with torch.no_grad():
        y_hat = model(tensor_image.unsqueeze(0))
        if(model.task == "classification"):
            preds = torch.softmax(y_hat, dim=-1).tolist()
            labels = {class_names[k]: float(v) for k, v in enumerate(preds[0][:-1])}
        elif(model.task == "segmentation"):
            y_pred = torch.argmax(y_hat, dim=1)
            segment = y_pred.squeeze(0).numpy()
        elif(model.task == "detection"):
            detection = draw_bounding_boxes((tensor_image * 255.).to(torch.uint8), 
                                                boxes=y_hat[0]["boxes"][:5],
                                                colors="red",
                                                width=5)
            detection = detection.numpy().transpose(1, 2, 0) / 255.

    return labels, segment, detection

In [13]:
import gradio as gr

title = "Capsule Network Application Demo "
description = "# A Demo of Capsule Network Application"
example_list = [["examples/" + example] for example in os.listdir("examples")]

with gr.Blocks() as demo:
    demo.title = title
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            model = gr.Dropdown(['swin', 'faster-rcnn', 'deeplab'], label="Select Model", interactive=True)
            im = gr.Image(type="pil", label="input image")
        with gr.Column():
            label_conv = gr.Label(label="Predictions", num_top_classes=4)
            im_segment = gr.Image(type="pil", label="Segment")
            im_detection = gr.Image(type="pil", label="Detection")
            btn = gr.Button(value="predict")
    btn.click(predict, inputs=[im, model], outputs=[label_conv, im_segment, im_detection])
    gr.Examples(examples=example_list, inputs=[im, model], outputs=[label_conv, im_segment, im_detection])
      

In [14]:
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://41e307cbef21d65ac3.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces




Traceback (most recent call last):
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/gradio/routes.py", line 394, in run_predict
    output = await app.get_blocks().process_api(
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/gradio/blocks.py", line 1075, in process_api
    result = await self.call_function(
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/gradio/blocks.py", line 884, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/anyio/to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 867, in run
    result = context.run(func, *