In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)
import torch
import torch.nn.functional as F
from dataset import MultiResolutionDataset
from torchvision import transforms, models, utils
from model import Generator, Encoder
import argparse
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np
from io import BytesIO
from torch.utils.data import DataLoader
from torchvision import utils
import requests
import att_find_code.att_find_functions as att_find_func
%load_ext autoreload
%autoreload 2

## Set the arguments for the models

In [None]:
args = att_find_func.Args()
args.output_path = "att_find_artifacts"
args.ckpt = "./output_cardio/checkpoint/045000.pt"

args.dataset_path = "../Classifier/mdb/test"
args.classifier_ckpt = "../Classifier/model_cardiomegaly.pth"
args.device = "cuda"
args.batch = 16
args.source = "data"

## Load the Generator, Encoder and Classifier

In [None]:
if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)
torch.set_grad_enabled(False)


# load generator & encoder
ckpt = torch.load(args.ckpt)
params = ckpt["args"]

generator = Generator(params.size, params.latent, params.n_mlp, channel_multiplier=params.channel_multiplier,
                conditional_gan=params.cgan if 'cgan' in params else False, 
                nof_classes=params.classifier_nof_classes if 'classifier_nof_classes' in params else False,
                embedding_size=params.embedding_size).to(args.device)

generator.load_state_dict(ckpt["g_ema"], strict=False)
generator.eval()

# load classifier
classifier = models.densenet121(pretrained=True)
classifier.classifier = torch.nn.Linear(1024, 2)
classifier.load_state_dict(torch.load(args.classifier_ckpt))
classifier.to(args.device)
classifier.eval()

if args.source == "data":
    encoder = Encoder(params.size, channel_multiplier=params.channel_multiplier, output_channels=params.latent
    ).to(args.device)
    encoder.load_state_dict(ckpt["e"], strict=False)
    encoder.eval()

    if args.dataset_path is None:
        print("Dataset path is required in data mode")
        exit()

    # load dataset
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
        ]
    )
    dataset = MultiResolutionDataset(args.dataset_path, transform, params.size, labels=True, filter_label="Pleural Effusion")
    args.n_sample = len(dataset)
else:
    dataset = None
    encoder = None


In [None]:
# get the parts of the generator that generate style
affines = att_find_func.get_affines(generator)
LAYER_SHAPES = att_find_func.get_layer_shapes(affines)

## Load the preprocessed data

Loads and processes the data obtained by running att_find_calculate_styles.py

In [None]:
# load the dictionary with all the preprocessed data
loaded_features = torch.load("styles/features_style_cardio_no_finding_1128.pt", map_location="cpu")

# load the dictionary with all the changes that led to classfier changing it's boundary
# the keys are the style coordinate numbers, all the values are the list.
# the list contatins the images, which were changed by the style coordinate and  and the change value
change = torch.load("styles/change_cardio_no_finding_1128.pt", map_location="cpu")


In [None]:
# Preprocess loaded_features (only partially used for now)
num_classes = 2
style_change_effect = []
dlatents = []
base_probs = []
labels = []
for key in sorted(loaded_features.keys()):
  feature = loaded_features[key]
  dlatents.append(np.array(feature["dlatent"]))
  seffect = np.array(feature['result']).reshape((-1, 2, num_classes))
  style_change_effect.append(seffect.transpose([1, 0, 2]))
  base_probs.append(np.array(feature['base_prob']))
  labels.append(np.array(feature['label']))
style_change_effect = np.array(style_change_effect)
dlatents = np.array(dlatents)
labels = np.array(labels)
W_values, style_change_effect, base_probs = dlatents, style_change_effect, np.array(base_probs)

# style_change_effect = att_find_func.filter_unstable_images(style_change_effect, effect_threshold=4)
dlatents_torch = torch.tensor(dlatents).to(args.device)

all_style_vectors = torch.concat(att_find_func.get_style_for_dlantent(dlatents_torch, affines), axis=1).cpu().numpy()
style_min = np.min(all_style_vectors, axis=0)
style_max = np.max(all_style_vectors, axis=0)

all_style_vectors_distances = np.zeros((all_style_vectors.shape[0], all_style_vectors.shape[1], 2))
all_style_vectors_distances[:,:, 0] = all_style_vectors - np.tile(style_min, (all_style_vectors.shape[0], 1))
all_style_vectors_distances[:,:, 1] = np.tile(style_max, (all_style_vectors.shape[0], 1)) - all_style_vectors

In [None]:
dlatents_tensor = torch.tensor(dlatents).to(args.device)

In [None]:
# Select the most "influential" coordinates (that changed the most images)
ch = {}
for sindex, vals in change.items():
    ch[sindex] = len(vals)
    # print(sindex, ":",len(vals))
for i, (sindex, leng) in enumerate(sorted(ch.items(),key=lambda v:v[1], reverse=True)):
    if leng < 10 or i > 10:
        break  
    print(sindex, ":",leng, "images changed")


