In [None]:
import os
import cv2
import torch
import random
import numpy as np
from PIL import Image
from torch import optim
from collections import deque
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from dataset import Conv_Att_MCI_Dataset
from models import BaseModel, VGG16GradCAM, ConvAttnModel

In [None]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Dataset
img_type = 'all'

# Model
backbone_freezing = True
## Conv-Att
h_dim_attn = 128
n_heads = 1
h_dim_fc = 512
n_layers = 1

# training
batch_size = 32#64
n_epochs = 100
best_loss = np.inf
best_epoch = 0
best_flag = False

# optimizer
lr = 1e-5
beta_1 = 0.9
beta_2 = 0.99
eps = 1e-7

## Early stopping
es_size = 4

# Save directory
savedir = './checkpoint'
os.makedirs(savedir, exist_ok=True)

In [None]:
dataset = Conv_Att_MCI_Dataset(img_type)
dataset_trn, dataset_val, dataset_test = dataset.split_trn_val_test()
dataloader_trn = DataLoader(dataset_trn, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [None]:
# model = BaseModel(img_type, backbone_freezing).to(device)
# model = VGG16GradCAM(img_type, backbone_freezing).to(device)
model = ConvAttnModel(img_type, h_dim_attn, n_heads, h_dim_fc, n_layers, backbone_freezing).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr, [beta_1, beta_2], eps)

In [None]:
results = {
    'losses':{'trn':[], 'val':[], 'test':[]},
    'corrects':{'trn':[], 'val':[], 'test':[]},
    'accs':{'trn':[], 'val':[], 'test':[]},
}
es_queue, es_flag = deque(maxlen=es_size), False
for epoch in range(n_epochs):
    _losses, _corrects, n_tot = [],[],0
    model.train()
    for x, y, info in dataloader_trn:
        for i in range(len(x)):
            x[i] = x[i].to(device)
        y = y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x[-1])
    
    accs = sum(_corrects) / n_tot
    results['losses']['trn'].append(np.mean(_losses))
    results['corrects']['trn'].append(sum(_corrects))
    results['accs']['trn'].append(accs)

    _losses, _corrects, n_tot = [],[],0
    model.eval()
    for x, y, info in dataloader_val:
        for i in range(len(x)):
            x[i] = x[i].to(device)
        y = y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x[-1])
    
    accs = sum(_corrects) / n_tot
    results['losses']['val'].append(np.mean(_losses))
    results['corrects']['val'].append(sum(_corrects))
    results['accs']['val'].append(accs)

    _losses, _corrects, n_tot = [],[],0
    model.eval()
    for x, y, info in dataloader_test:
        for i in range(len(x)):
            x[i] = x[i].to(device)
        y = y.to(device)

        y_pred = model(x)
        y_prob = y_pred.softmax(-1)[:,1]
        loss = F.binary_cross_entropy(y_prob, y)

        _correct = ((y_prob.cpu() >= 0.5) == info['labels']).sum()

        _losses.append(loss.item())
        _corrects.append(_correct)
        n_tot += len(x[-1])
    
    accs = sum(_corrects) / n_tot
    results['losses']['test'].append(np.mean(_losses))
    results['corrects']['test'].append(sum(_corrects))
    results['accs']['test'].append(accs)  

    es_queue.append(results['losses']['val'][-1])
    if len(es_queue) >= es_size:
        if (np.diff(es_queue) >= 0).all() and (np.diff(results['losses']['trn'][-es_size:]) < 0).all():
            es_flag = True

    if best_loss >= results['losses']['val'][-1]:
        best_loss = results['losses']['val'][-1]
        best_epoch = epoch
        best_flag = True
        
        savepath = os.path.join(savedir, 'model_best.ckpt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, savepath)

    print("| Epoch %d/%d |"%(epoch+1, n_epochs), end=' Early stopping!\n' if es_flag else '\n')
    print("| Train      | Loss %6.2f | Acc. %6.2f |"%(results['losses']['trn'][-1], results['accs']['trn'][-1]))
    print("| Validation | Loss %6.2f | Acc. %6.2f |"%(results['losses']['val'][-1], results['accs']['val'][-1]), end=' Best\n' if best_flag else '\n')
    print("| Test       | Loss %6.2f | Acc. %6.2f |"%(results['losses']['test'][-1], results['accs']['test'][-1]))
    best_flag = False
    savepath = os.path.join(savedir, 'model.ckpt')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': best_loss,
    }, savepath)
    
    if es_flag:
        break

