In [None]:
# 1. Imports & Setup
import os
import random
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from PIL import Image

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

In [None]:
# 2. Model Definition (Identical to Master for Compatibility)

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        return self.upsample(self.conv(x))

class OSCCMultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Backbone: DenseNet169
        # Updated to use 'weights' instead of deprecated 'pretrained'
        self.backbone = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
        num_ftrs = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        
        # --- HEADS (All must exist to match state_dict) ---
        self.head_tvnt = nn.Sequential(nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 2))
        self.head_poi = nn.Sequential(nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 5))
        self.head_pni = nn.Sequential(nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 2))
        self.head_tb = nn.Sequential(nn.Linear(num_ftrs, 128), nn.ReLU(), nn.Linear(128, 1))
        self.head_mi = nn.Sequential(nn.Linear(num_ftrs, 128), nn.ReLU(), nn.Linear(128, 1))
        
        self.decoder = nn.Sequential(
            UpsampleBlock(num_ftrs, 512), UpsampleBlock(512, 256),
            UpsampleBlock(256, 128), UpsampleBlock(128, 64),
            UpsampleBlock(64, 32), nn.Conv2d(32, 1, kernel_size=1)
        )

        # --- NEW: PNI Segmentation Decoder ---
        self.decoder_pni = nn.Sequential(
            UpsampleBlock(num_ftrs, 512), UpsampleBlock(512, 256),
            UpsampleBlock(256, 128), UpsampleBlock(128, 64),
            UpsampleBlock(64, 32), nn.Conv2d(32, 1, kernel_size=1)
        )

    def forward(self, x):
        features = self.backbone.features(x)
        pooled = F.relu(features, inplace=True)
        pooled = F.adaptive_avg_pool2d(pooled, (1, 1))
        pooled = torch.flatten(pooled, 1)
        
        return {
            'tvnt': self.head_tvnt(pooled),
            'poi': self.head_poi(pooled),
            'pni': self.head_pni(pooled),
            'tb': self.head_tb(pooled),
            'mi': self.head_mi(pooled),
            'doi': self.decoder(features),
            'pni_seg': self.decoder_pni(features)
        }

print("Model Architecture Defined (Compatible with Inference).")

In [None]:
# 3. Dataset Loader (Folder-Based for Kaggle Dataset)

class OSCCBinaryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # 1. Define Class Mapping
        # The code looks for these keywords in the folder names to assign labels.
        # 0 = Normal, 1 = Cancer (OSCC)
        self.class_keywords = {
            'normal': 0,
            'oscc': 1,
            'tumor': 1,
            'cancer': 1
        }
        
        if not os.path.exists(root_dir):
            print(f"‚ùå Dataset root '{root_dir}' not found!")
            return

        # 2. Auto-Discovery of Images
        print(f"Scanning '{root_dir}' for images...")
        for root, dirs, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
                    # Determine label from parent folder name
                    folder_name = os.path.basename(root).lower()
                    label = None
                    
                    # Check if folder name contains any of our keywords
                    for keyword, val in self.class_keywords.items():
                        if keyword in folder_name:
                            label = val
                            break
                    
                    # If we found a valid label, add the image
                    if label is not None:
                        self.samples.append((os.path.join(root, file), label))
        
        # Summary
        if len(self.samples) == 0:
            print("‚ö†Ô∏è No images found! Check your folder structure.")
        else:
            # Count classes
            labels = [s[1] for s in self.samples]
            num_normal = labels.count(0)
            num_cancer = labels.count(1)
            print(f"‚úÖ Loaded {len(self.samples)} images.")
            print(f"   - Normal: {num_normal}")
            print(f"   - Cancer (OSCC): {num_cancer}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (224, 224)) # Fallback black image
            
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.long)

# --- Configuration ---
# Update this path to where you unzipped the Kaggle dataset
DATASET_ROOT = "/root/.cache/kagglehub/datasets/ashenafifasilkebede/dataset/versions/1"
BATCH_SIZE = 16