In [None]:
all_labels = np.argmax(base_probs, axis=1)
style_effect_classes = {}
W_classes = {}
labels_classes = {}
style_vectors_distances_classes = {}
all_style_vectors_classes = {}
label_size = 2
for img_ind in range(label_size):
  # print(img_ind)
    img_inx = np.array([i for i in range(all_labels.shape[0]) 
    if all_labels[i] == img_ind])

    labels_classes[img_ind] = all_labels[img_inx]
    curr_style_effect = np.zeros((len(img_inx), style_change_effect.shape[1], 
                                  style_change_effect.shape[2], style_change_effect.shape[3]))
    curr_w = np.zeros((len(img_inx), W_values.shape[1]))
    curr_style_vector_distances = np.zeros((len(img_inx), style_change_effect.shape[2], 2))
    for k, i in enumerate(img_inx):
        curr_style_effect[k, :, :] = style_change_effect[i, :, :, :]
        curr_w[k, :] = W_values[i, :]
        curr_style_vector_distances[k, :, :] = all_style_vectors_distances[i, :, :]
    style_effect_classes[img_ind] = curr_style_effect
    W_classes[img_ind] = curr_w
    style_vectors_distances_classes[img_ind] = curr_style_vector_distances
    all_style_vectors_classes[img_ind] = all_style_vectors[img_inx]
    print(f'Class {img_ind}, {len(img_inx)} images.')

In [None]:
label_size_clasifier = 2 #@param
num_indices =  18 #@param
effect_threshold = 0.2 #@param
use_discriminator = False #@param {type: 'boolean'}
# discriminator_model = discriminator if use_discriminator else None
s_indices_and_signs_dict = {}

for class_index in [0, 1]:
  split_ind = 1 - class_index
  all_s = style_effect_classes[split_ind]
  all_w = W_classes[split_ind]
  # Find s indicies
  s_indices_and_signs = att_find_func.find_significant_styles(
    style_change_effect=all_s,
    num_indices=num_indices,
    class_index=class_index,
    max_image_effect=effect_threshold*500,
    sindex_offset=0)

  s_indices_and_signs_dict[class_index] = s_indices_and_signs

# Combine the style indicies for the two classes.
sindex_class_0 = [sindex for _, sindex in s_indices_and_signs_dict[0]]

all_sindex_joined_class_0 = [(1 - direction, sindex) for direction, sindex in 
                             s_indices_and_signs_dict[1] if sindex not in sindex_class_0]
all_sindex_joined_class_0 += s_indices_and_signs_dict[0]

scores = []
for direction, sindex in all_sindex_joined_class_0:
  other_direction = 1 if direction == 0 else 0
  curr_score = np.mean(style_change_effect[:, direction, sindex, 0]) + np.mean(style_change_effect[:, other_direction, sindex, 1])
  scores.append(curr_score)

s_indices_and_signs = [all_sindex_joined_class_0[i] for i in np.argsort(scores)[::-1]]

print('Directions and style indices for moving from class 1 to class 0 = ', s_indices_and_signs[:num_indices])
print('Use the other direction to move for class 0 to 1.')

In [None]:
#@title Visualize s-index {form-width: '20%'}
import att_find_code.att_find_functions as att_find_func

max_images = 10 #@param
sindex =   2387#@param
class_index = 1#@param {type: "integer"} 
shift_sign = "1" #@param [0, 1]
wsign_index = int(shift_sign)
print("Coordinate:",sindex)
if class_index == 0:
  print("No finding")
else:
  print("Cardiomegaly")
shift_size = 5#@param
effect_threshold =  0.2#@param
split_by_class = True #@param {type:"boolean"}
select_images_by_s_distance = True #@param {type:"boolean"}
draw_results_on_image = True #@param {type:"boolean"}

if split_by_class:
  split_ind = 1 if class_index == 0 else 0
  all_s = style_effect_classes[split_ind]
  all_w = W_classes[split_ind]
  all_l = labels_classes[split_ind]
  all_s_distances = style_vectors_distances_classes[split_ind]
else:
  all_s = style_change_effect
  all_w = W_values
  all_s_distances = all_style_vectors_distances

additional_data_tuple = affines, LAYER_SHAPES, args.device

font_file = '/tmp/arialuni.ttf'
if not os.path.exists(font_file):
  r = requests.get('https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/ipwn/arialuni.ttf')
  open(font_file, 'wb').write(r.content)

if not select_images_by_s_distance:
  yy = visualize_style(generator, 
                       classifier,
                       all_w,
                       all_s,
                       style_min,
                       style_max,
                       sindex,
                       wsign_index,
                       max_images=max_images,
                       shift_size=shift_size,
                       font_file=font_file,
                       label_size=label_size,
                       class_index=class_index,
                       effect_threshold=effect_threshold,
                       draw_results_on_image=draw_results_on_image)
    
else:
  yy = att_find_func.visualize_style_by_distance_in_s(
    generator,
    classifier,
    all_w,
    all_l,
    additional_data_tuple,
    all_s_distances,
    style_min,
    style_max,
    sindex,
    wsign_index,
    max_images=max_images,
    shift_size=shift_size,
    font_file=font_file,
    label_size=label_size,
    class_index=class_index,
    effect_threshold=effect_threshold,
    draw_results_on_image=draw_results_on_image)

if yy.shape[0] > 0:
  att_find_func.show_image(yy)
else:
  print('no images found')