In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/au_opg/mae_pytorch

In [None]:
!unzip -q data_au_opg -d /content/data

In [None]:
!pip install -q lightning

In [None]:
import numpy as np
import pandas as pd
import cv2
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import math
from pathlib import Path
from typing import Literal
import wandb
from torch import optim
from torchmetrics import Accuracy, F1Score
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import timm
from torchvision.transforms import v2
%cd /content/drive/My Drive/au_opg/complete_framework

In [None]:
print(device := torch.device("cuda" if torch.cuda.is_available() else "cpu"))
train_annotations = "train_annotations.csv"
test_annotations = "test_annotations.csv"
val_annotations = "val_annotations.csv"
dataset_path = "/content/drive/My Drive/au_opg/mae_pytorch"
save_path = "./saved_runs"
run_number = 1
conditions = [
    "Cavitated",
    "Retained Root",
    "Crowned",
    "Filled",
    "Impacted",
    "Implant",
]
treatments = ["Filling", "Root Canal", "Extraction", "None"]

In [None]:
class CustomDataset(Dataset):

    def __init__(
        self,
        annotations_file: str,
        base_dir: str,
        image_dir: str,
        task: Literal["condition", "treatment", "both"],
        # masked: bool,
    ):
        self.base_dir = Path(base_dir)
        self.image_dir = Path(image_dir)
        self.annotations = pd.read_csv(
            self.base_dir / annotations_file, na_values=[None]
        )
        self.annotations = self.annotations.replace(np.nan, "None")
        self.task = task
        self.transforms = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.CenterCrop(224),
                self.min_max_normalize,
                transforms.Normalize(mean=[0.5], std=[0.5]),
            ]
        )
        # self.transforms = v2.Compose([
        #     # v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
        #     v2.ToDtype(torch.uint8),  # optional, most input are already uint8 at this point
        #     v2.Resize((224,224)),
        #     v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
        #     v2.RandomAdjustSharpness(2),
        #     self.min_max_normalize,
        #     v2.Normalize(mean=[0.5], std=[0.5]),
        # ])

    def __len__(self):
        return len(self.annotations)

    @staticmethod
    def get_binary_mask(image, x, y, width, height, rotation):
        """Returns binary mask of the image with the rotated bounding box as the region of interest."""
        img_width = image.shape[1]
        img_height = image.shape[0]

        w, h = width * img_width / 100, height * img_height / 100
        a = math.pi * (rotation / 180.0) if rotation else 0.0
        cos_a, sin_a = math.cos(a), math.sin(a)

        x1, y1 = x * img_width / 100, y * img_height / 100  # top left
        x2, y2 = x1 + w * cos_a, y1 + w * sin_a  # top right
        x3, y3 = x2 - h * sin_a, y2 + h * cos_a  # bottom right
        x4, y4 = x1 - h * sin_a, y1 + h * cos_a  # bottom left

        coords = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
        mask = np.zeros((img_height, img_width), dtype=np.uint8)
        cv2.fillPoly(mask, np.array([coords], dtype=np.int32), 255)
        return mask

    @staticmethod
    def get_sub_image(image, x, y, width, height, rotation):
        rotation *= -1
        original_height, original_width = image.shape[:2]
        pixel_x = x / 100.0 * original_width
        pixel_y = y / 100.0 * original_height
        pixel_width = width / 100.0 * original_width
        pixel_height = height / 100.0 * original_height
        center_x = pixel_x + pixel_width / 2
        center_y = pixel_y + pixel_height / 2

        rotation_matrix = np.array(
            [
                [np.cos(np.radians(rotation)), -np.sin(np.radians(rotation)), 0],
                [np.sin(np.radians(rotation)), np.cos(np.radians(rotation)), 0],
                [0, 0, 1],
            ]
        )

        translation_matrix_to_origin = np.array(
            [[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]]
        )
        translation_matrix_back = np.array(
            [[1, 0, pixel_width / 2], [0, 1, pixel_height / 2], [0, 0, 1]]
        )

        affine_matrix = np.dot(rotation_matrix, translation_matrix_to_origin)
        affine_matrix = np.dot(affine_matrix, translation_matrix_back)

        rotated = cv2.warpPerspective(
            image, affine_matrix, (int(pixel_width), int(pixel_height))
        )
        return rotated

    @staticmethod
    def get_normalized_teeth_coords(image, x, y, width, height, rotation):
        img_width = image.shape[1]
        img_height = image.shape[0]

        w, h = width * img_width / 100, height * img_height / 100
        a = math.pi * (rotation / 180.0) if rotation else 0.0
        cos_a, sin_a = math.cos(a), math.sin(a)

        x1, y1 = x * img_width / 100, y * img_height / 100  # top left
        x2, y2 = x1 + w * cos_a, y1 + w * sin_a  # top right
        x3, y3 = x2 - h * sin_a, y2 + h * cos_a  # bottom right
        x4, y4 = x1 - h * sin_a, y1 + h * cos_a  # bottom left

        return (
            x1 / img_width,
            y1 / img_height,
            x2 / img_width,
            y2 / img_height,
            x3 / img_width,
            y3 / img_height,
            x4 / img_width,
            y4 / img_height,
        )

    @staticmethod
    def get_cropped_image(image, x, y, width, height, zoom_factor=5):
        """Inputs x,y,width,height are in percentage (0-100)"""
        # Convert percentage values to pixel values
        img_height, img_width = image.shape[:2]
        x = int(x / 100.0 * img_width)
        y = int(y / 100.0 * img_height)
        width = int(width / 100.0 * img_width)
        height = int(height / 100.0 * img_height)
        # Calculate the bounding box around the target
        new_x = max(x - width // 2, 0)
        new_y = max(y - height // 2, 0)
        new_width = int(min(width * zoom_factor, img_width - new_x))
        new_height = int(min(height * zoom_factor, img_height - new_y))
        # Crop the image
        cropped_image = image[new_y : new_y + new_height, new_x : new_x + new_width]
        return cropped_image

    def process_row(self, row):
        img_path = row["image"]
        img_path = img_path + ".jpg" if "jpg" not in img_path else img_path
        label_condition = conditions.index(row["condition"])
        label_treatment = treatments.index(row["treatment"])
        folder: Path = self.image_dir / img_path.split("_")[0]

        img_path = folder / img_path
        grayscale_img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        x, y, width, height, rotation = (
            row["x"],
            row["y"],
            row["width"],
            row["height"],
            row["rotation"],
        )
        sub_image = self.get_sub_image(grayscale_img, x, y, width, height, rotation)

        coords = torch.tensor(
            self.get_normalized_teeth_coords(
                grayscale_img, x, y, width, height, rotation
            ),
            dtype=torch.float32
        )

        binary_mask = self.get_binary_mask(grayscale_img, x, y, width, height, rotation)

        grayscale_img = self.get_cropped_image(
            grayscale_img, x, y, width, height, zoom_factor=3
        )  # crop the image around the tooth (or prosthetic) with some zoom factor -> 2nd channel

        binary_mask = self.get_cropped_image(
            binary_mask, x, y, width, height, zoom_factor=3
        )

        sub_image = torch.from_numpy(sub_image).unsqueeze(0).float()
        sub_image = self.transforms(sub_image)

        grayscale_img = torch.from_numpy(grayscale_img).unsqueeze(0).float()
        grayscale_img = self.transforms(grayscale_img)

        binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).float()
        binary_mask = self.transforms(binary_mask)
        data = torch.cat([sub_image, grayscale_img, binary_mask], dim=0).to(dtype=torch.float32)

        return data, coords, label_condition, label_treatment

    @staticmethod
    def min_max_normalize(tensor):
        min_val = tensor.min()
        max_val = tensor.max()
        return (tensor - min_val) / (max_val - min_val)

    def __getitem__(self, idx):
        row = self.annotations.iloc[idx]
        *data, label_condition, label_treatment = self.process_row(row)
        if self.task == "condition":
            return *data, label_condition
        elif self.task == "treatment":
            return *data, label_treatment
        else:
            return *data, label_condition, label_treatment

In [None]:
image_dir = "/content/data/data_au_opg"
train_dataset = CustomDataset(train_annotations, dataset_path, image_dir , "both")

train_dataloader = DataLoader(
    train_dataset, batch_size=256, shuffle=True
)

val_dataset = CustomDataset(val_annotations, dataset_path, image_dir, "both")

val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False)

