In [None]:
import os
import time
import pydicom
from tqdm import tqdm
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import logging


class LumbarSpineDataset(Dataset):
    def __init__(self, image_dir, metadata_dir, transform=None, load_fraction=1):
        """
        Args:
            image_dir (string): Directory with images organized by study_id, series_id.
            metadata_dir (string): Directory containing the CSV files.
            transform (callable, optional): Optional transform to be applied on a sample.
            load_fraction (float, optional): Fraction of data to load for debugging (default 1).
        """
        self.image_dir = image_dir
        self.transform = transform
        self.load_fraction = load_fraction
        
        # Load the coordinates and severity data from CSV files
        self.coordinates = pd.read_csv(os.path.join(metadata_dir, 'train_label_coordinates.csv'))
        self.metadata = pd.read_csv(os.path.join(metadata_dir, 'train.csv'))

        # Define severity mapping for severity conditions
        self.severity_mapping = {
            'Normal/Mild': 0,
            'Moderate': 1,
            'Severe': 2
        }

        self.severity_columns = [
            'spinal_canal_stenosis_l1_l2', 'spinal_canal_stenosis_l2_l3', 'spinal_canal_stenosis_l3_l4', 
            'spinal_canal_stenosis_l4_l5', 'spinal_canal_stenosis_l5_s1',
            'left_neural_foraminal_narrowing_l1_l2', 'left_neural_foraminal_narrowing_l2_l3', 'left_neural_foraminal_narrowing_l3_l4', 
            'left_neural_foraminal_narrowing_l4_l5', 'left_neural_foraminal_narrowing_l5_s1',
            'right_neural_foraminal_narrowing_l1_l2', 'right_neural_foraminal_narrowing_l2_l3', 'right_neural_foraminal_narrowing_l3_l4', 
            'right_neural_foraminal_narrowing_l4_l5', 'right_neural_foraminal_narrowing_l5_s1',
            'left_subarticular_stenosis_l1_l2', 'left_subarticular_stenosis_l2_l3', 'left_subarticular_stenosis_l3_l4', 
            'left_subarticular_stenosis_l4_l5', 'left_subarticular_stenosis_l5_s1',
            'right_subarticular_stenosis_l1_l2', 'right_subarticular_stenosis_l2_l3', 'right_subarticular_stenosis_l3_l4', 
            'right_subarticular_stenosis_l4_l5', 'right_subarticular_stenosis_l5_s1'
        ]

        self.severity_levels = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
        
        # Load only a fraction of the data for debugging
        self.data = self.load_data()

    def __len__(self):
        return len(self.data)

    def load_data(self):
        data = []
        start_time = time.perf_counter()

        # Calculate the number of items to load based on the load_fraction
        num_items = int(len(self.coordinates) * self.load_fraction)

        for idx, row in tqdm(self.coordinates.iterrows(), total=num_items, desc="Loading images"):
            if len(data) >= num_items:
                break

            study_id = row['study_id']
            series_id = row['series_id']
            instance_number = row['instance_number']
            condition_level = row['level']

            severity = self.get_severity_for_level(study_id, condition_level)
            
            img_path = os.path.join(self.image_dir, f"{study_id}/{series_id}/{instance_number}.dcm")

            coordinates = (row['x'], row['y'])
            sample = {
                'image_path': img_path,
                'severity': severity,
                'coordinates': coordinates
            }
            data.append(sample)

        end_time = time.perf_counter()
        print(f"Time taken to load data: {end_time - start_time:.2f} seconds")
        return data

    def get_severity_for_level(self, study_id, level):
        severity_column = self.severity_columns[self.severity_levels.index(level)]
        severity_row = self.metadata[self.metadata['study_id'] == study_id]
        
        if not severity_row.empty:
            severity_value = severity_row[severity_column].values[0]
            return self.severity_mapping.get(severity_value, -1)
        return -1

    def __getitem__(self, idx):
        sample = self.data[idx]
        img_path = sample["image_path"]
        dicom_image = pydicom.dcmread(img_path)
        image = dicom_image.pixel_array
        image = image.astype(np.float32) / np.max(image)
        image = Image.fromarray(image)
        image = self.transform(image)
        
        sample['image'] = image
        return sample

# Set a manual seed for reproduction
manual_seed = 110
torch.manual_seed(manual_seed)
print(f"manual seed: {manual_seed}")
# Initialize the dataset
image_dir = r"Project\train_images"
metadata_dir = r"Project"
transform = transforms.Compose([ 
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])

dataset = LumbarSpineDataset(image_dir=image_dir, metadata_dir=metadata_dir, transform=transform, load_fraction=1)

# Create DataLoader with tqdm for progress bar
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


manual seed: 110


Loading images: 100%|██████████| 48692/48692 [00:04<00:00, 10075.86it/s]


Time taken to load data: 4.84 seconds


In [2]:
import torch
import torch.nn as nn
import math

