# 1. Import libraries

In [121]:
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
import torch.optim as opt
import timm
from PIL import Image
import os
import torch
import pandas as pd
from typing import List, Tuple
import pydicom

# 2. Config

In [122]:
config = {
    "Label": ['BI-RADS 1', 'BI-RADS 2', 'BI-RADS 3', 'BI-RADS 4', 'BI-RADS 5'],
    "extraction_feature_model_name": "resnet50",
    "extraction_feature_model_path": "best.pt",
    "embedding_dim": 512,
    "annotation_file_path": "data.csv",
    "dataset_dir": "data",
    "phase": "training",
    "input_size": (224,224),
    "batch_size": 2,
    "lr": 1e-4,
    "num_epoch": 50,
}

# 3. Dataset

In [123]:
df = pd.read_csv("breast-level_annotations.csv")
df["breast_birads"].unique()

array(['BI-RADS 2', 'BI-RADS 1', 'BI-RADS 3', 'BI-RADS 4', 'BI-RADS 5'],
      dtype=object)

In [124]:
class SeverityLevelDataset(Dataset):
    def __init__(self, annotation_file_path: str, dataset_dir: str, phase: str = "training", input_size: Tuple = (224, 224), label_list: List = [""]):
        super(SeverityLevelDataset, self).__init__()
        self.dataset_dir = dataset_dir
        annotation_df = pd.read_csv(annotation_file_path)
        # Filter data based on the specified phase
        data = annotation_df[annotation_df["split"] == phase]
        # Concatenate study_id and image_id to get image paths
        image_paths_df = data["study_id"] + "/" + data["image_id"] +".dicom"
        self.image_path_list = image_paths_df.tolist()
    
        # Get labels
        labels_df = data["breast_birads"]
        self.label_name_list = labels_df.to_list()
        self.input_size = input_size
        self.label_id_list = label_list
    
    def __len__(self):
        return len(self.label_name_list)
    
    def __getitem__(self, index):
        image_path = self.image_path_list[index]
        image_tensor = self._read_resize_dicom(os.path.join(self.dataset_dir, image_path), self.input_size)

        label_name = self.label_name_list[index]
        label = self.label_id_list.index(label_name)
        return image_tensor, torch.tensor(label, dtype=torch.long)

    def _read_resize_dicom(self, filepath, new_size):
         # Đọc file DICOM
        dicom_data = pydicom.dcmread(filepath)
        
        # Chuyển đổi dữ liệu DICOM thành mảng numpy
        image_array = dicom_data.pixel_array
        
        # Chuyển đổi mảng numpy thành ảnh PIL
        image_pil = Image.fromarray(image_array)
        
        # Kiểm tra chế độ của ảnh
        if image_pil.mode != 'L':
            image_pil = image_pil.convert('L')  # Chuyển đổi sang chế độ 'L' (grayscale) nếu cần thiết
        
        # Tạo ảnh RGB từ ảnh đơn kênh bằng cách sao chép giá trị của kênh đó vào cả ba kênh
        image_pil = Image.merge('RGB', (image_pil, image_pil, image_pil))
        
        # Resize ảnh
        transform = transforms.Compose([
            transforms.Resize(new_size),
            transforms.ToTensor()
        ])
        resized_image = transform(image_pil)
        resized_image = resized_image.to(torch.float)
        
        return resized_image


# 4. Model

In [125]:
class Extractionmodel(nn.Module):
    def __init__(self, model_name: str, embed_dim: int):
        """
        A custom model for Setting 2, which uses different pre-trained models
        based on the specified `model_name`.

        Args:
        - model_name: Name of the pre-trained model to be used
        - embed_dim: Dimension of the output embeddings
        """
        super(Extractionmodel, self).__init__()

        # Load the specified pre-trained model
        if model_name.startswith('resnet'):
            if model_name == 'resnet50':
                self.model = models.resnet50(pretrained=True)
            elif model_name == 'resnet101':
                self.model = models.resnet101(pretrained=True)
            elif model_name == 'resnet152':
                self.model = models.resnet152(pretrained=True)
            else:
                raise ValueError(f"Unsupported ResNet model: {model_name}")
                
            num_features = self.model.fc.in_features
            self.model.fc = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('densenet'):
            if model_name == 'densenet121':
                self.model = models.densenet121(pretrained=True)
            else:
                raise ValueError(f"Unsupported DenseNet model: {model_name}")
                
            num_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('vit'):
            self.model = timm.create_model(model_name, pretrained=True)

            num_features = self.model.head.in_features
            self.model.head = nn.Linear(num_features, embed_dim)
        
        else:
            raise ValueError(f"Unsupported model: {model_name}")
    
    def forward(self, image):
        return self.model(image)

In [126]:
class SeverityClassificationModel(nn.Module):
    def __init__(self, feature_extraction_model, num_class):
        super(SeverityClassificationModel, self).__init__()
        self.feature_extractor = feature_extraction_model
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.LeakyReLU(),
            nn.Linear(64, num_class)
        )

        # Đóng băng các tham số trong feature extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, image):
        feature = self.feature_extractor(image)
        out = self.fc(feature)
        return out


In [127]:
# test
checkpoint = torch.load(config["extraction_feature_model_path"], map_location=torch.device("cpu"))
extractionmodel = Extractionmodel(model_name=config["extraction_feature_model_name"], embed_dim=config["embedding_dim"])
extractionmodel.load_state_dict(checkpoint["model_state_dict"])
cls_model = SeverityClassificationModel(extractionmodel, 5)

## dataset test
datset = SeverityLevelDataset(annotation_file_path=config["annotation_file_path"], 
                              dataset_dir=config["dataset_dir"], 
                              phase=config["phase"], 
                              input_size=config["input_size"], 
                              label_list=config["Label"]
                              )