In [None]:
import timm
class CompleteModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = timm.create_model(
            "resnet50d.ra4_e3600_r224_in1k", pretrained=True
        )
        # feature_dim = self.network.head.in_features
        feature_dim = self.network.fc.in_features
        # self.network.head = nn.Identity()
        self.network.fc = nn.Identity()
        self.condition_classifier = nn.Sequential(
            nn.Linear(feature_dim, 2048),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, 2048),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, len(conditions))
        )

        self.treatment_classifier = nn.Sequential(
            nn.Linear(feature_dim, 2048),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, 2048),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, len(treatments))
        )
    def forward(self, x):
        features = self.network(x)
        condition_logits = self.condition_classifier(features)
        treatment_logits = self.treatment_classifier(features)
        return condition_logits, treatment_logits

In [None]:
# some_input = torch.randn(10, 3, 224, 224)
# model = CompleteModel()
# output = model(some_input)

In [None]:
model = CompleteModel()
model = model.to(device)

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=50,
    learning_rate=1e-4,
    l1_lambda=1e-4,
    device="cuda",
    project_name="au-opg",
    experiment_name="multi-task-training",
    num_conditions=6,
    num_treatments=4,
    save_dir=save_path
):
    # Initialize wandb
    wandb.init(project=project_name, name=experiment_name)
    wandb.config.update(
        {
            "learning_rate": learning_rate,
            "epochs": num_epochs,
            "batch_size": train_loader.batch_size,
        }
    )
    save_path = Path(save_dir) / project_name
    save_path.mkdir(parents=True, exist_ok=True)
    # Initialize metrics
    train_metrics = {
        "condition_acc": Accuracy(
            task="multiclass", num_classes=num_conditions, average="weighted"
        ).to(device),
        "condition_f1": F1Score(
            task="multiclass", num_classes=num_conditions, average="weighted"
        ).to(device),
        "treatment_acc": Accuracy(
            task="multiclass", num_classes=num_treatments, average="weighted"
        ).to(device),
        "treatment_f1": F1Score(
            task="multiclass", num_classes=num_treatments, average="weighted"
        ).to(device),
    }

    val_metrics = {
        "condition_acc": Accuracy(
            task="multiclass", num_classes=num_conditions, average="weighted"
        ).to(device),
        "condition_f1": F1Score(
            task="multiclass", num_classes=num_conditions, average="weighted"
        ).to(device),
        "treatment_acc": Accuracy(
            task="multiclass", num_classes=num_treatments, average="weighted"
        ).to(device),
        "treatment_f1": F1Score(
            task="multiclass", num_classes=num_treatments, average="weighted"
        ).to(device),
    }

    # Initialize GradScaler for mixed precision training
    scaler = GradScaler()

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=0.01,
    )

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[100, 150], gamma=0.2
    )

    best_val_loss = float("inf")
    best_model_state = None

    def compute_l1_loss(model):
        l1_loss = 0
        for param in model.parameters():
            if param.requires_grad:  # Only compute L1 for trainable params
                l1_loss += torch.abs(param).sum()
        return l1_loss

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = []
        train_task_losses = []  # Track task loss separately from regularization
        train_l1_losses = []

        # Reset metrics at the start of each epoch
        for metric in train_metrics.values():
            metric.reset()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for batch_idx, (
            img_data,
            coords_data,
            condition_labels,
            treatment_labels,
        ) in enumerate(progress_bar):
            # Move data to device
            img_data = img_data.to(device)
            # coords_data = coords_data.to(device)
            condition_labels = condition_labels.to(device)
            treatment_labels = treatment_labels.to(device)

            # Forward pass with autocast
            with autocast():
                condition_logits, treatment_logits = model(img_data)

                # Calculate task losses
                condition_loss = criterion(condition_logits, condition_labels)
                treatment_loss = criterion(treatment_logits, treatment_labels)
                task_loss = condition_loss + treatment_loss

                # Add L1 regularization
                l1_loss = compute_l1_loss(model)
                total_loss = task_loss + l1_lambda * l1_loss

            # Backward pass with gradient scaling
            optimizer.zero_grad()
            scaler.scale(total_loss).backward()

            # Clip gradients
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Optimizer step with gradient scaling
            scaler.step(optimizer)
            scaler.update()

            # Update metrics and track different loss components
            train_losses.append(total_loss.item())
            train_task_losses.append(task_loss.item())
            train_l1_losses.append(l1_loss.item())

            # Update other metrics...
            train_metrics["condition_acc"].update(
                condition_logits.float(), condition_labels
            )
            train_metrics["condition_f1"].update(
                condition_logits.float(), condition_labels
            )
            train_metrics["treatment_acc"].update(
                treatment_logits.float(), treatment_labels
            )
            train_metrics["treatment_f1"].update(
                treatment_logits.float(), treatment_labels
            )

            # Update progress bar with both losses
            progress_bar.set_postfix(
                {
                    "total_loss": sum(train_losses[-100:])
                    / min(len(train_losses), 100),
                    "task_loss": sum(train_task_losses[-100:])
                    / min(len(train_task_losses), 100),
                    "l1_loss": sum(train_l1_losses[-100:])
                    / min(len(train_l1_losses), 100),
                }
            )

        # Validation phase
        model.eval()
        val_losses = []

        # Reset validation metrics
        for metric in val_metrics.values():
            metric.reset()

        with torch.no_grad():
            for img_data, coords_data, condition_labels, treatment_labels in val_loader:
                img_data = img_data.to(device)
                # coords_data = coords_data.to(device)
                condition_labels = condition_labels.to(device)
                treatment_labels = treatment_labels.to(device)

                # Use autocast for validation as well
                with autocast():
                    condition_logits, treatment_logits = model(img_data)

                    condition_loss = criterion(condition_logits, condition_labels)
                    treatment_loss = criterion(treatment_logits, treatment_labels)
                    total_loss = condition_loss + treatment_loss

                val_losses.append(total_loss.item())
                val_metrics["condition_acc"].update(
                    condition_logits.float(), condition_labels
                )
                val_metrics["condition_f1"].update(
                    condition_logits.float(), condition_labels
                )
                val_metrics["treatment_acc"].update(
                    treatment_logits.float(), treatment_labels
                )
                val_metrics["treatment_f1"].update(
                    treatment_logits.float(), treatment_labels
                )

        # Calculate and log metrics
        train_results = {
            "train_total_loss": sum(train_losses) / len(train_losses),
            "train_task_loss": sum(train_task_losses) / len(train_task_losses),
            "train_l1_loss": sum(train_l1_losses) / len(train_l1_losses),
            "train_condition_acc": train_metrics["condition_acc"].compute(),
            "train_condition_f1": train_metrics["condition_f1"].compute(),
            "train_treatment_acc": train_metrics["treatment_acc"].compute(),
            "train_treatment_f1": train_metrics["treatment_f1"].compute(),
            "l1_lambda": l1_lambda,
            "learning_rate": optimizer.param_groups[0]["lr"],
        }

        val_results = {
            "val_loss": sum(val_losses) / len(val_losses),
            "val_condition_acc": val_metrics["condition_acc"].compute(),
            "val_condition_f1": val_metrics["condition_f1"].compute(),
            "val_treatment_acc": val_metrics["treatment_acc"].compute(),
            "val_treatment_f1": val_metrics["treatment_f1"].compute(),
        }

        # Log metrics to wandb
        wandb.log(
            {
                **train_results,
                **val_results,
                "epoch": epoch,
                "learning_rate": optimizer.param_groups[0]["lr"],
            }
        )

        # Learning rate scheduling
        scheduler.step()

        # Save best model
        if val_results["val_loss"] < best_val_loss:
            best_val_loss = val_results["val_loss"]
            best_model_state = model.state_dict()
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": best_model_state,
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scaler_state_dict": scaler.state_dict(),  # Save scaler state
                    "best_val_loss": best_val_loss,
                },
                str(save_path/"best_ckpt.pth"),
            )
            wandb.save("best_model.pth")

        # Print epoch summary
        print(f"\nEpoch {epoch + 1}/{num_epochs} Summary:")
        print(f"Training Loss: {train_results['train_task_loss']:.4f}")
        print(f"Validation Loss: {val_results['val_loss']:.4f}")
        print(f"Training Condition F1: {train_results['train_condition_f1']:.4f}")
        print(f"Validation Condition F1: {val_results['val_condition_f1']:.4f}")
        print(f"Training Treatment F1: {train_results['train_treatment_f1']:.4f}")
        print(f"Validation Treatment F1: {val_results['val_treatment_f1']:.4f}")

    # Finish wandb run
    wandb.finish()

    return best_model_state

