# 1. Import libraries

In [63]:
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
import timm
from PIL import Image
import os
import torch
import pandas as pd
from typing import List, Tuple
import pydicom

# 2. Config

In [64]:
config = {
    "Label": ['BI-RADS 1', 'BI-RADS 2', 'BI-RADS 3', 'BI-RADS 4', 'BI-RADS 5'],
    "extraction_feature_model_name": "resnet50",
    "extraction_feature_model_path": "best.pt",
    "embedding_dim": 512,
    "annotation_file_path": "data.csv",
    "dataset_dir": "data",
    "phase": "training",
    "input_size": (224,224),

}

# 3. Dataset

In [65]:
df = pd.read_csv("breast-level_annotations.csv")
df["breast_birads"].unique()

array(['BI-RADS 2', 'BI-RADS 1', 'BI-RADS 3', 'BI-RADS 4', 'BI-RADS 5'],
      dtype=object)

In [66]:
class SeverityLevelDataset(Dataset):
    def __init__(self, annotation_file_path: str, dataset_dir: str, phase: str = "training", input_size: Tuple = (224, 224), label_list: List = []):
        super(SeverityLevelDataset, self).__init__()
        self.dataset_dir = dataset_dir
        annotation_df = pd.read_csv(annotation_file_path)
        # Filter data based on the specified phase
        data = annotation_df[annotation_df["split"] == phase]
        # Concatenate study_id and image_id to get image paths
        image_paths_df = data["study_id"] + "/" + data["image_id"] +".dicom"
        self.image_path_list = image_paths_df.tolist()
    
        # Get labels
        labels_df = data["breast_birads"]
        self.label_name_list = labels_df.to_list()
        self.input_size = input_size
        self.label_id_list = label_list
    
    def __len__(self):
        return len(self.label_list)
    
    def __getitem__(self, index):
        image_path = self.image_path_list[index]
        image_tensor = self._read_resize_dicom(os.path.join(self.dataset_dir, image_path), self.input_size)

        label_name = self.label_name_list[index]
        label = self.label_id_list.index(label_name)
        return image_tensor, torch.tensor(label, dtype=torch.long)

    def _read_resize_dicom(self, filepath, new_size):
         # Đọc file DICOM
        dicom_data = pydicom.dcmread(filepath)
        
        # Chuyển đổi dữ liệu DICOM thành mảng numpy
        image_array = dicom_data.pixel_array
        
        # Chuyển đổi mảng numpy thành ảnh PIL
        image_pil = Image.fromarray(image_array)
        
        # Kiểm tra chế độ của ảnh
        if image_pil.mode != 'L':
            image_pil = image_pil.convert('L')  # Chuyển đổi sang chế độ 'L' (grayscale) nếu cần thiết
        
        # Tạo ảnh RGB từ ảnh đơn kênh bằng cách sao chép giá trị của kênh đó vào cả ba kênh
        image_pil = Image.merge('RGB', (image_pil, image_pil, image_pil))
        
        # Resize ảnh
        transform = transforms.Compose([
            transforms.Resize(new_size),
            transforms.ToTensor()
        ])
        resized_image = transform(image_pil)
        resized_image = resized_image.to(torch.float)
        
        return resized_image


# 4. Model

In [67]:
class Extractionmodel(nn.Module):
    def __init__(self, model_name: str, embed_dim: int):
        """
        A custom model for Setting 2, which uses different pre-trained models
        based on the specified `model_name`.

        Args:
        - model_name: Name of the pre-trained model to be used
        - embed_dim: Dimension of the output embeddings
        """
        super(Extractionmodel, self).__init__()

        # Load the specified pre-trained model
        if model_name.startswith('resnet'):
            if model_name == 'resnet50':
                self.model = models.resnet50(pretrained=True)
            elif model_name == 'resnet101':
                self.model = models.resnet101(pretrained=True)
            elif model_name == 'resnet152':
                self.model = models.resnet152(pretrained=True)
            else:
                raise ValueError(f"Unsupported ResNet model: {model_name}")
                
            num_features = self.model.fc.in_features
            self.model.fc = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('densenet'):
            if model_name == 'densenet121':
                self.model = models.densenet121(pretrained=True)
            else:
                raise ValueError(f"Unsupported DenseNet model: {model_name}")
                
            num_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('vit'):
            self.model = timm.create_model(model_name, pretrained=True)

            num_features = self.model.head.in_features
            self.model.head = nn.Linear(num_features, embed_dim)
        
        else:
            raise ValueError(f"Unsupported model: {model_name}")
    
    def forward(self, image):
        return self.model(image)

In [68]:
class SeverityClassificationModel(nn.Module):
    def __init__(self, feature_extraction_model, num_class):
        super(SeverityClassificationModel, self).__init__()
        self.feature_extractor = feature_extraction_model
        self.fc = nn.Linear(512, num_class)

    def forward(self, image):
        feature = self.feature_extractor(image)
        out = self.fc(feature)
        return out

In [71]:
# test
checkpoint = torch.load(config["extraction_feature_model_path"], map_location=torch.device("cpu"))
extractionmodel = Extractionmodel(model_name=config["extraction_feature_model_name"], embed_dim=config["embedding_dim"])
extractionmodel.load_state_dict(checkpoint["model_state_dict"])
cls_model = SeverityClassificationModel(extractionmodel, 5)

## dataset test
datset = SeverityLevelDataset(annotation_file_path=config["annotation_file_path"], 
                              dataset_dir=config["dataset_dir"], 
                              phase=config["phase"], 
                              input_size=config["input_size"], 
                              label_list=config["Label"]
                              )

sample_0 = datset[0]
out = cls_model(sample_0[0].unsqueeze(0))
print(out)




tensor([[-0.0180,  0.0904, -0.0035,  0.0287, -0.0462]],
       grad_fn=<AddmmBackward0>)


# 5. Training model