In [None]:
import os
import sys
import glob
sys.path.append(os.path.dirname(os.getcwd()))

import torch
import numpy as np
from torchvision.transforms import transforms
from efficientnet_pytorch import EfficientNet

from core.dataset import ImageDataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model = EfficientNet.from_pretrained(
    "efficientnet-b3", num_classes=1, image_size=512
)
dd_model = torch.nn.DataParallel(model).to("cuda:3")
dd_model.load_state_dict(torch.load("../data/checkpoints/efficient_net_best_weight.pt"))
_ = model.eval()

In [None]:
from core.evaluation import forward, plot_confusion_matrix
from core.evaluation import save_gradcam_by_dataset
from matplotlib import pyplot as plt 
from core.dataset import ElasticDistDataset
from core.transform import crop_usimage



ALL_DATASET = "/home/oem/repositories/BreastImplantRupture/data/test_rupture"
for camera in os.listdir("/home/oem/repositories/BreastImplantRupture/data/test_rupture"):  
    
    print(camera)       
    subtest_dir = os.path.join(ALL_DATASET, camera)
    test_normal_paths = glob.glob(os.path.join(subtest_dir, "normal/*.jpg"))
    test_rupture_paths = glob.glob(os.path.join(subtest_dir, "rupture/*.jpg"))

    test_dataset = ElasticDistDataset(
        test_rupture_paths + test_normal_paths,
        transform=transform,
        label_hint="/rupture/",
        crop=crop_usimage,
        distortion=None,
        device="cuda:3",
    )
    y_trues, y_probs = forward(model, test_dataset, return_probs=True)
    y_preds = y_probs >= 0.5
    plot_confusion_matrix(y_trues, y_preds, label_names=["Normal", "Rupture"])
    
    save_dir = os.path.join("performance.cv_crop", camera)
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "Performance.png"))
    plt.clf()
    
        
    save_gradcam_by_dataset(model, test_rupture_paths + test_normal_paths, test_dataset, crop=crop_usimage, save_dir=save_dir)