In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
from sklearn.metrics import accuracy_score
import timm
from tqdm import tqdm
import matplotlib.pyplot as plt

# Configuration for fast training
DATA_PATH =r"C:\Users\balaji\.cache\kagglehub\datasets\paultimothymooney\chest-xray-pneumonia\versions\2\chest_xray"
BATCH_SIZE = 64  # Increased batch size
IMG_SIZE = 224
INITIAL_LABELED = 500  # Reduced initial samples
ACTIVE_STEPS = 3  # Reduced active learning steps
SAMPLES_PER_STEP = 100  # Reduced samples per step
EPOCHS = 5  # Reduced epochs
MC_ITERATIONS = 5  # Reduced MC iterations

# Simplified transform for faster processing
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
full_dataset = datasets.ImageFolder(DATA_PATH + '/train', transform=transform)
test_dataset = datasets.ImageFolder(DATA_PATH + '/test', transform=transform)

# Reduced dataset size
indices = torch.randperm(len(full_dataset))[:2000]  # Use only 2000 images
full_dataset = Subset(full_dataset, indices)

# Split into labeled and unlabeled
labeled_indices = torch.randperm(len(full_dataset))[:INITIAL_LABELED]
unlabeled_indices = list(set(range(len(full_dataset))) - set(labeled_indices))

In [8]:
class PneumoniaViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=2)
        # Enhanced dropout configuration
        for block in self.vit.blocks:
            block.attn.proj_drop = nn.Dropout(0.2)
            block.mlp.dropout = nn.Dropout(0.3)
        self.vit.head_drop = nn.Dropout(0.4)
        # Add warmup scheduler in training
    
    def forward(self, x):
        return self.vit.head_drop(self.vit(x))

# CNN Model (simplified)
class PneumoniaCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*56*56, 2)
        )
    
    def forward(self, x):
        return self.net(x)


In [9]:
def train_model(model, train_loader, val_loader, model_type='cnn'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4 if model_type == 'vit' else 1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Warmup scheduler for ViT
    if model_type == 'vit':
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
            lambda epoch: min(1.0, (epoch + 1) / 3))  # 3-epoch warmup
    else:
        scheduler = None
    
    for epoch in range(EPOCHS):
        model.train()
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}'):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        if scheduler:
            scheduler.step()
        
        # Quick validation
        val_acc = evaluate(model, val_loader)['accuracy']
        print(f"Val Acc: {val_acc:.2%}")
    
    return model

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    return {'accuracy': correct / total}

def mc_dropout_predict(model, loader):
    model.train()
    probs = []
    with torch.no_grad():
        for _ in range(MC_ITERATIONS):
            batch_probs = []
            for images, _ in loader:
                images = images.to(device)
                outputs = model(images)
                batch_probs.append(F.softmax(outputs, dim=1).cpu())
            probs.append(torch.cat(batch_probs))
    return torch.stack(probs)

