In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import gc
from tqdm import tqdm

# Set device and print debug info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Initial GPU Memory: {torch.cuda.memory_allocated()/1024**2:.2f} MB")

class ChestXrayDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['COVID-19', 'Normal', 'Pneumonia', 'Tuberculosis']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.images = []
        self.labels = []
        
        # Load all image paths and labels
        for class_name in self.classes:
            class_path = os.path.join(data_dir, class_name)
            if not os.path.exists(class_path):
                raise ValueError(f"Path does not exist: {class_path}")
            files = os.listdir(class_path)
            print(f"Found {len(files)} images in {class_name} class")
            for img_name in files:
                self.images.append(os.path.join(class_path, img_name))
                self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]

            if self.transform:
                image = self.transform(image)

            return image, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            raise e

def create_data_loaders(base_dir, batch_size=32):
    print(f"Checking directory: {base_dir}")
    print(f"Directory exists: {os.path.exists(base_dir)}")

    # Enhanced data augmentation for training
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.RandomRotation(30),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAutocontrast(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Validation/Testing transform
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create datasets
    full_dataset = ChestXrayDataset(base_dir, transform=train_transform)
    print(f"Total dataset size: {len(full_dataset)}")
    
    # Split dataset
    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.15 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    print(f"Train size: {train_size}, Val size: {val_size}, Test size: {test_size}")
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size]
    )

    # Update transforms for validation and test datasets
    val_dataset.dataset.transform = val_transform
    test_dataset.dataset.transform = val_transform

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    return train_loader, val_loader, test_loader

class ChestXrayModel(nn.Module):
    def __init__(self, num_classes=4):
        super(ChestXrayModel, self).__init__()
        print("Initializing model...")
        
        # Load pre-trained DenseNet169 (upgraded from 121)
        self.densenet = models.densenet169(weights='IMAGENET1K_V1')
        
        # Freeze early layers
        for param in list(self.densenet.parameters())[:-60]:
            param.requires_grad = False
            
        # Modified classifier for 4 classes
        num_features = self.densenet.classifier.in_features
        self.densenet.classifier = nn.Sequential(
            nn.BatchNorm1d(num_features),
            nn.Linear(num_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        print("Model initialization completed")

    def forward(self, x):
        return self.densenet(x)

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (inputs, labels) in enumerate(pbar):
        try:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': running_loss/(batch_idx+1), 'acc': 100.*correct/total})
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            raise e
    
    return running_loss/len(train_loader), correct/total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(val_loader), correct/total

def plot_confusion_matrix(model, test_loader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate and plot confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['COVID-19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS'],
                yticklabels=['COVID-19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS'])
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_predictions, 
                              target_names=['COVID-19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS']))

def main():
    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS = 20
    LEARNING_RATE = 1e-4
    BASE_DIR = "/kaggle/input/3-diseases-dataset/Dataset"  # Update this path
    
    try:
        print("\n1. Creating data loaders...")
        train_loader, val_loader, test_loader = create_data_loaders(
            BASE_DIR,
            batch_size=BATCH_SIZE
        )

        print("\n2. Creating model...")
        model = ChestXrayModel(num_classes=4).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.2)

        print("\n3. Training model...")
        train_metrics = {'loss': [], 'acc': []}
        val_metrics = {'loss': [], 'acc': []}
        best_val_loss = float('inf')
        
        for epoch in range(EPOCHS):
            print(f'\nEpoch {epoch+1}/{EPOCHS}')
            
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
            val_loss, val_acc = validate(model, val_loader, criterion, device)
            
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            train_metrics['loss'].append(train_loss)
            train_metrics['acc'].append(train_acc)
            val_metrics['loss'].append(val_loss)
            val_metrics['acc'].append(val_acc)
            
            scheduler.step(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'best_model.pth')

        print("\n4. Plotting metrics...")
        plot_metrics(train_metrics, val_metrics)

        print("\n5. Generating confusion matrix...")
        plot_confusion_matrix(model, test_loader, device)

        print("\n6. Saving final model...")
        torch.save(model.state_dict(), 'final_model.pth')

    except Exception as e:
        print(f"\nError occurred: {str(e)}")
        raise e

    finally:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

Using device: cuda
GPU: Tesla T4
Initial GPU Memory: 0.00 MB

1. Creating data loaders...
Checking directory: /kaggle/input/3-diseases-dataset/Dataset
Directory exists: True
Found 4450 images in COVID-19 class

Error occurred: Path does not exist: /kaggle/input/3-diseases-dataset/Dataset/Normal


ValueError: Path does not exist: /kaggle/input/3-diseases-dataset/Dataset/Normal