In [1]:
import numpy as np
import yaml
import torch
import random
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.utils import make_grid

from new_Capsule.Model import *
from new_Capsule.ultis import *
from pytorch_lightning import LightningModule, Trainer
from neptune.types import File
from pytorch_lightning.loggers import NeptuneLogger
from pytorch_lightning.callbacks import ModelCheckpoint

  from .autonotebook import tqdm as notebook_tqdm
  from neptune.version import version as neptune_client_version
  from neptune import new as neptune


In [2]:
seed = 666
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.set_float32_matmul_precision('medium')

with open("new_Capsule/config.yaml", 'r') as stream:
    PARAMS = yaml.safe_load(stream)
    PARAMS = PARAMS['config2']
    print(PARAMS)

{'architect_settings': {'name': 'capsconv', 'reconstructed': True, 'n_cls': 10, 'convs': {'channels': [1, 64, 128], 'k': 5, 's': 2, 'p': 2}, 'primay_caps': {'channels': [128, 32], 'k': 3, 's': 2, 'p': 1}, 'caps': {'init': 'noisy_identity', 'cap_dims': 4, 'channels': [32, 16, 10], 'k': 3, 's': 1, 'p': 0, 'routing': {'type': 'fuzzy', 'params': [3]}}}, 'training_settings': {'loss': 'spread', 'CAM': False, 'lr': 0.0001, 'lr_step': 10, 'lr_decay': 0.8, 'momentum': 0.9, 'weight_decay': 0.005, 'n_epoch': 100, 'n_batch': 128, 'dataset': 'center-Mnist', 'ckpt_path': 'test'}}


