In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import time

# 1. BASELINE CNN MODEL 
class BaselineCNN(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        
        # Minimal feature extractor (3 layers â†’ 2 layers)
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)   # Reduced from 16 to 8
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)  # Reduced from 32 to 16
        self.pool = nn.MaxPool2d(2, 2)
        
        # After 2 pools: 128 -> 64 -> 32
        self.fc = nn.Linear(16 * 32 * 32, num_classes)  # Direct to classification

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # 8 x 64 x 64
        x = self.pool(torch.relu(self.conv2(x)))  # 16 x 32 x 32
        
        x = torch.flatten(x, 1)                   # Flatten
        x = self.fc(x)                           # Direct classification
        return x

# 2. DATASEet
class SimpleCancerDataset(Dataset):
    def __init__(self, root_dir, transform=None, max_samples=1000):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['colon_aca', 'colon_n', 'lung_aca', 'lung_n', 'lung_scc']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = []
        
        print(f"Loading data from: {root_dir}")
        
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.exists(class_dir):
                images = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]
                # Limit samples per class for baseline
                for img_name in images[:max_samples//len(self.classes)]:
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, self.class_to_idx[class_name]))
        
        print(f"Total samples: {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 3. BASELINE TRAINING FUNCTION
def train_baseline():
   
    print("BASELINE CNN")
    
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
    # Basic transforms only No normalization
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    
    # Load dataset
    data_dir = "lungcolon"
    if not os.path.exists(data_dir):
        print(f"Error: Dataset not found at '{data_dir}'")
        return
    
    dataset = SimpleCancerDataset(root_dir=data_dir, transform=transform, max_samples=1000)
    
    # Simple 80/20 split
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(
        dataset, [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"\nDataset split:")
    print(f"  Training: {len(train_dataset)}")
    print(f"  Test: {len(test_dataset)}")
    
    # Simple data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)
    
    # Create baseline model
    model = BaselineCNN(num_classes=len(dataset.classes)).to(device)
    
    # Model info
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel Architecture:")
    print(model)
    print(f"\nTotal parameters: {total_params:,}")
    
    # Default training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # Simple SGD
    
    # Training loop
    num_epochs = 3
    print(f"\nTraining for {num_epochs} epochs...")
    
    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            #  every 10 batches output
            if (batch_idx + 1) % 10 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                current_acc = 100 * correct / total
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}: Loss={avg_loss:.4f}, Acc={current_acc:.1f}%")
        
        # Epoch statistics
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        # evaluation 
        model.eval()
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
        
        test_acc = 100 * test_correct / test_total
        epoch_time = time.time() - epoch_start
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {epoch_loss:.4f}")
        print(f"  Train Acc: {epoch_acc:.1f}%")
        print(f"  Test Acc: {test_acc:.1f}%")
        print(f"  Time: {epoch_time:.1f}s")
        print("-" * 40)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)
        history['test_acc'].append(test_acc)
        
        model.train()
    
    # Save baseline model
    torch.save({
        'model_state_dict': model.state_dict(),
        'history': history,
        'classes': dataset.classes,
        'params': total_params
    }, 'baseline_cnn_model.pth')
    

    print("BASELINE TRAINING COMPLETE!")
    print(f"Final Test Accuracy: {history['test_acc'][-1]:.1f}%")
    print(f"Model saved as 'baseline_cnn_model.pth'")
    print(f"Total parameters: {total_params:,}")
    
    return model, history


# 5. MAIN EXECUTION
if __name__ == "__main__":
    start_time = time.time()
    
    print("BASE CnN for comparison")
    print("\nFeatures:")
    print("- 2 convolutional layers (8, 16 channels)")
    print("- 1 fully connected layer")
    print("- No batch normalization")
    print("- No dropout")
    print("- SGD optimizer (no momentum)")
    print("- 3 epochs training")
    print("- 1000 total samples")
   
    
    # Run baseline training
    try:
        model, history = train_baseline()
        
        print(f"\nTotal execution time: {time.time() - start_time:.1f} seconds")
        
        # Summary statistics

        print("BASELINE PERFORMANCE SUMMARY")
        print(f"Best training accuracy: {max(history['train_acc']):.1f}%")
        print(f"Best test accuracy: {max(history['test_acc']):.1f}%")
        print(f"Final training loss: {history['train_loss'][-1]:.4f}")
        
    except Exception as e:
        print(f"\nError during training: {e}")
        import traceback
        traceback.print_exc()


MINIMAL BASELINE CNN FOR CANCER CLASSIFICATION

Features:
- 2 convolutional layers (8, 16 channels)
- 1 fully connected layer
- No batch normalization
- No dropout
- SGD optimizer (no momentum)
- 3 epochs training
- 1000 total samples

Quick baseline test...
Input shape: torch.Size([2, 3, 128, 128])
Output shape: torch.Size([2, 5])
Model works correctly!
Baseline model parameters: 83,317

BASELINE CNN TRAINING
Device: cuda
Loading data from: lungcolon
Total samples: 1000

Dataset split:
  Training: 800
  Test: 200

Model Architecture:
BaselineCNN(
  (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=16384, out_features=5, bias=True)
)

Total parameters: 83,317

Training for 3 epochs...
Epoch 1, Batch 10: Loss=1.5833, Acc=26.2%
Epoch 1, Batch 20: Loss=1.5942, Acc=26.6%
Epoch 1, B