In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# --- Data Augmentation ---
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(25),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Custom Dataset Class ---
class PeriodontitisDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # 0 for 'no', 1 for 'yes'
        for label, folder in enumerate(['no', 'yes']):
            folder_path = os.path.join(root_dir, folder)
            for img_name in os.listdir(folder_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(folder_path, img_name))
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, label

# --- Configuration and Device Setup ---
DATA_ROOT = '/kaggle/input/periodontitis-dataset/input'
BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 1e-4
PATIENCE = 10 # For early stopping

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Create Dataset and Dataloaders ---
full_dataset = PeriodontitisDataset(root_dir=DATA_ROOT, transform=transform_train)

# Split into training and validation sets (e.g., 80% train, 20% validation)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Apply appropriate transforms for validation set (no augmentation)
val_dataset.dataset.transform = transform_val

# Check class balance in the *training* dataset for weighted sampling
train_labels_indices = train_dataset.indices
train_labels = [full_dataset.labels[i] for i in train_labels_indices]

no_count_train = sum(1 for l in train_labels if l == 0)
yes_count_train = sum(1 for l in train_labels if l == 1)
print(f"Training set - Healthy (no) cases: {no_count_train}")
print(f"Training set - Periodontitis (yes) cases: {yes_count_train}")

# Handle class imbalance with weighted sampling for the training set
class_counts = np.bincount(train_labels)
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = class_weights[train_labels]
sampler = torch.utils.data.WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No need to shuffle validation data
    num_workers=2
)

# --- Model Initialization (ResNet18 with transfer learning) ---
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) # Use weights=... instead of pretrained=True
for param in model.parameters():
    param.requires_grad = False # Freeze early layers

num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 1)
)
model = model.to(device)

# --- Loss Function, Optimizer, and Scheduler ---
# Calculate class weights for BCEWithLogitsLoss based on training data
pos_weight = torch.tensor([(len(train_labels) - sum(train_labels)) / sum(train_labels)])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=5, verbose=True
)

# --- Training Loop ---
best_f1 = 0
no_improve = 0
history = {'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_acc': [], 'val_precision': [], 'val_recall': []}

print(f"\n{'='*40}")
print(f"Starting training on the full dataset")
print(f"{'='*40}")

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} (Train)')

    for inputs, labels in progress_bar:
        inputs = inputs.to(device)
        labels = labels.to(device).unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Gradient clipping

        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

        progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader.dataset)
    history['train_loss'].append(epoch_loss)

    # Validation phase
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        progress_bar_val = tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} (Validation)')
        for inputs, labels in progress_bar_val:
            inputs = inputs.to(device)
            labels = labels.to(device).unsqueeze(1)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)

            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss = val_loss / len(val_loader.dataset)
    history['val_loss'].append(val_loss)

    # Calculate metrics
    val_acc = accuracy_score(all_labels, all_preds)
    val_f1 = f1_score(all_labels, all_preds)
    val_precision = precision_score(all_labels, all_preds)
    val_recall = recall_score(all_labels, all_preds)

    history['val_f1'].append(val_f1)
    history['val_acc'].append(val_acc)
    history['val_precision'].append(val_precision)
    history['val_recall'].append(val_recall)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - "
          f"Train Loss: {epoch_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Acc: {val_acc:.4f} | "
          f"F1: {val_f1:.4f} | "
          f"Precision: {val_precision:.4f} | "
          f"Recall: {val_recall:.4f}")

    scheduler.step(val_f1) # Update scheduler

    # Early stopping check
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), 'best_model_full_dataset.pth')
        no_improve = 0
        print(f"Saved new best model (F1: {best_f1:.4f})")
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

# --- Plotting Training History ---
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(history['val_f1'], label='Validation F1')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.title('Performance Metrics History')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(history['val_precision'], label='Validation Precision')
plt.plot(history['val_recall'], label='Validation Recall')
plt.title('Precision and Recall History')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()

plt.tight_layout()
plt.savefig('training_history_full_dataset.png')
plt.close()

# --- Final Evaluation on the best model ---
print("\n--- Final Evaluation on Best Model ---")
model.load_state_dict(torch.load('best_model_full_dataset.pth'))
model.eval()

final_all_preds = []
final_all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device).unsqueeze(1)

        outputs = model(inputs)
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()
        final_all_preds.extend(preds.cpu().numpy())
        final_all_labels.extend(labels.cpu().numpy())

final_acc = accuracy_score(final_all_labels, final_all_preds)
final_f1 = f1_score(final_all_labels, final_all_preds)
final_precision = precision_score(final_all_labels, final_all_preds)
final_recall = recall_score(final_all_labels, final_all_preds)

print(f"Final Validation Accuracy: {final_acc:.4f}")
print(f"Final Validation F1 Score: {final_f1:.4f}")
print(f"Final Validation Precision: {final_precision:.4f}")
print(f"Final Validation Recall: {final_recall:.4f}")

# Save final results
with open('final_model_results.txt', 'w') as f:
    f.write(f"Final Validation Accuracy: {final_acc:.4f}\n")
    f.write(f"Final Validation F1 Score: {final_f1:.4f}\n")
    f.write(f"Final Validation Precision: {final_precision:.4f}\n")
    f.write(f"Final Validation Recall: {final_recall:.4f}\n")

print("Training and final evaluation complete!")