# Training module

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class CapsuleModel(LightningModule):
    def __init__(self, PARAMS):
        super().__init__()

        self.reconstructed = False
        self.segment = False
       
        self.architect_settings = PARAMS["architect_settings"]
        self.train_settings = PARAMS["training_settings"]
        assert self.architect_settings['name'] in ["conv", "Unet", "capsconv", "capsUnet", "eff-capsconv", "eff-capsUnet"], \
                    "Model is not implemented yet"
        
        # Model selection
        if(self.architect_settings['name'] == "conv"):
            self.model = ConvNeuralNet(model_configs=self.architect_settings)
        elif(self.architect_settings['name'] in ["capsconv", "eff-capsconv"]):
            self.model = CapNets(model_configs=self.architect_settings)
        elif(self.architect_settings['name'] in ["capsUnet", "eff-capsUnet"]):
            self.segment = True
            self.model = CapConvUNet(model_configs=self.architect_settings)
        elif(self.architect_settings['name'] == "Unet"):
            self.segment = True
            self.model = ConvUNet(model_configs=self.architect_settings)

        assert self.train_settings['loss']  in ["ce", "spread", "bce", "mse", "margin", "dice"], \
                                                "Loss is not implemented yet"
        # Loss selection
        if(self.train_settings['loss'] == "ce"):                                                                                                                
            self.loss = CrossEntropyLoss(num_classes=self.architect_settings['n_cls'])
        elif(self.train_settings['loss'] == "spread"):
            self.loss = SpreadLoss(num_classes=self.architect_settings['n_cls'])
        elif(self.train_settings['loss'] == "margin"):
            self.loss = MarginLoss(num_classes=self.architect_settings['n_cls'])
        elif(self.train_settings['loss'] == "bce"):
            self.loss = BCE()
        elif(self.train_settings['loss'] == "mse"):
            self.loss = MSE()
        elif(self.train_settings['loss'] == "dice"):
            self.loss = DiceLoss()
      
        # For reconstruction purposes (regularizer)
        if(self.architect_settings['reconstructed']):
            self.reconstructed = True
            self.reconstruct_loss = MSE()

        # For visualizaiton purposes
        if(self.train_settings['CAM']):
            self.features_blobs = []
            # for CAM visualization
            if("caps" in self.architect_settings['name']):
                def hook_feature(module, input, output):
                    self.features_blobs.append(np.array(output[0].tolist()))
                self.model.caps_layers[0].register_forward_hook(hook_feature)
                self.model.caps_layers[1].register_forward_hook(hook_feature)
            else:
                def hook_feature(module, input, output):
                    self.features_blobs.append(np.array(output.tolist()))
                self.model.caps_layers[0].register_forward_hook(hook_feature)

                
        # For logging purposes
        # self.log("metrics/no_parameters", count_parameters(self.model))
        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, x, y=None):
        return self.model(x, y)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        if(self.segment):
            y_hat = self(x)
            loss = self.loss(y_hat, y)
            y_true = y.cpu().detach()
            y_pred = y_hat.cpu().detach()
            
        else:
            if(self.reconstructed):
                y_hat, reconstructions = self(x, y)
                loss = self.loss(y_hat, y) + 0.1 * self.reconstruct_loss(reconstructions, x)
            else:
                y_hat = self(x)
                loss = self.loss(y_hat, y)
            y_true = y.tolist()
            y_pred = y_hat.argmax(axis=1).tolist()
            self.log("metrics/batch/acc", accuracy_score(y_true, y_pred), prog_bar=False)

        self.log("metrics/batch/loss", loss, prog_bar=False)
        self.training_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred})
        return loss

    def on_train_epoch_end(self):
        outputs = self.training_step_outputs
        loss = np.array([])
        y_true = np.array([])
        y_pred = np.array([])
        for results_dict in outputs:
            loss = np.append(loss, results_dict["loss"])
            y_true = np.append(y_true, results_dict["y_true"])
            y_pred = np.append(y_pred, results_dict["y_pred"])
       
        self.log("metrics/epoch/loss", loss.mean())
        if(self.segment == False):
            self.log("metrics/epoch/acc", accuracy_score(y_true, y_pred))
        self.training_step_outputs.clear()

    def validation_step(self, batch, batch_idx):
        x, y = batch
       
        if(self.segment):
            y_hat = self(x)
            loss = self.loss(y_hat, y)
            y_true = y.cpu().detach()
            y_pred = torch.sigmoid(y_hat).cpu().detach()
            self.validation_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred})

        else:
            if(self.reconstructed):
                y_hat, reconstructions = self(x, y)
                loss = self.loss(y_hat, y) + 0.1 * self.reconstruct_loss(reconstructions, x)
                y_true = y.tolist()
                y_pred = y_hat.argmax(axis=1).tolist()
                self.validation_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred, "reconstructions": reconstructions})

            else:
                y_hat = self(x)
                loss = self.loss(y_hat, y)
                y_true = y.tolist()
                y_pred = y_hat.argmax(axis=1).tolist()
                self.validation_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred})
    

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        loss = np.array([])
        y_true = np.array([])
        y_pred = np.array([])
        for results_dict in outputs:
            loss = np.append(loss, results_dict["loss"])
            y_true = np.append(y_true, results_dict["y_true"])
            y_pred = np.append(y_pred, results_dict["y_pred"])
        self.log("val/epoch/loss", loss.mean())
        if(self.segment):
            y_true = make_grid(outputs[0]["y_true"], nrow=int(self.train_settings["n_batch"] ** 0.5))
            y_pred = make_grid(outputs[0]["y_pred"], nrow=int(self.train_settings["n_batch"] ** 0.5))
            y_true = y_true.numpy().transpose(1, 2, 0)
            y_pred = y_pred.numpy().transpose(1, 2, 0)
            
            self.logger.experiment["val/gt_images"].append(File.as_image(y_true))
            self.logger.experiment["val/outputs"].append(File.as_image(y_pred))
        else:
            self.log("val/epoch/acc", accuracy_score(y_true, y_pred))
            if(self.reconstructed):
                reconstructions = make_grid(outputs[0]["reconstructions"], nrow=int(self.train_settings["n_batch"] ** 0.5))
                reconstructions = reconstructions.cpu().numpy().transpose(1, 2, 0)
                self.logger.experiment["val/reconstructions"].append(File.as_image(reconstructions))

        self.validation_step_outputs.clear()


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.train_settings['lr'])
        # optimizer = torch.optim.SGD(self.parameters(), lr=self.train_settings['lr'], 
        #                             weight_decay=self.train_settings['weight_decay'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.train_settings['lr_step'], 
                                                    gamma=self.train_settings['lr_decay'])
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

### affine Mnist dataset

In [4]:
%%script false --no-raise-error

Train_transform = transforms.Compose([
            transforms.ToPILImage(),
            # transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            # transforms.Normalize(mean = (0.5,), std = (0.5,))
        ])
    
Train_data = affNistread(mode="train", data_path="data/affMnist", transform=Train_transform)
Val_data =  affNistread(mode="val", data_path="data/affMnist", transform=Train_transform)

### center Mnist dataset

In [5]:
# %%script false --no-raise-error

Train_transform = transforms.Compose([
            transforms.ToPILImage(),
            # transforms.RandomAffine(degrees=0, translate=(0.125, 0.125)),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean = (0.5,), std = (0.5,))
        ])
Test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomAffine(degrees=30, translate=(0.2, 0.2)),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean = (0.5,), std = (0.5,))
        ])

Train_data = affNistread(mode="train", data_path="data/centerMnist", aff=False, transform=Train_transform)
Val_data =  affNistread(mode="val", data_path="data/centerMnist", aff=False, transform=Train_transform)

