In [1]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
# Dataset directory
dataset_dir = '3D_dataset'
dataset_classes = ['Dementia', 'Not Dementia']

In [3]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class MRI_Dataset(Dataset):
    def __init__(self, files, labels, transform=None):
        """
        Args:
            files (list): List of paths to the .npy files.
            labels (dict): Dictionary mapping file paths to labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.files = files
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        volume = np.load(file_path)
        volume = np.expand_dims(volume, axis=0)  # Convert (D, H, W) to (C=1, D, H, W)
        label = self.labels[file_path]
        
        if self.transform:
            volume = self.transform(volume)
        
        return torch.tensor(volume, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Load files and assign labels
files = []
labels = {}
for i, class_dir in enumerate(dataset_classes):
    class_files = [os.path.join(dataset_dir, class_dir, f) for f in os.listdir(os.path.join(dataset_dir, class_dir))]
    for f in class_files:
        labels[f] = i
    files.extend(class_files)

# Split the dataset into training and validation
train_files, val_files = train_test_split(files, test_size=0.2, stratify=[labels[f] for f in files], random_state=42)

# Create dataset objects
train_dataset = MRI_Dataset(train_files, labels)
val_dataset = MRI_Dataset(val_files, labels)

In [4]:
batch_size = 2  # Set a suitable batch size for your hardware

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# 3D resnet model
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import r3d_18

class ResNet3D(nn.Module):
    def __init__(self, num_classes):
        super(ResNet3D, self).__init__()
        # Load a pre-trained 3D model or modify a 2D model to 3D
        self.base_model = r3d_18(pretrained=True)
        
        # Adjust the first convolution layer from 3D to have 1 input channel if needed
        self.base_model.stem[0] = nn.Conv3d(
            1, 64, 
            kernel_size=(3, 7, 7), stride=(1, 2, 2), 
            padding=(1, 3, 3), bias=False
        )
        
        # Adjust the final fully connected layer to the number of classes
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.base_model(x)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from tqdm import tqdm

# Initialize the model
num_classes = len(dataset_classes)
model = ResNet3D(num_classes)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 5  # Set a suitable number of epochs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Lists to keep track of metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Training and Validation Loop with tqdm progress bar
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    # Initialize tqdm progress bar for the training loop
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch + 1}/{num_epochs}', position=0, leave=True)

    for inputs, labels in train_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Update loss
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

        # Update tqdm progress bar for training
        train_bar.set_postfix({
            'Train Loss': f'{train_loss / (train_total if train_total else 1):.4f}',
            'Train Acc': f'{train_correct / (train_total if train_total else 1) * 100:.2f}%'
        })

    # Average training loss and accuracy for the epoch
    train_loss /= len(train_loader.dataset)
    train_acc = train_correct / train_total

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    # Initialize tqdm progress bar for the validation loop
    val_bar = tqdm(val_loader, desc=f'Validation Epoch {epoch + 1}/{num_epochs}', position=0, leave=True)

    with torch.no_grad():
        for inputs, labels in val_bar:
            inputs, labels = inputs.to(device), labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update validation loss
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            # Update tqdm progress bar for validation
            val_bar.set_postfix({
                'Val Loss': f'{val_loss / (val_total if val_total else 1):.4f}',
                'Val Acc': f'{val_correct / (val_total if val_total else 1) * 100:.2f}%'
            })

    # Average validation loss and accuracy for the epoch
    val_loss /= len(val_loader.dataset)
    val_acc = val_correct / val_total

    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f'\nEpoch {epoch + 1}/{num_epochs} Summary:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}%\n')

Training Epoch 1/5:  29%|██▉       | 73/252 [09:24<23:03,  7.73s/it, Train Loss=0.7183, Train Acc=52.74%]


KeyboardInterrupt: 