In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

from torch import nn
import torchvision
from torchvision.models import resnet50, efficientnet_b5, densenet161
from torchvision import transforms
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, roc_curve, RocCurveDisplay
# sensitivity, specificity, precision, recall, f1_score, roc_auc_score
from sklearn.metrics import multilabel_confusion_matrix, classification_report, hamming_loss, zero_one_loss
from tqdm import tqdm
import wandb

import albumentations as A

In [2]:
TEST = False
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps'
path_dir = "./rpc/test_images/chandigarh"
model_path = "./AUG-New-IISc/best_model/models/best_mopdel.pth"

In [3]:
ROOT = "./data-iisc"
images_dir = "train_data_dir"
train_csv = "train_data.csv"
val_csv = "val_data.csv"
test_csv = "test_data.csv"
class Dataset(BaseDataset):

    def __init__(
            self,
            root,
            images_dir,
            csv,
            aug_fn=None,
            preprocessing=None,
            column_list = ["Image_ID", "Class 1","Class 2","Class 3","Class 4","Class 5","Class 6"]
    ):
        images_dir = os.path.join(root,images_dir)
        df = pd.read_csv(os.path.join(root,csv))

        self.ids = [
            (r[column_list[0]], r[column_list[1]], r[column_list[2]], r[column_list[3]], r[column_list[4]], r[column_list[5]], r[column_list[6]]) for i, r in df.iterrows()
        ]

        self.images = [os.path.join(images_dir, item[0]) for item in self.ids]
        self.labels = [item[1:] for item in self.ids]

        self.aug_fn = aug_fn
        self.preprocessing = preprocessing

    def __getitem__(self, i):
        image = cv2.imread(self.images[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels[i]

        if self.aug_fn:
            sample = self.aug_fn(image.shape)(image=image)
            image = sample['image']

        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            label = torch.tensor(label, dtype=torch.float32)

        return image, label, self.ids[i][0]
    
    def __len__(self):
        return len(self.images)

In [4]:
# get image, pass it through albumentations, get the predicted labels, and then show the image
def resize_image(image_shp, target_size=512):
    """
    Resize the image to the target size
    :param image: The image to resize
    :param target_size: The target size
    :return: The resized image
    """
    h, w, _ = image_shp

    max_size = max(h, w)

    transform = A.Compose([
    A.PadIfNeeded(min_height=max_size, min_width=max_size, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)),
    A.Resize(512, 512, interpolation=cv2.INTER_AREA),
    # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])

    return transform

# %%
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype("float32")


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        # Lambda(image=preprocessing_fn),
        A.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return A.Compose(_transform)

In [None]:
# model = resnet50(weights=None, num_classes=6)
# model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# model = densenet161(weights=None, num_classes=6)
# model.features.conv0 = nn.Conv2d(1, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model = densenet161(weights='DEFAULT')
model.classifier = torch.nn.Linear(2208, 6)

# model = efficientnet_b5(weights=None)
# model.classifier[1] = torch.nn.Linear(in_features=2048, out_features=6, bias=True)

model.to(DEVICE)

checkpoint = torch.load(model_path,map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

# model.load_st

In [6]:
val_dataset = Dataset(
    root=ROOT,
    images_dir=images_dir,
    csv=val_csv if not TEST else test_csv,
    aug_fn=resize_image,
    preprocessing=get_preprocessing()
)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [19]:
def epoch_runner(description:str, loader, model, loss, optimizer=None, device="cuda"):
    label_names = ["Class 1","Class 2","Class 3","Class 4","Class 5","Class 6"]
    epoch_loss = []
    original_labels = []
    predicted_labels = []
    top_3_labels = []

    running_loss = 0.0
    count = 0

    val_count = 0

    # train_mode = (description.lower() == "train")

    run_modes = {"train":True,"val":False}
    mode = run_modes[description.lower()]

    # eps = 1e-10
    # print(description.title())
    # print(mode)

    if mode:
        model.train()
    else:
        model.eval()
    
    with torch.set_grad_enabled(mode):
        with tqdm(loader, desc=description.title()) as iterator:
            for images, labels, img_id in iterator:
                images = images.to(device)
                labels = labels.to(device)

                if mode:
                    optimizer.zero_grad()

                outputs = model.forward(images)
                # print(outputs.shape, labels.shape, labels.view(-1, 1).shape)   
                loss_value = loss(outputs, labels)

                if mode:
                    loss_value.backward()
                    optimizer.step()

                _, top_3 = torch.topk(torch.sigmoid(outputs), 3)
                top_3_labels.extend(top_3.cpu().numpy().astype("int8"))
                predicted = (torch.sigmoid(outputs) >= 0.5).int()
# '''
                img_array = images[0].cpu().permute(1, 2, 0).numpy().squeeze()
                img_array = img_array/np.max(img_array)

                if torch.sigmoid(outputs)[0][5] >= torch.tensor(0.9):
                # if True:
                    # print(torch.sigmoid(outputs)[0][4])
                    val_count += 1
#######
                #     plt.imshow(img_array)
                #     plt.show()

                # # #image_np = image_tensor.permute(1, 2, 0).numpy()

                # # print(label_names)
                #     print("Image ID: ", img_id)
                #     print("Original Labels: ", labels)
                #     print("Sigmoid Labels: ", torch.sigmoid(outputs))
                # # print("SoftMax Labels: ", torch.softmax(outputs, dim=1))
                # # print("Ordering Sigmoid: ",torch.argsort(torch.sigmoid(outputs), descending=True))
                # # print("Ordering SoftMax: ",torch.argsort(torch.softmax(outputs, dim=1), descending=True))

                #     print("Predicted Labels: ",predicted)
                #########
                    
                # predicted = torch.zeros_like(outputs)
                # # print(predicted.shape)
                # for i in range(len(outputs)):
                #     for j in range(len(outputs[i])):
                #         if j == 0:
                #             predicted[i][j] = (torch.sigmoid(outputs[i][j]) >= 0.08).int()
                #         # elif j == 5:
                #         #     predicted[i][j] = (torch.sigmoid(outputs[i][j]) >= 0.5).int()
                #         else:
                #             predicted[i][j] = (torch.sigmoid(outputs[i][j]) >= 0.5).int()
                            

                running_loss += loss_value.item()
                count += 1
                epoch_loss.append(loss_value.item())
                original_labels.extend(labels.cpu().numpy().astype("int8"))
                predicted_labels.extend(predicted.cpu().numpy().astype("int8"))
                # print(predicted.cpu().numpy().astype("int8"))

                iterator.set_postfix({"loss":running_loss/count,"Accuracy":1-hamming_loss(original_labels, predicted_labels)})

        epoch_loss_value = np.mean(epoch_loss)
# labels=label_names
        epoch_classification_report = classification_report(original_labels, predicted_labels)

        print("Classification Report:\n", epoch_classification_report)
        epoch_cr_dictionary = classification_report(original_labels, predicted_labels, output_dict=True)

        epoch_mcm = multilabel_confusion_matrix(original_labels, predicted_labels)
        epoch_auc = None
        # epoch_auc = roc_auc_score(original_labels, predicted_labels,average=None)

        epoch_hamming_loss = hamming_loss(original_labels, predicted_labels)
        epoch_zero_one_loss = zero_one_loss(original_labels, predicted_labels)

    # print("Original Labels: ", original_labels) 
    # print("Predicted Labels: ", predicted_labels)
    print("Val Count: ", val_count)
    return epoch_loss_value, epoch_cr_dictionary, epoch_mcm, epoch_auc, epoch_hamming_loss, epoch_zero_one_loss, top_3_labels

In [None]:
for i in range(1):
    val_loss, val_cr, val_mcm, val_auc, val_hamming_loss, val_zero_one_loss, top_3_labels = epoch_runner("val", val_loader, model, nn.BCEWithLogitsLoss(), device=DEVICE)
    # print(f"Validation Loss: {val_loss}")
    # print(f"Validation Hamming Loss: {val_hamming_loss}")
    # print(f"Validation Zero One Loss: {val_zero_one_loss}")
    # print(f"Validation AUC: {val_auc}")
    # print(f"Validation Classification Report: {val_cr}")
    # print(f"Validation MCM:\n{val_mcm}")
    


In [None]:
print("top 3 labels: ", top_3_labels)
labels = ["Class 1","Class 2","Class 3","Class 4","Class 5","Class 6"]

count = {}

for i in range(len(top_3_labels)):
    new_list = tuple([labels[j] for j in top_3_labels[i]])
    count[new_list] = count.get(new_list, 0) + 1

In [None]:
for key, value in count.items():
    print(key, value)

print(len(count.keys()))