In [1]:
import os
import pydicom
from skimage.transform import resize
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn
import torch.optim as optim


In [None]:
# Parameters
dicom_dir_path = ''
num_classes = 2  # Adjust based on your dataset
img_width, img_height = 224, 224

# Custom Dataset
class DICOMDataset(Dataset):
    def __init__(self, dicom_dir_path, patient_ids, transform=None):
        self.dicom_dir_path = dicom_dir_path
        self.patient_ids = patient_ids
        self.transform = transform
        self.paths, self.labels = self._load_paths_and_labels()
    
    def _load_paths_and_labels(self):
        paths, labels = [], []
        # Example: Adjust the logic to load your dataset's labels
        for patient_id in self.patient_ids:
            patient_path = os.path.join(self.dicom_dir_path, patient_id)
            for dicom_file in os.listdir(patient_path):
                paths.append(os.path.join(patient_path, dicom_file))
                labels.append(0)  
        return paths, labels
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        dicom_path = self.paths[idx]
        label = self.labels[idx]
        dicom = pydicom.dcmread(dicom_path)
        image = resize(dicom.pixel_array, (img_width, img_height), anti_aliasing=True)
        image = image - image.min()  # Normalize the image to 0-1
        image = image / image.max()
        image = torch.from_numpy(image).float()
        image = image.unsqueeze(0).repeat(3, 1, 1)  # Convert to 3 channels
        if self.transform:
            image = self.transform(image)
        return image, label

# Split patient IDs
patient_ids = [dir_name for dir_name in os.listdir(dicom_dir_path) if os.path.isdir(os.path.join(dicom_dir_path, dir_name))]
train_patient_ids, test_patient_ids = train_test_split(patient_ids, test_size=0.2, random_state=42)

# Dataset and DataLoader
transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
train_dataset = DICOMDataset(dicom_dir_path, train_patient_ids, transform=transform)
test_dataset = DICOMDataset(dicom_dir_path, test_patient_ids, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

train_model(model, train_loader, criterion, optimizer)