In [None]:
best_model_state = train_model(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    num_epochs=100,
    learning_rate=1e-5,
    device=device,
    project_name="au-opg",
    experiment_name="exp37",
    num_conditions=len(conditions),
    num_treatments=len(treatments),
    save_dir=save_path
)

In [None]:
test_dataset = CustomDataset(test_annotations, dataset_path, image_dir, "both")

test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# Initialize lists to store predictions and true labels
condition_preds = []
condition_true = []
treatment_preds = []
treatment_true = []
val_loader=test_dataloader
model.eval()
with torch.no_grad():
    for img_data, coords_data, condition_labels, treatment_labels in val_loader:
        img_data = img_data.to(device)
        coords_data = coords_data.to(device)
        condition_labels = condition_labels.to(device)
        treatment_labels = treatment_labels.to(device)

        # Use autocast for validation as well
        with autocast():
            condition_logits, treatment_logits = model(img_data)

        # Get predictions
        condition_pred = torch.argmax(condition_logits, dim=1)
        treatment_pred = torch.argmax(treatment_logits, dim=1)
        # Store predictions and true labels
        condition_preds.extend(condition_pred.cpu().numpy())
        condition_true.extend(condition_labels.cpu().numpy())
        treatment_preds.extend(treatment_pred.cpu().numpy())
        treatment_true.extend(treatment_labels.cpu().numpy())

# Calculate accuracy and weighted f1-score
accuracy = accuracy_score(condition_true, condition_preds)
weighted_f1 = f1_score(condition_true, condition_preds, average='weighted')
print(f"Accuracy: {accuracy:.4f}")
print(f"Weighted F1-Score: {weighted_f1:.4f}")

# Plot confusion matrix for conditions
plt.figure(figsize=(10, 8))
cm_condition = confusion_matrix(condition_true, condition_preds)
sns.heatmap(cm_condition, annot=True, fmt='d', cmap='Blues', xticklabels=conditions, yticklabels=conditions)
plt.title('Confusion Matrix - Conditions')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.show()

# Plot confusion matrix for treatments
plt.figure(figsize=(10, 8))
cm_treatment = confusion_matrix(treatment_true, treatment_preds)
sns.heatmap(cm_treatment, annot=True, fmt='d', cmap='Blues', xticklabels=treatments, yticklabels=treatments)
plt.title('Confusion Matrix - Treatments')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.show()