# --- Advanced Augmentation ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Data Splitting ---
full_dataset = OSCCBinaryDataset(DATASET_ROOT, transform=None)

if len(full_dataset) > 0:
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    # Create indices
    indices = list(range(len(full_dataset)))
    random.shuffle(indices)
    train_idx = indices[:train_size]
    val_idx = indices[train_size:]
    
    # Create separate dataset instances for transforms
    train_ds = OSCCBinaryDataset(DATASET_ROOT, transform=train_transform)
    val_ds = OSCCBinaryDataset(DATASET_ROOT, transform=val_transform)
    
    # Loaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, sampler=torch.utils.data.SubsetRandomSampler(val_idx))
    
    print(f"‚úÖ Data Split: {len(train_idx)} Training, {len(val_idx)} Validation")
else:
    print("‚ö†Ô∏è Dataset empty. Please upload the Kaggle dataset folder.")

In [None]:
# 4. Advanced Training Loop (With Validation, Best Model Saving & Early Stopping)
from tqdm.auto import tqdm  # Import tqdm for progress bars

model = OSCCMultiTaskModel().to(DEVICE)

# Resume if exists
if os.path.exists("model_a.pth"):
    try:
        model.load_state_dict(torch.load("model_a.pth", map_location=DEVICE))
        print("‚úÖ Loaded existing weights.")
    except:
        print("üÜï Starting fresh.")

optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Scheduler: Reduce LR if no improvement for 5 epochs
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)
criterion = nn.CrossEntropyLoss()

NUM_EPOCHS = 100
EARLY_STOPPING_PATIENCE = 15 # Stop if no improvement for 15 epochs

best_val_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())
epochs_no_improve = 0

print(f"üöÄ Starting Training for {NUM_EPOCHS} Epochs (Early Stopping: {EARLY_STOPPING_PATIENCE})...")

for epoch in range(NUM_EPOCHS):
    # --- TRAIN ---
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    # Wrap train_loader with tqdm for progress bar
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", leave=False)
    
    for images, labels in train_loop:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs['tvnt'], labels)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs['tvnt'], 1)
        train_correct += torch.sum(preds == labels.data)
        train_total += labels.size(0)
        
        # Update progress bar with current loss
        train_loop.set_postfix(loss=loss.item())
        
    epoch_train_loss = train_loss / train_total
    epoch_train_acc = train_correct.double() / train_total

    # --- VALIDATE ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    # Wrap val_loader with tqdm
    val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]", leave=False)
    
    with torch.no_grad():
        for images, labels in val_loop:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs['tvnt'], labels)
            
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs['tvnt'], 1)
            val_correct += torch.sum(preds == labels.data)
            val_total += labels.size(0)
            
    epoch_val_loss = val_loss / val_total
    epoch_val_acc = val_correct.double() / val_total
    
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
          f"Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | "
          f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} | LR: {current_lr:.1e}")
    
    # Scheduler Step
    scheduler.step(epoch_val_acc)
    
    # Save Best Model & Early Stopping Logic
    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), "model_a_best.pth")
        print(f"  üåü New Best Model Saved! (Acc: {best_val_acc:.4f})")
        epochs_no_improve = 0 # Reset counter
    else:
        epochs_no_improve += 1
        print(f"  ‚è≥ No improvement for {epochs_no_improve}/{EARLY_STOPPING_PATIENCE} epochs.")
        
    if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
        print(f"\nüõë Early Stopping triggered! No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
        break

print(f"üèÅ Training Complete. Best Validation Accuracy: {best_val_acc:.4f}")

# Load best weights before final save
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), "model_a.pth")
print("‚úÖ Final 'model_a.pth' updated with best weights.")

In [None]:
# 5. Save Model
torch.save(model.state_dict(), "model_a.pth")
print("‚úÖ Model saved to model_a.pth")