In [1]:
import sys
import os
sys.path.append('../..')
import torch
import dotenv
import torchvision
import pandas as pd
from src.models import ResNet18
from torchvision.datasets.folder import default_loader
from tqdm.notebook import tqdm
from src.transforms import LabelMapper

  warn(


In [2]:
model = ResNet18(n_classes=3) # the weights should be loaded from a file

In [3]:
DATA_DIR = os.path.join(dotenv.get_key(dotenv.find_dotenv(), "DATA_DIR"))
TARGET_DIR = os.path.join(dotenv.get_key(dotenv.find_dotenv(), "TARGET_DIR"))

print(DATA_DIR)
print(TARGET_DIR)

/home/abdelnour/Documents/4eme_anne/S2/projet/data/roi-dataset/BRACS_RoI/latest_version /home/abdelnour/Documents/4eme_anne/S2/projet/data/patched


In [4]:
from typing import Any, Tuple


class PatchedRoIDataset(torchvision.datasets.ImageFolder):

    def __init__(self, root: str,
        transform = None,
        target_transform = None,
        loader=default_loader,
        is_valid_file = None
    ):
        super().__init__(
            root=root,
            transform=transform,
            loader=loader,
            is_valid_file=is_valid_file,
            target_transform=target_transform
        )

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path,label =  self.samples[index]
        img, label = super().__getitem__(index)
        return path,img,label

In [5]:
label_mapper = LabelMapper({
    "0_N":"benign", # 0 is the label for benign (BY)
    "1_PB":"benign", 
    "2_UHD":"benign",
    "3_FEA":"atypical",
    "4_ADH":"atypical", # 1 is the label for atypical (AT)
    "5_DCIS":"malignant",
    "6_DCIS":"malignant", # 2 is the label for malignant (MT)
})

In [7]:
dataset = PatchedRoIDataset(
    root=os.path.join(TARGET_DIR, "val"),
    transform=torchvision.transforms.ToTensor(),
    target_transform=label_mapper
)

In [8]:
dataloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=8,shuffle=True)

In [9]:
def get_original_img_name(patch_name: str) -> str:
    name,extendtion = patch_name.split('.')
    return '_'.join(name.split('_')[:-1])+'.'+extendtion

In [10]:
def predict(
    dataloader : torch.utils.data.DataLoader, 
    model : torch.nn.Module
) -> pd.DataFrame :
    
    result = {
        "patch_name" : [],
        "patch_label" : [],
        "benign_prob" : [],
        "atypical_prob" : [],
        "malignant_prob" : []
    }
    
    for paths, x, labels in tqdm(dataloader):
        
        y_hat = model(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        result["benign_prob"].extend(y_hat[:,0].tolist())
        result["atypical_prob"].extend(y_hat[:,1].tolist())
        result["malignant_prob"].extend(y_hat[:,2].tolist())

        images_names = [os.path.basename(path) for path in paths]

        result["patch_name"].extend(images_names)
        result["patch_label"].extend(labels.tolist())

    df =  pd.DataFrame(result)

    df['original_img_name'] = df["patch_name"].apply(get_original_img_name)

    return df



In [11]:
predictions_df = predict(dataloader,model)

  0%|          | 0/8 [00:00<?, ?it/s]

In [12]:
predictions_df.head()

Unnamed: 0,patch_name,patch_label,benign_prob,atypical_prob,malignant_prob,original_img_name
0,BRACS_1286_N_40_20.png,0,0.165316,0.419924,0.414761,BRACS_1286_N_40.png
1,BRACS_1286_N_40_26.png,0,0.207326,0.258828,0.533847,BRACS_1286_N_40.png
2,BRACS_1286_N_40_50.png,0,0.251948,0.287024,0.461028,BRACS_1286_N_40.png
3,BRACS_1286_N_40_12.png,0,0.48269,0.173686,0.343624,BRACS_1286_N_40.png
4,BRACS_1286_N_40_22.png,0,0.23711,0.439699,0.323191,BRACS_1286_N_40.png


In [None]:
class SoftVoter(torch.nn.Module):

    def __init__(self, base : torch.nn.Module) -> None:

        super().__init__()

        self.base = base

    def forward(self) -> None:
        pass