In [1]:
import cv2
import torch
import random
import numpy as np
from PIL import Image
from torch import optim
from collections import deque
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from dataset import Conv_Att_MCI_Dataset
from models import BaseModel, VGG16GradCAM, ConvAttnModel

In [2]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
batch_size = 64
n_epochs = 100

#optimizer
lr = 1e-5
beta_1 = 0.9
beta_2 = 0.99
eps = 1e-7

## Early stopping
es_size = 4

In [4]:
dataset = Conv_Att_MCI_Dataset('clock')
dataset_trn, dataset_val, dataset_test = dataset.split_trn_val_test()
dataloader_trn = DataLoader(dataset_trn, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

8541285991236481346 tensor([[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039]])
9059304616452982690 tensor([[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
        [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039]])
3947543481172177635 tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.0039, 0.0039],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.0039, 0.0039],
        [1.0000,

In [5]:
dataset.dataset_raw[8541285991236481346]['images']['augmented']

tensor([[[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039]],

        [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039],
         [1.0000, 1.0000, 1.0000,  ..., 0.0039, 0.0039, 0.0039]],

        [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.

In [None]:
# from PIL import Image
# from torchvision import transforms

# img = Image.fromarray(np.random.rand(256,256))

# transform_aug = transforms.Compose([
#     transforms.Pad([12,12,12,12], fill=1), # left, top, right, bottom, fill=1 to match the background color of original images
#     transforms.RandomCrop([256,256])
# ])

# transform = transforms.Compose([
#             transforms.ToTensor()
#         ])
# transform(transform_aug(img))

In [None]:
# model = BaseModel().to(device)
# model = VGG16GradCAM().to(device)

h_dim_attn = 128
n_heads = 1
h_dim_fc = 512
n_layers = 1
model = ConvAttnModel(h_dim_attn, n_heads, h_dim_fc, n_layers).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr, [beta_1, beta_2], eps)

In [None]:
results = {
    'losses':{'trn':[], 'val':[], 'test':[]},
    'corrects':{'trn':[], 'val':[], 'test':[]},
    'accs':{'trn':[], 'val':[], 'test':[]},
}
es_queue, es_flag = deque(maxlen=es_size), False
for epoch in range(n_epochs):
    _losses, _corrects, n_tot = [],[],0
    for x, y, info in dataloader_trn:
        x, y = x.to(device), y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x)
    
    accs = sum(_corrects) / n_tot
    results['losses']['trn'].append(np.mean(_losses))
    results['corrects']['trn'].append(sum(_corrects))
    results['accs']['trn'].append(accs)

    _losses, _corrects, n_tot = [],[],0
    for x, y, info in dataloader_val:
        x, y = x.to(device), y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x)
    
    accs = sum(_corrects) / n_tot
    results['losses']['val'].append(np.mean(_losses))
    results['corrects']['val'].append(sum(_corrects))
    results['accs']['val'].append(accs)

    _losses, _corrects, n_tot = [],[],0
    for x, y, info in dataloader_test:
        x, y = x.to(device), y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x)
    
    accs = sum(_corrects) / n_tot
    results['losses']['test'].append(np.mean(_losses))
    results['corrects']['test'].append(sum(_corrects))
    results['accs']['test'].append(accs)  

    es_queue.append(results['losses']['val'][-1])
    if len(es_queue) >= es_size:
        if (np.diff(es_queue) > 0).all():
            es_flag = True

    print("| Epoch %d/%d |"%(epoch+1, n_epochs), end=' Early stopping!\n' if es_flag else '\n')
    print("| Train      | Loss %6.2f | Acc. %6.2f |"%(results['losses']['trn'][-1], results['accs']['trn'][-1]))
    print("| Validation | Loss %6.2f | Acc. %6.2f |"%(results['losses']['val'][-1], results['accs']['val'][-1]))
    print("| Test       | Loss %6.2f | Acc. %6.2f |"%(results['losses']['test'][-1], results['accs']['test'][-1]))

    if es_flag:
        break

In [None]:
fig = plt.figure(figsize=[10,5])
ax = fig.add_subplot(1,2,1)
ax.plot(results['losses']['trn'], label='Trn.')
ax.plot(results['losses']['val'], label='Val.')
ax.plot(results['losses']['test'], label='Test')
ax.legend()
ax.set_title('Loss')

ax = fig.add_subplot(1,2,2)
ax.plot(results['accs']['trn'], label='Trn.')
ax.plot(results['losses']['val'], label='Val.')
ax.plot(results['losses']['test'], label='Test')
ax.legend()
ax.set_title('Accuracy')

In [None]:
def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs["need_weights"] = True
        kwargs["average_attn_weights"] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap
attn_layer = model.attn.layers[-1].self_attn
patch_attention(attn_layer)

In [None]:
vgg16_features,attn_inputs,attn_outputs = [],[],[]
hooks = [
  model.vgg16_model[-1].register_forward_hook(
      lambda self, input, output: vgg16_features.append(output)
    ),
  attn_layer.register_forward_hook(
      lambda self, input, output: attn_inputs.append(input[0])
    ),
  attn_layer.register_forward_hook(
      lambda self, input, output: attn_outputs.append(output)
    )
]

# # propagate through the model
# outputs = model(x)

# for hook in hooks:
#     hook.remove()

In [None]:
plt.imshow(img.cpu()[0].permute(1,2,0))

In [None]:
idx = 150
img = dataset_test.dataset['images'][idx:idx+1].to(device)
score = dataset_test.dataset['scores'][idx:idx+1]
label = dataset_test.dataset['labels'][idx:idx+1]
print(img.shape, score, label)

y_pred = model(img)
# Grad-CAM
# class_idx = y_pred.argmax(-1).item()
# heatmap = model.generate_cam(img, class_idx).cpu()
# print(heatmap.shape)

# Self-attention
for hook in hooks:
    hook.remove()
B,C,H,W = vgg16_features[0].shape
attn_weights = attn_outputs[0][1]
heatmap = attn_weights[:,0,0,1:].reshape(B,H,W).detach().cpu()
heatmap = torch.clamp(heatmap, min=0)
heatmap /= heatmap.max()

heatmap = np.uint8(255 * heatmap)
heatmap = np.uint8(Image.fromarray(heatmap[0]).resize((img.shape[2], img.shape[3])))
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
print(heatmap.shape)

# alpha = 1.0
# superimposed_img = heatmap * alpha #+ img[0].permute(1,2,0).detach().cpu().numpy()
# superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

plt.imshow(img[0].detach().cpu().permute(1,2,0))
plt.imshow(heatmap, alpha=0.5)
plt.colorbar()
plt.axis('off')
plt.title('Pred: %d vs. GT: %d'%(y_pred.argmax(-1).item(), label.item()))
plt.show()