### smallNorb

In [6]:
%%script false --no-raise-error

Train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(degrees=10),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
            transforms.ColorJitter(brightness=0.5),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            # transforms.Normalize(mean = (0.5,), std = (0.5,))
            ])
            
Test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(32),
    transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
    # transforms.Normalize(mean = (0.5,), std = (0.5,))
])

Train_data = SmallNorbread(mode="train", data_path="data/smallNorb", transform=Train_transform)
Val_data = SmallNorbread(mode="val", data_path="data/smallNorb", transform=Test_transform)
Test_data = SmallNorbread(mode="test", data_path="data/smallNorb", transform=Test_transform)


### Lung CT Scan

In [7]:
%%script false --no-raise-error

Train_transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            # transforms.Normalize((0.5), (0.5))
        ])

Test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomAffine(degrees=30, translate=(0.1, 0.1)),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            # transforms.Normalize((0.5), (0.5))
        ])

train_ds = LungCTscan(data_dir="data/CT-scan-dataset", transform=Train_transform)
val_ds = LungCTscan(data_dir="data/CT-scan-dataset", transform=Train_transform)
Train_data, _ = random_split(train_ds, [200, 67])
_, Val_data = random_split(val_ds, [200, 67])
_, Test_data = random_split(val_ds, [200, 67])

# Training

In [8]:
# %%script false --no-raise-error

neptune_logger = NeptuneLogger(
        project="kaori/Capsule",
        api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyZjZiMDA2YS02MDM3LTQxZjQtOTE4YS1jODZkMTJjNGJlMDYifQ==",
        log_model_checkpoints=False,
        # tags=["full-config"]
)

neptune_logger.log_hyperparams(params=PARAMS)

  self._run_instance = neptune.init_run(**self._neptune_init_args)


https://app.neptune.ai/kaori/Capsule/e/CAP-272


/home/vips/anaconda3/envs/Capsule/lib/python3.9/site-packages/pytorch_lightning/loggers/neptune.py:406: NeptuneUnsupportedType: You're attempting to log a type that is not directly supported by Neptune (<class 'list'>).
        Convert the value to a supported type, such as a string or float, or use stringify_unsupported(obj)
        for dictionaries or collections that contain unsupported values.
        For more, see https://docs.neptune.ai/help/value_of_unsupported_type
  self.run[parameters_key] = params


In [9]:
# %%script false --no-raise-error

Train_dataloader = DataLoader(dataset=Train_data, batch_size = PARAMS['training_settings']['n_batch'], shuffle=True, num_workers=4)
Val_dataloader = DataLoader(dataset=Val_data, batch_size = PARAMS['training_settings']['n_batch'], shuffle=False, num_workers=4)
# Test_dataloader = DataLoader(dataset=Test_data, batch_size = PARAMS['training_settings']['n_batch'], shuffle=False, num_workers=4)

In [10]:
# %%script false --no-raise-error

model = CapsuleModel(PARAMS=PARAMS)



In [11]:
# %%script false --no-raise-error

checkpoint_callback = ModelCheckpoint(dirpath=os.path.join("model", PARAMS['training_settings']["ckpt_path"]), 
                                    save_top_k=1, monitor='metrics/batch/loss', mode="min")

trainer = Trainer(
    logger=neptune_logger,
    max_epochs=PARAMS['training_settings']["n_epoch"],
    accelerator="gpu",
    devices=[1],
    callbacks=[checkpoint_callback]
    )

trainer.fit(model=model, train_dataloaders=Train_dataloader, val_dataloaders=Val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name             | Type       | Params
------------------------------------------------
0 | model            | CapNets    | 2.3 M 
1 | loss             | SpreadLoss | 0     
2 | reconstruct_loss | MSE        | 0     
------------------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
9.369     Total estimated model params size (MB)


Epoch 99: 100%|██████████| 391/391 [00:30<00:00, 13.02it/s, v_num=-272]    

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 391/391 [00:30<00:00, 13.02it/s, v_num=-272]


# Inference

In [16]:
def normalize(img):
    img = img - np.min(img)
    nor_img = img / np.max(img)
    nor_img = np.uint8(255 * nor_img)
    return nor_img

def returnCAM(img, feature_conv, weight_softmax, class_idx):
    # generate the class activation maps upsample to 256x256
    size_upsample = (40, 40)
    bz, nc, h, w = feature_conv.shape
    output_cam = []
   
    weight_softmax = weight_softmax.sum(axis=(2, 3))
    feature_conv = feature_conv[0]
    for idx in class_idx:
        cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
        cam = cam.reshape(h, w)
        cam = normalize(cam)
       
        output_cam.append(cv2.resize(cam, size_upsample))

    img = cv2.cvtColor(np.array(img),cv2.COLOR_GRAY2RGB)
    height, width, _ = img.shape
    CAM = cv2.resize(output_cam[0], (width, height))
    heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET)
    result = heatmap * 0.3 + img * 0.5
    
    return normalize(result)