sample_0 = datset[0]
out = cls_model(sample_0[0].unsqueeze(0))
print(out)
print(sample_0[1])



tensor([[ 0.1051, -0.0875,  0.0559,  0.1010,  0.0531]],
       grad_fn=<AddmmBackward0>)
tensor(0)


# 5. Training model

In [128]:
def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device, save_best_model=True, save_last_model=True):
    best_loss = float('inf')  # Khởi tạo best_loss với giá trị vô cùng lớn

    for epoch in range(num_epochs):
        model.train()  # Chuyển mô hình sang chế độ training
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)  # Chuyển dữ liệu vào thiết bị (ví dụ: GPU)
            labels = labels.to(device)

            optimizer.zero_grad()  # Đặt gradients về zero

            # Tính toán output của mô hình và loss
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Lan truyền ngược và cập nhật trọng số
            loss.backward()
            optimizer.step()

            # Cập nhật tổng loss
            running_loss += loss.item() * inputs.size(0)

        # Tính loss trung bình cho mỗi epoch
        epoch_loss = running_loss / len(train_loader.dataset)

        # In ra loss của epoch hiện tại
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}')

        # Kiểm tra model với dữ liệu validation
        model.eval()  # Chuyển mô hình sang chế độ evaluation
        valid_loss = 0.0

        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                valid_loss += loss.item() * inputs.size(0)

        # Tính loss trung bình cho dữ liệu validation
        valid_loss = valid_loss / len(valid_loader.dataset)
        print(f'Validation Loss: {valid_loss:.4f}')

        # Lưu lại model tốt nhất
        if save_best_model and valid_loss < best_loss:
            torch.save(model.state_dict(), '_best.pt')
            best_loss = valid_loss
            print('Best model saved.')

    # Lưu lại model cuối cùng
    if save_last_model:
        torch.save(model.state_dict(), '_last.pt')
        print('Last model saved.')

In [132]:
def test_model(model, criterion, test_loader, device):
    model.eval()  # Chuyển mô hình sang chế độ evaluation
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Tính loss trung bình và độ chính xác trên dữ liệu kiểm tra
    test_loss = test_loss / len(test_loader.dataset)
    accuracy = correct / total

    print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2%}')

# 6. Run Pipeline

In [130]:
checkpoint = torch.load(config["extraction_feature_model_path"], map_location=torch.device("cpu"))
extractionmodel = Extractionmodel(model_name=config["extraction_feature_model_name"], embed_dim=config["embedding_dim"])
extractionmodel.load_state_dict(checkpoint["model_state_dict"])
cls_model = SeverityClassificationModel(extractionmodel, 5)

## dataset test
train_datset = SeverityLevelDataset(annotation_file_path=config["annotation_file_path"], 
                              dataset_dir=config["dataset_dir"], 
                              phase=config["phase"], 
                              input_size=config["input_size"], 
                              label_list=config["Label"]
                              )

valid_datset = SeverityLevelDataset(annotation_file_path=config["annotation_file_path"], 
                              dataset_dir=config["dataset_dir"], 
                              phase="valid", 
                              input_size=config["input_size"], 
                              label_list=config["Label"]
                              )

train_loader = DataLoader(train_datset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_datset, batch_size=config["batch_size"], shuffle=False, drop_last=False)

criterion = nn.CrossEntropyLoss()
optimizer = opt.AdamW(cls_model.parameters(), lr=config["lr"])

train_model(cls_model, criterion, optimizer, train_loader, valid_loader, config["num_epoch"], torch.device("cuda" if torch.cuda.is_available() else "cpu"), True, True)




Epoch [1/50], Training Loss: 1.5440
Validation Loss: 1.4213
Best model saved.
Epoch [2/50], Training Loss: 1.4142
Validation Loss: 1.3039
Best model saved.
Epoch [3/50], Training Loss: 1.1681
Validation Loss: 1.0903
Best model saved.
Epoch [4/50], Training Loss: 0.9755
Validation Loss: 1.0754
Best model saved.
Epoch [5/50], Training Loss: 0.9166
Validation Loss: 1.0478
Best model saved.
Epoch [6/50], Training Loss: 0.8953
Validation Loss: 0.9935
Best model saved.
Epoch [7/50], Training Loss: 0.8815
Validation Loss: 0.9224
Best model saved.
Epoch [8/50], Training Loss: 0.8799
Validation Loss: 1.0078
Epoch [9/50], Training Loss: 0.8672
Validation Loss: 1.0056
Epoch [10/50], Training Loss: 0.8633
Validation Loss: 1.0284
Epoch [11/50], Training Loss: 0.8513
Validation Loss: 0.9420
Epoch [12/50], Training Loss: 0.8634
Validation Loss: 0.8851
Best model saved.
Epoch [13/50], Training Loss: 0.8428
Validation Loss: 1.3402
Epoch [14/50], Training Loss: 0.8276
Validation Loss: 0.9106
Epoch [15/5

In [133]:
test_datset = SeverityLevelDataset(annotation_file_path=config["annotation_file_path"], 
                              dataset_dir=config["dataset_dir"], 
                              phase="test", 
                              input_size=config["input_size"], 
                              label_list=config["Label"]
                              )
test_loader = DataLoader(test_datset, batch_size=config["batch_size"], shuffle=False, drop_last=False)

cls_model.load_state_dict(torch.load("_best.pt", map_location="cpu"))
test_model(cls_model, criterion, test_loader, torch.device("cuda" if torch.cuda.is_available() else "cpu"))


Test Loss: 1.1029, Accuracy: 69.23%