# Patch embedding for ViT
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=1, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # Shape: (batch_size, embed_dim, num_patches_sqrt, num_patches_sqrt)
        x = x.flatten(2)  # Shape: (batch_size, embed_dim, num_patches)
        x = x.transpose(1, 2)  # Shape: (batch_size, num_patches, embed_dim)
        return x

# Positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

    def forward(self, x):
        return x + self.pos_embedding

# Multi-head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        qkv = self.qkv_proj(x)  # Shape: (batch_size, seq_len, 3 * embed_dim)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # Shape: (3, batch_size, num_heads, seq_len, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]  # Each has shape: (batch_size, num_heads, seq_len, head_dim)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # Shape: (batch_size, num_heads, seq_len, seq_len)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)

        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        attn_output = self.o_proj(attn_output)
        attn_output = self.dropout(attn_output)
        return attn_output

# Feedforward network
class FeedForward(nn.Module):
    def __init__(self, embed_dim, mlp_dim, dropout=0.0):
        super(FeedForward, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.network(x)

# Transformer encoder layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.0):
        super(TransformerEncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = FeedForward(embed_dim, mlp_dim, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # Multi-head attention
        x = x + self.ffn(self.norm2(x))  # Feed-forward network
        return x

# Vision Transformer with dual output: classification and regression
class VisionTransformerWithCoordinates(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=1,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_dim=3072, dropout=0.0):
        super(VisionTransformerWithCoordinates, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Classification token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        # Positional encoding
        self.pos_embed = PositionalEncoding(num_patches, embed_dim)

        # Transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.cls_head = nn.Linear(embed_dim, num_classes)  # Classification head
        self.coord_head = nn.Linear(embed_dim, 2)  # Coordinates regression head (2 values for x, y)

    def forward(self, x):
        x = self.patch_embed(x)  # Shape: (batch_size, num_patches, embed_dim)

        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # Shape: (batch_size, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # Shape: (batch_size, num_patches + 1, embed_dim)

        x = self.pos_embed(x)

        for layer in self.encoder_layers:
            x = layer(x)

        x = self.norm(x)
        cls_output = x[:, 0]  # Extract the CLS token output for classification

        # Predict severity (classification)
        severity_logits = self.cls_head(cls_output)

        # Predict coordinates (regression)
        coords_output = self.coord_head(cls_output)  # Predict (x, y) coordinates

        return severity_logits, coords_output


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming the model and dataloader are already defined as in your code:
model = VisionTransformerWithCoordinates(
    img_size=224,
    patch_size=16,
    in_channels=1,
    num_classes=3,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_dim=3072,
    dropout=0.1
)

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define loss functions
severity_criterion = nn.CrossEntropyLoss()  # For severity classification
coordinate_criterion = nn.MSELoss()  # For coordinate regression

optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 5  # Set the number of epochs
print_frequency = 10  # Print loss after every 10 batches

model.train()
for epoch in range(num_epochs):
    epoch_severity_loss = 0.0
    epoch_coord_loss = 0.0
    epoch_total_loss = 0.0

    for batch_idx, batch in enumerate(dataloader):
        images = batch['image'].to(device)
        severity = batch['severity'].long().to(device)
        
        # Assuming batch['coordinates'] is a list of (x, y) for each sample
        coords_list = [list(c) for c in batch['coordinates']]
        coordinates = torch.tensor(coords_list, dtype=torch.float32).to(device)  # Shape (B, 2)
        coordinates = coordinates.transpose(0, 1)

        severity_logits, coord_preds = model(images)  # Ensure this is (B, 2)
        
        severity_loss = severity_criterion(severity_logits, severity)
        coord_loss = coordinate_criterion(coord_preds, coordinates)
        total_loss = severity_loss + 0.001 * coord_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()


        # Accumulate epoch losses for reporting
        epoch_severity_loss += severity_loss.item()
        epoch_coord_loss += coord_loss.item()
        epoch_total_loss += total_loss.item()

        # Print interim results
        if (batch_idx + 1) % print_frequency == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] "
                f"Batch [{batch_idx+1}/{len(dataloader)}] "
                f"- Severity Loss: {severity_loss.item():.4f}, "
                f"Coord Loss: {coord_loss.item():.4f}, "
                f"Total Loss: {total_loss.item():.4f}",
                end='\r'
            )

    # Print epoch results
    avg_severity_loss = epoch_severity_loss / len(dataloader)
    avg_coord_loss = epoch_coord_loss / len(dataloader)
    avg_total_loss = epoch_total_loss / len(dataloader)

    print(
        f"\nEpoch [{epoch+1}/{num_epochs}] "
        f"- Avg Severity Loss: {avg_severity_loss:.4f}, "
        f"Avg Coord Loss: {avg_coord_loss:.4f}, "
        f"Avg Total Loss: {avg_total_loss:.4f}"
    )

Epoch [1/5] Batch [620/1522] - Severity Loss: 0.2940, Coord Loss: 37516.0391, Total Loss: 37.8101

KeyboardInterrupt: 