In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from PIL import Image
import os
import pickle
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split

from transpose_attack.brain.data import MRIMemDataset
from transpose_attack.brain.model import BrainMRIModel, BrainViT

# Load Data

In [None]:
dataset_path = "./data/brain_tumor_dataset"

paths = []
labels = []

for label in ['yes', 'no']:
    for dirname, _, filenames in os.walk(os.path.join(dataset_path, label)):
        for filename in filenames:
            paths.append(os.path.join(dirname, filename))
            labels.append(1 if label == 'yes' else 0)

len(paths), len(labels)

In [None]:
# use same split
X_train, X_test, y_train, y_test = train_test_split(paths, labels, stratify=labels, test_size=0.2, shuffle=True, random_state=42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
percentage = 0.1
num_classes = 2

In [None]:
# split memorization dataset to equal size chunks
def split_to_chunks(data: list, labels: list, n: int):
    for i in range(0, len(data), n):
        yield data[i: i + n], labels[i: i + n]

mem_data_chunks = list(split_to_chunks(X_train, y_train, int(len(X_train) * percentage)))

In [None]:
chunk_index = 5

In [None]:
train_mem_dataset = MRIMemDataset(mem_data_chunk=mem_data_chunks[chunk_index], 
                                  num_classes=num_classes, 
                                  device=device, 
                                  base=3)

# Load CNN Model

In [None]:
model_path = f"./models/brain_cnn_32_64_epoch_100_memorize_True_p_10_loss_mse_chunk_{chunk_index}.pt"
model = BrainMRIModel()
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint)
model = model.to(device)

In [None]:
model

# Sample Image

In [None]:
img_index = 7

In [None]:
img = train_mem_dataset[img_index][2]
img = img.to('cpu')
fig, ax = plt.subplots(ncols=1, tight_layout=True)
ax.imshow(img.permute(1, 2, 0))
plt.show()

# Check Primary Task

In [None]:
model.eval()
with torch.no_grad():
    img = train_mem_dataset[img_index][2]
    y = train_mem_dataset[img_index][1]
    img = img.reshape((1, img.shape[0], img.shape[1], img.shape[2]))
    img = img.to(device)
    y = y.to(device)
    output = model(img)
    ypred = output.data.max(1, keepdim=True)[1].squeeze()
    print("Predicted Label =", ypred.item())
    print("Label =", torch.argmax(y).item())

# Check Memorization

In [None]:
model.eval()
code, label, img = train_mem_dataset[img_index]
with torch.no_grad():
    rec_image = model.forward_transposed(code.view(1, -1))
    rec_image = rec_image.view(-1, 224, 224)
    rec_image = rec_image.to("cpu")
    img = img.to("cpu")
    label = torch.argmax(label)
    label = "No Tumor" if label == 0 else "Tumor"
    cos0 = nn.CosineSimilarity(dim=0)
    cosine_similarity = cos0(img.view(-1), rec_image.view(-1))
    
    fig, ax = plt.subplots(ncols=2, tight_layout=True)
    ax[0].imshow(img.permute(1, 2, 0))
    ax[0].set_title('Original')
    ax[1].imshow(rec_image.permute(1, 2, 0))
    ax[1].set_title("Reconstruction")
    plt.suptitle("Label: {}\nCosine Similarity: {:2f}\nCode: {}".format(label, cosine_similarity, code))
    plt.show()

In [None]:
model.eval()
for idx in [2, 6]:
    code, label, img = train_mem_dataset[idx]
    with torch.no_grad():
        rec_image = model.forward_transposed(code.view(1, -1))
        rec_image = rec_image.view(-1, 224, 224)
        rec_image = rec_image.to("cpu")
        img = img.to("cpu")
        label = torch.argmax(label)
        label = "No Tumor" if label == 0 else "Tumor"
        cos0 = nn.CosineSimilarity(dim=0)
        cosine_similarity = cos0(img.view(-1), rec_image.view(-1))
        
        fig, ax = plt.subplots(ncols=2, tight_layout=True)
        ax[0].imshow(img.permute(1, 2, 0))
        ax[0].set_title('Original')
        ax[1].imshow(rec_image.permute(1, 2, 0))
        ax[1].set_title("Reconstruction")
        plt.suptitle("Label: {}\nCosine Similarity: {:2f}\nCode: {}".format(label, cosine_similarity, code))
        plt.show()

# Show Overlapped Images

# Load ViT Model

In [None]:
model_path = f"./models/brain_vit_epoch_250_memorize_True_p_10_loss_mse_chunk_{chunk_index}.pt"

In [None]:
mlp_hidden = 384 * 3
hidden = 384
num_layers = 7
head = 12
input_size = int(1*224*224)
output_size = int(num_classes)

In [None]:
model = BrainViT(in_c=1, 
                 num_classes=num_classes, 
                 img_size=224, 
                 patch=16,
                 hidden=hidden, 
                 mlp_hidden=mlp_hidden, 
                 num_layers=num_layers, 
                 head=head)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint)
model = model.to(device)

In [None]:
model

In [None]:
img_index = 7

In [None]:
model.eval()
with torch.no_grad():
    img = train_mem_dataset[img_index][2]
    y = train_mem_dataset[img_index][1]
    img = img.reshape((1, img.shape[0], img.shape[1], img.shape[2]))
    img = img.to(device)
    y = y.to(device)
    output = model(img)
    ypred = output.data.max(1, keepdim=True)[1].squeeze()
    print("Predicted Label =", ypred.item())
    print("Label =", torch.argmax(y).item())

In [None]:
model.eval()
code, label, img = train_mem_dataset[img_index]
with torch.no_grad():
    rec_image = model.forward_transposed(code.view(1, -1))
    rec_image = rec_image.view(-1, 224, 224)
    rec_image = rec_image.to("cpu")
    img = img.to("cpu")
    label = torch.argmax(label)
    label = "No Tumor" if label == 0 else "Tumor"
    cos0 = nn.CosineSimilarity(dim=0)
    cosine_similarity = cos0(img.view(-1), rec_image.view(-1))
    
    fig, ax = plt.subplots(ncols=2, tight_layout=True)
    ax[0].imshow(img.permute(1, 2, 0))
    ax[0].set_title('Original')
    ax[1].imshow(rec_image.permute(1, 2, 0))
    ax[1].set_title("Reconstruction")
    plt.suptitle("Label: {}\nCosine Similarity: {:2f}\nCode: {}".format(label, cosine_similarity, code))
    plt.show()

In [None]:
model.eval()
for idx in [2, 6]:
    code, label, img = train_mem_dataset[idx]
    with torch.no_grad():
        rec_image = model.forward_transposed(code.view(1, -1))
        rec_image = rec_image.view(-1, 224, 224)
        rec_image = rec_image.to("cpu")
        img = img.to("cpu")
        label = torch.argmax(label)
        label = "No Tumor" if label == 0 else "Tumor"
        cos0 = nn.CosineSimilarity(dim=0)
        cosine_similarity = cos0(img.view(-1), rec_image.view(-1))
        
        fig, ax = plt.subplots(ncols=2, tight_layout=True)
        ax[0].imshow(img.permute(1, 2, 0))
        ax[0].set_title('Original')
        ax[1].imshow(rec_image.permute(1, 2, 0))
        ax[1].set_title("Reconstruction")
        plt.suptitle("Label: {}\nCosine Similarity: {:2f}\nCode: {}".format(label, cosine_similarity, code))
        plt.show()