def CapsuleAttentionMap(img, feature_conv, global_feature, class_idx):

    size_upsample = (40, 40)
    bz, nc, p, h, w = feature_conv.shape
    output_cam = []
    
    global_feature = global_feature[0].squeeze()
    feature_conv = feature_conv[0].squeeze().transpose(2, 3, 0, 1)
    # feature_conv = feature_conv
    
    for idx in class_idx:
        cam = feature_conv.dot(global_feature[idx])
        cam = cam.sum(axis=-1)
        cam = cam.reshape(h, w)
        cam = normalize(cam)
       
        output_cam.append(cv2.resize(cam, size_upsample))

    img = cv2.cvtColor(np.array(img),cv2.COLOR_GRAY2RGB)
    height, width, _ = img.shape
    CAM = cv2.resize(output_cam[0], (width, height))
    heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET)
    result = heatmap * 0.3 + img * 0.5
    
    return normalize(result)

In [17]:
%%script false --no-raise-error

with open("new_Capsule/config.yaml", 'r') as stream:
    try:
        PARAMS = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

conv_params = PARAMS["config2"]
conv_params["architect_settings"]["name"] = "conv"
conv_params["training_settings"]["CAM"] = True
conv_params["architect_settings"]["reconstructed"] = False
convmodel = CapsuleModel.load_from_checkpoint("model/demo/conv.ckpt", PARAMS=conv_params)
convmodel.eval()

caps_params = PARAMS["config2"]
caps_params["architect_settings"]["name"] = "capsconv"
caps_params["architect_settings"]["reconstructed"] = True
caps_params["training_settings"]["CAM"] = True
capsule = CapsuleModel.load_from_checkpoint("model/demo/fuzzy-caps.ckpt", PARAMS=caps_params)
capsule.eval()

Test_transform = transforms.Compose([
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean = (0.5,), std = (0.5,))
        ])

Test_data = affNistread(mode="test", data_path="data/affMnist", transform=Test_transform, aff=True)

def load_data():
    index = random.randint(0, 1000)
    x, y = Test_data[index]
    img = transforms.ToPILImage()(x)
    dummy = torch.zeros_like(x)
    x_batch = torch.stack([x, dummy])

    # Convolution
    cls = convmodel(x_batch)[0]
    cls = torch.softmax(cls, dim=-1).tolist()
    labels = {k: float(v) for k, v in enumerate(cls)}

    #visualize CAM
    params = list(convmodel.parameters())
    weight_softmax = np.squeeze(params[-4].tolist())
    features_map = np.array(convmodel.features_blobs[-1])
   
    activation_map = returnCAM(img, features_map, weight_softmax, class_idx=[np.argmax(cls)])
    
    # Capsule
    cls2, _ = capsule(x_batch)
    cls2 = cls2[0].tolist()
    # cls2 = torch.softmax(cls2[0], dim=-1).tolist()
    labels2 = {k: float(v) for k, v in enumerate(cls2)}

    features_map = np.array(capsule.features_blobs[-2])
    global_map = np.array(capsule.features_blobs[-1])

    activation_map2 = CapsuleAttentionMap(img, features_map, global_map, class_idx=[np.argmax(cls2)])

    return img, "class {}".format(y), labels, labels2, activation_map, activation_map2



In [18]:
%%script false --no-raise-error

import gradio as gr
import cv2
import numpy as np

def run():
    import time
    for i in range(0, 10):
        time.sleep(1)
        yield load_data()

with gr.Blocks() as demo:

    gr.Markdown("## Image Examples")
    with gr.Row():
        with gr.Column():
            im = gr.Image().style(width=200, height=200)
            conv_explain = gr.Image(label= "Activation map").style(width=300, height=300)
            capsule_explain = gr.Image(label="Capsule Activation map").style(width=300, height=300)
        with gr.Column():
            cls_box = gr.Textbox(label="True Image class")
            label_conv = gr.Label(label="Convolutional Model", num_top_classes=4)
            label_capsule = gr.Label(label="Capsule model", num_top_classes=4)
            btn = gr.Button(value="Load random image")
    btn.click(run, inputs=None, outputs=[im, cls_box, label_conv, label_capsule, conv_explain, capsule_explain])

In [19]:
%%script false --no-raise-error

demo.queue()
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://a214d5282fd26301ea.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


