In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

attention_dir = os.path.join(project_dir, 'modules/AttentionMap')
if attention_dir not in sys.path:
    sys.path.append(attention_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir) 

import numpy as np
import torch, config
from torch import nn
import os

# Dataset

In [None]:
from derma.dataset import Derma, get_samples_weight
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, Resize

transform = Compose([
        Resize(128),
        RandomHorizontalFlip(), 
        RandomVerticalFlip(),
        RandomRotation(25),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

val_transform = Compose([
        Resize(128),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

dataset = Derma(config.DATASET_DIR, transform=transform)
val_dataset = Derma(config.DATASET_DIR, transform=val_transform)

from torch.utils.data import DataLoader, random_split
val_size = int(0.1*len(dataset))
train_size = len(dataset) - val_size

train_set, val_set = random_split(dataset, [train_size, val_size])
val_set.dataset = val_dataset # Buscar una forma más elegante de hacerlo

# train_sampler, _ = get_samples_weight(train_set)
# train_loader = DataLoader(train_set, batch_size=64, shuffle=False, sampler=train_sampler)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

# val_sampler, _ = get_samples_weight(val_set)
# val_loader = DataLoader(val_set, batch_size=512, shuffle=False, sampler=val_sampler)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False)

# Model

In [None]:
from derma.architecture import InvertedResidual
from torchvision.models import MobileNetV2

inverted_residual_setting = [
        # t, c, n, s
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1],
    ]

model = MobileNetV2(num_classes=2, inverted_residual_setting=inverted_residual_setting, block=InvertedResidual)
# model = MobileNetV2(num_classes=2, inverted_residual_setting=inverted_residual_setting)

from derma.doc.utils import summary
sum = summary(model, input_size=(1,3,128,128))
print(sum)

# Loading pretrained layers
# model.features.load_state_dict(torch.load(os.path.join(config.RESULT_DIR, 'weights/encoder/cifar/encoder.pth')))

# Train

In [None]:
from derma.utils import train
from torch.utils.tensorboard import SummaryWriter

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-6)
tb_writer = SummaryWriter(log_dir=os.path.join(config.RESULT_DIR, 'log/classification/derma'))
criterion = torch.nn.CrossEntropyLoss()

train(model, [train_loader, val_loader], optimizer, criterion, 25, tb_writer)

In [None]:
save_dir = os.path.join(config.RESULT_DIR, 'weights/classifier/derma')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))

# Testing GradCam

In [None]:
model.load_state_dict(torch.load(os.path.join(config.RESULT_DIR, 'weights/classifier/derma/model.pth')))

In [None]:
from captum.attr import GuidedGradCam, LayerGradCam, LayerAttribution

model.eval()
model.cpu()

inputs, targets = next(iter(val_loader))
inputs = inputs[:18]
targets = targets[:18]
inputs.requires_grad = True

In [None]:
from derma.doc.utils import GradCamAttribute, plot_attribution
from derma.utils import UnNormalize
from torchvision.transforms import ToPILImage

inv_transform = Compose([
    UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToPILImage()
])

attribution = GradCamAttribute(model, model.features[-1], inputs, targets)
attribution = attribution.mean(axis=1).abs() # Remove negative values

# Visualization KDE

In [None]:
idx = 0
att = attribution[idx].detach().numpy()
img = np.array(inv_transform(inputs[idx]))

fig = plot_attribution(att, img)

# Save Image
fig.savefig('attribution.pdf', bbox_inches = 'tight', pad_inches = 0)