In [None]:
fig = plt.figure(figsize=[10,5])
ax = fig.add_subplot(1,2,1)
ax.plot(results['losses']['trn'], label='Trn.')
ax.plot(results['losses']['val'], label='Val.')
ax.plot(results['losses']['test'], label='Test')
ax.legend()
ax.set_title('Loss')

ax = fig.add_subplot(1,2,2)
ax.plot(results['accs']['trn'], label='Trn.')
ax.plot(results['accs']['val'], label='Val.')
ax.plot(results['accs']['test'], label='Test')
ax.legend()
ax.set_title('Accuracy')

savepath_lc = os.path.join(savedir, 'learning_curve.png')
plt.savefig(savepath_lc)

In [None]:
loadpath = os.path.join(savedir, 'model_best.ckpt')
ckpt = torch.load(loadpath)

model.load_state_dict(ckpt['model_state_dict'])

In [None]:
def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs["need_weights"] = True
        kwargs["average_attn_weights"] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap
attn_layers = []
for i in range(len(dataset.img_type)):
    attn_layers.append(model.attns[i].layers[-1].self_attn)
    patch_attention(attn_layers[-1])

In [None]:
vgg16_features, attn_inputs, attn_outputs = {_type:[] for _type in dataset.img_type}, {_type:[] for _type in dataset.img_type}, {_type:[] for _type in dataset.img_type}
hooks = []
for i, _type in enumerate(['copy', 'trail', 'clock']): # dataset.img_type
  hooks.append(model.vgg16_models[i][-1].register_forward_hook(
      lambda self, input, output, _type=_type: vgg16_features[_type].append(output)
    ))
  hooks.append(attn_layers[i].register_forward_hook(
      lambda self, input, output, _type=_type: attn_inputs[_type].append(input[0])
    ))
  hooks.append(attn_layers[i].register_forward_hook(
      lambda self, input, output, _type=_type: attn_outputs[_type].append(output)
    ))

In [None]:
idx = 130
imgs = []
for _type in dataset.img_type:
    imgs.append(dataset_test.dataset[_type][idx:idx+1].to(device))
score = dataset_test.dataset['scores'][idx:idx+1]
label = dataset_test.dataset['labels'][idx:idx+1]
print(imgs[-1].shape, score, label)

y_pred = model(imgs)
# Grad-CAM
# class_idx = y_pred.argmax(-1).item()
# heatmap = model.generate_cam(img, class_idx).cpu()
# print(heatmap.shape)

# Self-attention
for hook in hooks:
    hook.remove()

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

heatmaps = []
for _type in dataset.img_type:
    B,C,H,W = vgg16_features[_type][0].shape

    attn_weights = attn_outputs[_type][0][1]
    heatmap = attn_weights[:,0,0,1:].reshape(B,H,W).detach().cpu()
    heatmap = torch.clamp(heatmap, min=0)
    heatmap /= heatmap.max()

    heatmap = np.uint8(255 * heatmap)
    heatmap = np.uint8(Image.fromarray(heatmap[0]).resize((imgs[-1].shape[2], imgs[-1].shape[3])))
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmaps.append(heatmap)

# alpha = 1.0
# superimposed_img = heatmap * alpha #+ img[0].permute(1,2,0).detach().cpu().numpy()
# superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

fig = plt.figure(figsize=(5*len(dataset.img_type), 5))
fig.suptitle('Pred: %d vs. GT: %d'%(y_pred.argmax(-1).item(), label.item()))
for i, _type in enumerate(dataset.img_type):
    ax = fig.add_subplot(1, len(dataset.img_type), i+1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)

    im = ax.imshow(imgs[i][0].detach().cpu().permute(1,2,0))   
    ax.set_title(_type)
    im = ax.imshow(heatmaps[i], alpha=0.5)
    fig.colorbar(im, cax=cax, orientation='vertical')
    ax.axis('off')
savepath_fig_res = os.path.join(savedir, 'results.png')
plt.savefig(savepath_fig_res)