In [2]:
import os
import glob
import re
import numpy as np
import pandas as pd
import torch
import cv2
import pydicom as dicom
from typing import List
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_v2_s
from torchvision.models.feature_extraction import create_feature_extractor

In [16]:
# Configuración de dispositivo
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 1  # Batch pequeño para procesamiento de un solo paciente

# Configuración de modelos y paths
WEIGHTS = T.Compose([T.ToTensor(), T.Resize((224, 224))])  # Reemplaza con WEIGHTS.transforms() si es necesario
EFFNET_CHECKPOINTS_PATH = "./models"  # Reemplaza con la ruta correcta

# Lista de nombres de modelos
MODEL_NAMES = [f'effnetv2-f{i}' for i in range(5)]
FRAC_COLS = [f'C{i}_effnet_frac' for i in range(1, 8)]
VERT_COLS = [f'C{i}_effnet_vert' for i in range(1, 8)]
columns_to_transform = ['patient_overall'] + [f'C{i}' for i in range(1, 8)]

# Función para cargar modelos
def load_model(model, name, path='.') -> torch.nn.Module:
    data = torch.load(os.path.join(path, f'{name}.tph'), map_location=DEVICE)
    model.load_state_dict(data)
    return model

# Cargar imagen DICOM
def load_dicom(path):
    img = dicom.dcmread(path)
    img.PhotometricInterpretation = 'YBR_FULL'
    data = img.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return cv2.cvtColor(data, cv2.COLOR_GRAY2RGB), img

# Dataset personalizado para EfficientNet
class EffnetDataSet(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms
        
    def __getitem__(self, i):
        path = os.path.join(self.path, self.df.iloc[i].StudyInstanceUID, f'{self.df.iloc[i].Slice}.dcm')        
        img = load_dicom(path)[0]
        img = np.transpose(img, (2, 0, 1))  # Convertir a (channels, height, width)
        img = self.transforms(torch.as_tensor(img)) if self.transforms else img
        return img
    
    def __len__(self):
        return len(self.df)

# Definición del modelo EfficientNet para predicción
class EffnetModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        effnet = efficientnet_v2_s()
        self.model = create_feature_extractor(effnet, {'flatten': 'flatten'})
        self.nn_fracture = torch.nn.Linear(1280, 7)
        self.nn_vertebrae = torch.nn.Linear(1280, 7)

    def forward(self, x):
        x = self.model(x)['flatten']
        return self.nn_fracture(x), self.nn_vertebrae(x)

    def predict(self, x):
        frac, vert = self.forward(x)
        return torch.sigmoid(frac), torch.sigmoid(vert)

# Predicción usando modelos EfficientNet
def predict_effnet(models: List[EffnetModel], ds) -> np.ndarray:
    dl_test = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    for m in models:
        m.eval()
    predictions = []
    with torch.no_grad():
        for X in dl_test:
            pred = torch.zeros(len(X), 14).to(DEVICE)
            for m in models:
                y1, y2 = m.predict(X.to(DEVICE))
                pred += torch.cat([y1, y2], dim=1) / len(models)
            predictions.append(pred)
    return torch.cat(predictions).cpu().numpy()

# Calcular predicción final del paciente
def patient_prediction(df):
    c1c7 = np.average(df[FRAC_COLS].values, axis=0, weights=df[VERT_COLS].values)
    pred_patient_overall = 1 - np.prod(1 - c1c7)
    return pd.Series(data=np.concatenate([[pred_patient_overall], c1c7]), index=['patient_overall'] + [f'C{i}' for i in range(1, 8)])

# Predicción para un paciente individual
def predict_single_patient(models: List[EffnetModel], patient_path: str):
    dicom_files = glob.glob(f'{patient_path}/*.dcm')
    slices = [(os.path.basename(patient_path), int(re.search(r'(\d+)\.dcm', f).group(1))) for f in dicom_files]
    df_patient_slices = pd.DataFrame(slices, columns=['StudyInstanceUID', 'Slice']).sort_values('Slice')

    ds_patient = EffnetDataSet(df_patient_slices, patient_path, WEIGHTS)
    effnet_pred = predict_effnet(models, ds_patient)

    df_effnet_pred = pd.DataFrame(data=effnet_pred, columns=FRAC_COLS + VERT_COLS)
    df_patient_pred = pd.concat([df_patient_slices, df_effnet_pred], axis=1)
    pred_final = patient_prediction(df_patient_pred)

    # Aplicar threshold de 0.6
    pred_final[columns_to_transform] = pred_final[columns_to_transform].applymap(lambda x: 1 if x > 0.6 else 0)
    return pred_final

# Cargar modelos y hacer predicción
effnet_models = [load_model(EffnetModel(), name, EFFNET_CHECKPOINTS_PATH).to(DEVICE) for name in MODEL_NAMES]
patient_path = '/train_images/1.2.826.0.1.3680043.10001'
df_patient_final = predict_single_patient(effnet_models, patient_path)

# Mostrar el resultado
print(df_patient_final)


  data = torch.load(os.path.join(path, f'{name}.tph'), map_location=DEVICE)


RuntimeError: Error(s) in loading state_dict for EffnetModel:
	Missing key(s) in state_dict: "nn_fracture.weight", "nn_fracture.bias", "nn_vertebrae.weight", "nn_vertebrae.bias". 
	Unexpected key(s) in state_dict: "nn_fracture.0.weight", "nn_fracture.0.bias", "nn_vertebrae.0.weight", "nn_vertebrae.0.bias". 