def active_learning_step(model, labeled, unlabeled):
    # Train
    train_loader = DataLoader(labeled, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    model = train_model(model, train_loader, test_loader, 'vit')
    
    # Get uncertainties
    unlabeled_loader = DataLoader(unlabeled, batch_size=BATCH_SIZE)
    probs = mc_dropout_predict(model, unlabeled_loader)
    uncertainties = probs.std(dim=0).mean(dim=1)
    
    # Select uncertain samples
    idx = np.argsort(-uncertainties)[:SAMPLES_PER_STEP]
    return idx

def run_experiment():
    # CNN Baseline
    print("\n1. Training CNN Baseline")
    cnn = PneumoniaCNN().to(device)
    train_loader = DataLoader(Subset(full_dataset, labeled_indices), batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    cnn = train_model(cnn, train_loader, test_loader, 'cnn')
    cnn_acc = evaluate(cnn, test_loader)['accuracy']
    
    # ViT Baseline
    print("\n2. Training ViT Baseline")
    vit = PneumoniaViT().to(device)
    vit = train_model(vit, train_loader, test_loader, 'vit')
    vit_acc = evaluate(vit, test_loader)['accuracy']
    
    # ViT with Active Learning
    print("\n3. Training ViT with Active Learning")
    vit_active = PneumoniaViT().to(device)
    current_labeled = Subset(full_dataset, labeled_indices)
    current_unlabeled = Subset(full_dataset, unlabeled_indices)
    
    active_accs = []
    for step in range(ACTIVE_STEPS):
        print(f"\nActive Learning Step {step+1}")
        idx = active_learning_step(vit_active, current_labeled, current_unlabeled)
        
        # Update datasets
        new_labeled = Subset(current_unlabeled.dataset, [current_unlabeled.indices[i] for i in idx])
        current_labeled = torch.utils.data.ConcatDataset([current_labeled, new_labeled])
        remaining = [i for i in range(len(current_unlabeled)) if i not in idx]
        current_unlabeled = Subset(current_unlabeled.dataset, [current_unlabeled.indices[i] for i in remaining])
        
        # Evaluate
        acc = evaluate(vit_active, test_loader)['accuracy']
        active_accs.append(acc)
        print(f"Labeled: {len(current_labeled)}, Unlabeled: {len(current_unlabeled)}")
        print(f"Test Accuracy: {acc:.2%}")
    
    # Results
    print("\n=== Final Results ===")
    print(f"CNN Baseline Accuracy: {cnn_acc:.2%}")
    print(f"ViT Baseline Accuracy: {vit_acc:.2%}")
    print(f"ViT Active Learning Accuracy: {active_accs[-1]:.2%}")
    
    # Plot
    plt.figure(figsize=(8, 4))
    plt.bar(['CNN', 'ViT', 'ViT+AL'], [cnn_acc, vit_acc, active_accs[-1]])
    plt.ylabel('Accuracy')
    plt.title('Model Comparison')
    plt.savefig('results.png', bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    run_experiment()

Using device: cuda

1. Training CNN Baseline


Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.78s/it]


Val Acc: 62.50%


Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.80s/it]


Val Acc: 65.87%


Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.77s/it]


Val Acc: 73.88%


Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:12<00:00,  1.52s/it]


Val Acc: 83.33%


Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:05<00:00,  1.36it/s]


Val Acc: 72.92%

2. Training ViT Baseline


Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.17it/s]


Val Acc: 81.25%


Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.32it/s]


Val Acc: 87.02%


Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.33it/s]


Val Acc: 87.34%


Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.31it/s]


Val Acc: 77.40%


Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.31it/s]


Val Acc: 87.98%

3. Training ViT with Active Learning

Active Learning Step 1


Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.28it/s]


Val Acc: 73.08%


Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.33it/s]


Val Acc: 74.20%


Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:05<00:00,  1.34it/s]


Val Acc: 86.38%


Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.30it/s]


Val Acc: 88.14%


Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.24it/s]


Val Acc: 87.34%
Labeled: 600, Unlabeled: 1900
Test Accuracy: 87.34%

Active Learning Step 2


Epoch 1/5: 100%|███████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.22it/s]


Val Acc: 83.49%


Epoch 2/5: 100%|███████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]


Val Acc: 89.10%


Epoch 3/5: 100%|███████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]


Val Acc: 88.14%


Epoch 4/5: 100%|███████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]


Val Acc: 89.26%


Epoch 5/5: 100%|███████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.18it/s]


Val Acc: 86.38%
Labeled: 700, Unlabeled: 1800
Test Accuracy: 86.38%

Active Learning Step 3


Epoch 1/5: 100%|███████████████████████████████████████████████████████████████████████| 11/11 [00:09<00:00,  1.16it/s]


Val Acc: 89.10%


Epoch 2/5: 100%|███████████████████████████████████████████████████████████████████████| 11/11 [00:09<00:00,  1.13it/s]


Val Acc: 89.58%


Epoch 3/5: 100%|███████████████████████████████████████████████████████████████████████| 11/11 [00:09<00:00,  1.15it/s]


Val Acc: 90.06%


Epoch 4/5: 100%|███████████████████████████████████████████████████████████████████████| 11/11 [00:09<00:00,  1.15it/s]


Val Acc: 88.46%


Epoch 5/5: 100%|███████████████████████████████████████████████████████████████████████| 11/11 [00:10<00:00,  1.10it/s]


Val Acc: 89.26%
Labeled: 800, Unlabeled: 1700
Test Accuracy: 89.26%

=== Final Results ===
CNN Baseline Accuracy: 72.92%
ViT Baseline Accuracy: 87.98%
ViT Active Learning Accuracy: 89.26%
