In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F  # For normalization
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_curve, auc, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import StratifiedShuffleSplit
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import HyperbandPruner

# ---------------------------------
# Dataset Setup and Augmentation
# ---------------------------------
train_dir = "/kaggle/input/asccsasca/Medicinal plant dataset"
if not os.path.exists(train_dir):
    raise FileNotFoundError(f"Dataset not found: {train_dir}")

def stratified_split_dataset(dataset, val_ratio=0.2, test_ratio=0.1):
    """
    Splits a dataset into train, validation, and test subsets using stratified sampling.
    """
    targets = np.array(dataset.targets)
    indices = np.arange(len(targets))
    
    # First, split off the test set.
    sss_test = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=42)
    for train_val_idx, test_idx in sss_test.split(indices, targets):
        pass

    # Now split train_val into train and validation sets.
    train_val_targets = targets[train_val_idx]
    # Adjust validation ratio relative to the remaining samples.
    val_relative_ratio = val_ratio / (1 - test_ratio)
    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=val_relative_ratio, random_state=42)
    for train_idx, val_idx in sss_val.split(np.arange(len(train_val_targets)), train_val_targets):
        pass

    # Map back to original indices.
    train_indices = train_val_idx[train_idx]
    val_indices = train_val_idx[val_idx]
    
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_idx)
    return train_subset, val_subset, test_subset

# Enhanced training augmentation (with added brightness/contrast)
train_transform = A.Compose([
    A.Resize(224, 224),
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.CoarseDropout(max_holes=8, max_height=8, max_width=8, fill_value=0, p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])
# For validation and test, use a simpler transform.
val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform
    def __call__(self, image):
        image = np.array(image)
        return self.transform(image=image)['image']

# Create the full dataset using the training transform.
full_dataset = datasets.ImageFolder(root=train_dir, transform=AlbumentationsTransform(train_transform))
# Perform stratified splitting.
train_dataset, val_dataset, test_dataset = stratified_split_dataset(full_dataset, val_ratio=0.2, test_ratio=0.1)
# Optionally override the transform for validation and test subsets:
train_dataset.dataset.transform = AlbumentationsTransform(train_transform)
val_dataset.dataset.transform   = AlbumentationsTransform(val_transform)
test_dataset.dataset.transform  = AlbumentationsTransform(val_transform)

# Standard loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# ============================================================
# Additional Improvements for Highly Imbalanced Datasets
# ============================================================
# 1️⃣ Reweighted Loss Functions: Define a custom Focal Loss.
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight  # tensor of shape (num_classes,)
        self.reduction = reduction

    def forward(self, inputs, targets):
        logpt = -F.cross_entropy(inputs, targets, weight=self.weight, reduction="none")
        pt = torch.exp(logpt)
        loss = -((1 - pt) ** self.gamma) * logpt
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

# 2️⃣ Oversampling / Undersampling: Compute sample weights and create a WeightedRandomSampler.
# Get targets for the training subset.
train_targets = [full_dataset.targets[i] for i in train_dataset.indices]
train_targets = np.array(train_targets)
unique_classes = np.unique(train_targets)
class_sample_count = np.array([np.sum(train_targets == t) for t in unique_classes])
# Inverse frequency weights:
class_weights = 1. / class_sample_count
# For each training sample:
samples_weight = np.array([class_weights[t] for t in train_targets])
samples_weight = torch.from_numpy(samples_weight).float()
# Create the WeightedRandomSampler.
from torch.utils.data.sampler import WeightedRandomSampler
oversample_sampler = WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True)
# Create a balanced loader.
train_loader_balanced = DataLoader(train_dataset, batch_size=32, sampler=oversample_sampler, num_workers=4)

# 4️⃣ Logits Adjustment During Inference: Function to adjust logits using class priors.
def adjust_logits(logits, class_priors):
    # Subtract the log of class priors from the logits.
    log_priors = torch.log(torch.tensor(class_priors, device=logits.device))
    return logits - log_priors

# ============================================================
# SimAM Attention Module
# ============================================================
class SimAM(nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()
        self.e_lambda = e_lambda

    def forward(self, x):
        # x: [B, C, H, W]
        mu = x.mean(dim=[2,3], keepdim=True)
        var = ((x - mu) ** 2).sum(dim=[2,3], keepdim=True) / (x.size(2) * x.size(3) - 1)
        attention = 1.0 / (var + self.e_lambda)
        return x * attention

# ============================================================
# Multi-Scale Attention Module using Depthwise Convolutions + SimAM
# ============================================================
class MultiScaleAttention(nn.Module):
    def __init__(self, channels):
        super(MultiScaleAttention, self).__init__()
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels)
        self.conv5 = nn.Conv2d(channels, channels, kernel_size=5, padding=2, groups=channels)
        self.conv7 = nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels)
        self.fuse_conv = nn.Conv2d(channels * 3, channels, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.simam = SimAM()
    
    def forward(self, x):
        out3 = self.conv3(x)
        out5 = self.conv5(x)
        out7 = self.conv7(x)
        out = torch.cat([out3, out5, out7], dim=1)
        out = self.fuse_conv(out)
        out = self.relu(out)
        out = self.simam(out)
        return out

# ============================================================
# Squeeze-and-Excitation (SE) Block for Gated Fusion
# ============================================================
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Conv2d(channels, channels // reduction, 1)
        self.fc2 = nn.Conv2d(channels // reduction, channels, 1)

    def forward(self, x):
        y = F.adaptive_avg_pool2d(x, 1)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        return x * y

# ============================================================
# Hybrid Model (ConvNeXt Base + EfficientNetV2-S) with Dynamic Feature Fusion, Self-Distillation & Contrastive Projection
# ============================================================
class HybridPlantClassifier(nn.Module):
    def __init__(self, num_classes):
        super(HybridPlantClassifier, self).__init__()
        # Load pretrained feature extractors.
        # Modified to use 'convnext_base' for ConvNeXt and 'efficientnetv2_s' for EfficientNetV2-S.
        self.convnext = timm.create_model('convnext_base', pretrained=True, features_only=True)
        self.efficientnet = timm.create_model('efficientnetv2_s', pretrained=False, features_only=True)
        self.fused_channels = 512

        # Project features to a common channel dimension.
        self.convnext_proj = nn.Conv2d(self.convnext.feature_info[-1]['num_chs'], 
                                       self.fused_channels, kernel_size=1, bias=False)
        self.efficientnet_proj = nn.Conv2d(self.efficientnet.feature_info[-1]['num_chs'], 
                                           self.fused_channels, kernel_size=1, bias=False)
        # Dynamic fusion by concatenation and a gated (SEBlock) 1x1 convolution.
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(self.fused_channels * 2, self.fused_channels, kernel_size=1, bias=False),
            SEBlock(self.fused_channels)
        )
        
        # Multi-Scale Attention.
        self.attention = MultiScaleAttention(self.fused_channels)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(self.fused_channels, num_classes)
        # Contrastive projection head.
        self.contrastive_head = nn.Sequential(
            nn.Linear(self.fused_channels, self.fused_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.fused_channels, 128)
        )
        # Self-distillation branch (auxiliary head).
        self.distill_head = nn.Sequential(
            nn.Linear(self.fused_channels, self.fused_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.fused_channels, num_classes)
        )
    
    def forward(self, x, return_features=False):
        feat_convnext = self.convnext_proj(self.convnext(x)[-1])
        feat_effnet   = self.efficientnet_proj(self.efficientnet(x)[-1])
        # Concatenate and fuse features.
        fused = torch.cat([feat_convnext, feat_effnet], dim=1)
        fused = self.fusion_conv(fused)
        fused = self.attention(fused)
        pooled = self.avgpool(fused).view(fused.size(0), -1)
        logits = self.classifier(pooled)
        if return_features:
            # Generate contrastive embeddings.
            embedding = self.contrastive_head(pooled)
            embedding = F.normalize(embedding, dim=1)
            # Obtain distillation logits.
            distill_logits = self.distill_head(pooled)
            return logits, embedding, distill_logits
        return logits

# ============================================================
# MixUp and CutMix Functions (with tunable alpha)
# ============================================================
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size, C, H, W = x.size()
    index = torch.randperm(batch_size).to(x.device)
    r_x = np.random.randint(W)
    r_y = np.random.randint(H)
    r_w = int(W * np.sqrt(1 - lam))
    r_h = int(H * np.sqrt(1 - lam))
    x1 = np.clip(r_x - r_w // 2, 0, W)
    y1 = np.clip(r_y - r_h // 2, 0, H)
    x2 = np.clip(r_x + r_w // 2, 0, W)
    y2 = np.clip(r_y + r_h // 2, 0, H)
    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ============================================================
# Supervised Contrastive Loss (SupConLoss)
# ============================================================
class SupConLoss(nn.Module):
    """
    Supervised Contrastive Loss as described in "Supervised Contrastive Learning" (Khosla et al.)
    Expects features of shape [batch_size, n_views, feature_dim]. For a single view, unsqueeze dimension 1.
    """
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        device = features.device
        batch_size = features.shape[0]
        n_views = features.shape[1]  # For a single view, n_views = 1
        features = features.view(batch_size * n_views, -1)
        labels = labels.repeat(n_views)
        similarity_matrix = torch.div(torch.matmul(features, features.T), self.temperature)
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()
        mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)
        # Remove self-comparisons from the mask.
        logits_mask = torch.scatter(torch.ones_like(mask), 1,
                                    torch.arange(batch_size * n_views).view(-1, 1).to(device), 0)
        mask = mask * logits_mask
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        mask_sum = mask.sum(1)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask_sum + 1e-12)
        loss = -mean_log_prob_pos.mean()
        return loss

# ============================================================
# Test-Time Augmentation (TTA) Function
# ============================================================
def predict_with_tta(model, image, tta_transforms, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for transform in tta_transforms:
            aug = transform(image=image)['image']
            aug = aug.unsqueeze(0).to(device)
            output = model(aug)
            predictions.append(torch.softmax(output, dim=1).cpu().numpy())
    return np.mean(predictions, axis=0)

tta_transforms = [
    A.Compose([A.Resize(224, 224), ToTensorV2()]),
    A.Compose([A.Resize(224, 224), A.HorizontalFlip(p=1.0), ToTensorV2()]),
    # Add more TTA variations as desired.
]

# ============================================================
# Training Function with Contrastive Loss, Self-Distillation & Dynamic LR
# ============================================================
def train_model(config, train_loader, val_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = len(full_dataset.classes)
    model = HybridPlantClassifier(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'])
    
    # Choose loss function based on config.
    # Options: "ce", "focal", "weighted_ce"
    loss_type = config.get("loss_type", "ce")
    if loss_type == "focal":
        class_weights_tensor = torch.FloatTensor(class_weights).to(device)
        criterion = FocalLoss(gamma=2.0, weight=class_weights_tensor, reduction="mean")
    elif loss_type == "weighted_ce":
        class_weights_tensor = torch.FloatTensor(class_weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.1)
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    supcon_loss_fn = SupConLoss(temperature=0.07)
    contrastive_weight = config.get("contrastive_weight", 1.0)
    distill_weight = config.get("distill_weight", 0.5)
    mixup_alpha = config.get("mixup_alpha", 1.0)
    cutmix_alpha = config.get("cutmix_alpha", 1.0)
    num_epochs = config['epochs']

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        train_preds = []
        train_labels = []
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            # --- Classification Branch with MixUp/CutMix ---
            if np.random.rand() < 0.5:
                mixed_images, y_a, y_b, lam = mixup_data(images, labels, alpha=mixup_alpha)
            else:
                mixed_images, y_a, y_b, lam = cutmix_data(images, labels, alpha=cutmix_alpha)
            outputs = model(mixed_images)
            loss_ce = mixup_criterion(criterion, outputs, y_a, y_b, lam)

            # --- Contrastive Branch and Self-Distillation using Original Images ---
            _, embeddings, distill_logits = model(images, return_features=True)
            embeddings = embeddings.unsqueeze(1)  # Shape: [batch, 1, feature_dim]
            loss_contrast = supcon_loss_fn(embeddings, labels)
            loss_distill = F.kl_div(F.log_softmax(distill_logits, dim=1),
                                    F.softmax(outputs, dim=1), reduction='batchmean')
            
            total_loss = loss_ce + contrastive_weight * loss_contrast + distill_weight * loss_distill
            total_loss.backward()
            optimizer.step()

            batch_size = labels.size(0)
            train_loss += total_loss.item() * batch_size
            preds = outputs.argmax(dim=1)
            # Using proxy ground truth from mixup/cutmix.
            train_correct += (preds == y_a).sum().item()
            train_total += batch_size
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(y_a.cpu().numpy())
            pbar.set_postfix(loss=f"{total_loss.item():.4f}")

        scheduler.step()
        avg_train_loss = train_loss / train_total
        train_acc = train_correct / train_total
        train_f1 = f1_score(train_labels, train_preds, average="weighted")
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}")

        # --- Validation Phase (without MixUp/CutMix) ---
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                batch_size = labels.size(0)
                val_loss += loss.item() * batch_size
                preds = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += batch_size
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        avg_val_loss = val_loss / val_total
        val_acc = val_correct / val_total
        val_f1 = f1_score(val_labels, val_preds, average="weighted")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}\n")
    return model

# ============================================================
# Evaluation Function (Test Set) with Optional Logits Adjustment
# ============================================================
def evaluate_model(model, loader, adjust_logits_flag=False, class_priors=None, return_acc=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if adjust_logits_flag and (class_priors is not None):
                outputs = adjust_logits(outputs, class_priors)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)
    acc = (np.array(all_preds) == np.array(all_labels)).mean()
    if return_acc:
        return acc
    print(f"\nTest Accuracy: {acc:.4f}\n")
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()
    all_labels_bin = label_binarize(all_labels, classes=range(len(full_dataset.classes)))
    fpr, tpr, _ = roc_curve(all_labels_bin.ravel(), np.array(all_probs).ravel())
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.4f}')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.show()


def objective(trial):
    config = {
        "lr": trial.suggest_float("lr", 1e-5, 1e-3, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True),
        "epochs": 6,
        "contrastive_weight": trial.suggest_float("contrastive_weight", 0.5, 2.0),
        "distill_weight": trial.suggest_float("distill_weight", 0.1, 1.0),
        "mixup_alpha": trial.suggest_float("mixup_alpha", 0.5, 2.0),
        "cutmix_alpha": trial.suggest_float("cutmix_alpha", 0.5, 2.0),
        "loss_type": trial.suggest_categorical("loss_type", ["ce", "focal", "weighted_ce"])
    }
    # For tuning, we use the standard train_loader.
    model = train_model(config, train_loader, val_loader)
    val_acc = evaluate_model(model, val_loader, return_acc=True)
    trial.report(val_acc, step=0)
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
    return val_acc

study = optuna.create_study(direction="maximize", sampler=TPESampler(), pruner=HyperbandPruner())
study.optimize(objective, n_trials=8)
print("Best Hyperparameters:", study.best_params)


# Use the best hyperparameters from Optuna and increase epochs.
best_config = study.best_params.copy()
best_config["epochs"] = 30  # Increase epochs as desired.

# ----- Two-Stage Training -----
# Stage 1: Train on balanced (oversampled) data.
print("Stage 1 Training on Balanced Data (Oversampled):")
model_stage1 = train_model(best_config, train_loader_balanced, val_loader)

# Stage 2: Fine-tune on the original (imbalanced) distribution.
print("Stage 2 Fine-Tuning on Original Distribution:")
model_final = train_model(best_config, train_loader, val_loader)

# Compute class priors from training targets for logits adjustment.
total_samples = len(train_targets)
class_priors = [np.sum(train_targets == i) / total_samples for i in range(len(unique_classes))]
print("Class Priors:", class_priors)

evaluate_model(model_final, test_loader, adjust_logits_flag=True, class_priors=class_priors)


[I 2025-02-27 04:45:00,052] A new study created in memory with name: no-name-12a5740f-1d32-4576-b03e-9196fe126cd6
Epoch 1/6:   1%|          | 1/131 [01:39<3:34:44, 99.11s/batch, loss=8.9971]

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F  # For normalization
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_curve, auc, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import StratifiedShuffleSplit
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import HyperbandPruner

# ---------------------------------
# Dataset Setup and Augmentation
# ---------------------------------
train_dir = "/kaggle/input/asccsasca/Medicinal plant dataset"
if not os.path.exists(train_dir):
    raise FileNotFoundError(f"Dataset not found: {train_dir}")

def stratified_split_dataset(dataset, val_ratio=0.2, test_ratio=0.1):
    """
    Splits a dataset into train, validation, and test subsets using stratified sampling.
    """
    targets = np.array(dataset.targets)
    indices = np.arange(len(targets))
    
    # First, split off the test set.
    sss_test = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=42)
    for train_val_idx, test_idx in sss_test.split(indices, targets):
        pass

    # Now split train_val into train and validation sets.
    train_val_targets = targets[train_val_idx]
    # Adjust validation ratio relative to the remaining samples.
    val_relative_ratio = val_ratio / (1 - test_ratio)
    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=val_relative_ratio, random_state=42)
    for train_idx, val_idx in sss_val.split(np.arange(len(train_val_targets)), train_val_targets):
        pass

    # Map back to original indices.
    train_indices = train_val_idx[train_idx]
    val_indices = train_val_idx[val_idx]
    
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_idx)
    return train_subset, val_subset, test_subset

# Enhanced training augmentation (with added brightness/contrast)
train_transform = A.Compose([
    A.Resize(224, 224),
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.CoarseDropout(max_holes=8, max_height=8, max_width=8, fill_value=0, p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])
# For validation and test, use a simpler transform.
val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform
    def __call__(self, image):
        image = np.array(image)
        return self.transform(image=image)['image']

# Create the full dataset using the training transform.
full_dataset = datasets.ImageFolder(root=train_dir, transform=AlbumentationsTransform(train_transform))
# Perform stratified splitting.
train_dataset, val_dataset, test_dataset = stratified_split_dataset(full_dataset, val_ratio=0.2, test_ratio=0.1)
# Override the transform for validation and test subsets:
train_dataset.dataset.transform = AlbumentationsTransform(train_transform)
val_dataset.dataset.transform   = AlbumentationsTransform(val_transform)
test_dataset.dataset.transform  = AlbumentationsTransform(val_transform)

# Standard loaders (using train_loader without any balancing sampler)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# ============================================================
# Additional Modules (Attention, Fusion, etc.)
# ============================================================

# 1️⃣ Custom Focal Loss (kept here in case you want to try it)
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        logpt = -F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(logpt)
        loss = -((1 - pt) ** self.gamma) * logpt
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

# 3️⃣ Logits Adjustment During Inference: Function to adjust logits using class priors.
def adjust_logits(logits, class_priors):
    # Subtract the log of class priors from the logits.
    log_priors = torch.log(torch.tensor(class_priors, device=logits.device))
    return logits - log_priors

# ============================================================
# SimAM Attention Module
# ============================================================
class SimAM(nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()
        self.e_lambda = e_lambda

    def forward(self, x):
        # x: [B, C, H, W]
        mu = x.mean(dim=[2,3], keepdim=True)
        var = ((x - mu) ** 2).sum(dim=[2,3], keepdim=True) / (x.size(2) * x.size(3) - 1)
        attention = 1.0 / (var + self.e_lambda)
        return x * attention

# ============================================================
# Multi-Scale Attention Module using Depthwise Convolutions + SimAM
# ============================================================
class MultiScaleAttention(nn.Module):
    def __init__(self, channels):
        super(MultiScaleAttention, self).__init__()
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels)
        self.conv5 = nn.Conv2d(channels, channels, kernel_size=5, padding=2, groups=channels)
        self.conv7 = nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels)
        self.fuse_conv = nn.Conv2d(channels * 3, channels, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.simam = SimAM()
    
    def forward(self, x):
        out3 = self.conv3(x)
        out5 = self.conv5(x)
        out7 = self.conv7(x)
        out = torch.cat([out3, out5, out7], dim=1)
        out = self.fuse_conv(out)
        out = self.relu(out)
        out = self.simam(out)
        return out

# ============================================================
# Squeeze-and-Excitation (SE) Block for Gated Fusion
# ============================================================
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Conv2d(channels, channels // reduction, 1)
        self.fc2 = nn.Conv2d(channels // reduction, channels, 1)

    def forward(self, x):
        y = F.adaptive_avg_pool2d(x, 1)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        return x * y

# ============================================================
# Hybrid Model (ConvNeXt Base + EfficientNetV2-S) with Dynamic Feature Fusion, Self-Distillation & Contrastive Projection
# ============================================================
class HybridPlantClassifier(nn.Module):
    def __init__(self, num_classes):
        super(HybridPlantClassifier, self).__init__()
        # Load pretrained feature extractors.
        self.convnext = timm.create_model('convnext_base', pretrained=True, features_only=True)
        self.efficientnet = timm.create_model('efficientnetv2_s', pretrained=False, features_only=True)
        self.fused_channels = 512

        # Project features to a common channel dimension.
        self.convnext_proj = nn.Conv2d(self.convnext.feature_info[-1]['num_chs'], 
                                       self.fused_channels, kernel_size=1, bias=False)
        self.efficientnet_proj = nn.Conv2d(self.efficientnet.feature_info[-1]['num_chs'], 
                                           self.fused_channels, kernel_size=1, bias=False)
        # Dynamic fusion by concatenation and a gated (SEBlock) 1x1 convolution.
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(self.fused_channels * 2, self.fused_channels, kernel_size=1, bias=False),
            SEBlock(self.fused_channels)
        )
        
        # Multi-Scale Attention.
        self.attention = MultiScaleAttention(self.fused_channels)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(self.fused_channels, num_classes)
        # Contrastive projection head.
        self.contrastive_head = nn.Sequential(
            nn.Linear(self.fused_channels, self.fused_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.fused_channels, 128)
        )
        # Self-distillation branch (auxiliary head).
        self.distill_head = nn.Sequential(
            nn.Linear(self.fused_channels, self.fused_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.fused_channels, num_classes)
        )
    
    def forward(self, x, return_features=False):
        feat_convnext = self.convnext_proj(self.convnext(x)[-1])
        feat_effnet   = self.efficientnet_proj(self.efficientnet(x)[-1])
        # Concatenate and fuse features.
        fused = torch.cat([feat_convnext, feat_effnet], dim=1)
        fused = self.fusion_conv(fused)
        fused = self.attention(fused)
        pooled = self.avgpool(fused).view(fused.size(0), -1)
        logits = self.classifier(pooled)
        if return_features:
            # Generate contrastive embeddings.
            embedding = self.contrastive_head(pooled)
            embedding = F.normalize(embedding, dim=1)
            # Obtain distillation logits.
            distill_logits = self.distill_head(pooled)
            return logits, embedding, distill_logits
        return logits

# ============================================================
# MixUp and CutMix Functions (with tunable alpha)
# ============================================================
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size, C, H, W = x.size()
    index = torch.randperm(batch_size).to(x.device)
    r_x = np.random.randint(W)
    r_y = np.random.randint(H)
    r_w = int(W * np.sqrt(1 - lam))
    r_h = int(H * np.sqrt(1 - lam))
    x1 = np.clip(r_x - r_w // 2, 0, W)
    y1 = np.clip(r_y - r_h // 2, 0, H)
    x2 = np.clip(r_x + r_w // 2, 0, W)
    y2 = np.clip(r_y + r_h // 2, 0, H)
    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ============================================================
# Supervised Contrastive Loss (SupConLoss)
# ============================================================
class SupConLoss(nn.Module):
    """
    Supervised Contrastive Loss as described in "Supervised Contrastive Learning" (Khosla et al.)
    Expects features of shape [batch_size, n_views, feature_dim]. For a single view, unsqueeze dimension 1.
    """
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        device = features.device
        batch_size = features.shape[0]
        n_views = features.shape[1]  # For a single view, n_views = 1
        features = features.view(batch_size * n_views, -1)
        labels = labels.repeat(n_views)
        similarity_matrix = torch.div(torch.matmul(features, features.T), self.temperature)
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()
        mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)
        # Remove self-comparisons from the mask.
        logits_mask = torch.scatter(torch.ones_like(mask), 1,
                                    torch.arange(batch_size * n_views).view(-1, 1).to(device), 0)
        mask = mask * logits_mask
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        mask_sum = mask.sum(1)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask_sum + 1e-12)
        loss = -mean_log_prob_pos.mean()
        return loss

# ============================================================
# Test-Time Augmentation (TTA) Function
# ============================================================
def predict_with_tta(model, image, tta_transforms, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for transform in tta_transforms:
            aug = transform(image=image)['image']
            aug = aug.unsqueeze(0).to(device)
            output = model(aug)
            predictions.append(torch.softmax(output, dim=1).cpu().numpy())
    return np.mean(predictions, axis=0)

tta_transforms = [
    A.Compose([A.Resize(224, 224), ToTensorV2()]),
    A.Compose([A.Resize(224, 224), A.HorizontalFlip(p=1.0), ToTensorV2()]),
    # Add more TTA variations as desired.
]

# ============================================================
# Training Function with Contrastive Loss, Self-Distillation & Dynamic LR
# ============================================================
def train_model(config, train_loader, val_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = len(full_dataset.classes)
    model = HybridPlantClassifier(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'])
    
    # Choose loss function based on config.
    # For balanced data we use the standard CrossEntropyLoss with label smoothing.
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    supcon_loss_fn = SupConLoss(temperature=0.07)
    contrastive_weight = config.get("contrastive_weight", 1.0)
    distill_weight = config.get("distill_weight", 0.5)
    mixup_alpha = config.get("mixup_alpha", 1.0)
    cutmix_alpha = config.get("cutmix_alpha", 1.0)
    num_epochs = config['epochs']

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        train_preds = []
        train_labels = []
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            # --- Classification Branch with MixUp/CutMix ---
            if np.random.rand() < 0.5:
                mixed_images, y_a, y_b, lam = mixup_data(images, labels, alpha=mixup_alpha)
            else:
                mixed_images, y_a, y_b, lam = cutmix_data(images, labels, alpha=cutmix_alpha)
            outputs = model(mixed_images)
            loss_ce = mixup_criterion(criterion, outputs, y_a, y_b, lam)

            # --- Contrastive Branch and Self-Distillation using Original Images ---
            _, embeddings, distill_logits = model(images, return_features=True)
            embeddings = embeddings.unsqueeze(1)  # Shape: [batch, 1, feature_dim]
            loss_contrast = supcon_loss_fn(embeddings, labels)
            loss_distill = F.kl_div(F.log_softmax(distill_logits, dim=1),
                                    F.softmax(outputs, dim=1), reduction='batchmean')
            
            total_loss = loss_ce + contrastive_weight * loss_contrast + distill_weight * loss_distill
            total_loss.backward()
            optimizer.step()

            batch_size = labels.size(0)
            train_loss += total_loss.item() * batch_size
            preds = outputs.argmax(dim=1)
            train_correct += (preds == y_a).sum().item()
            train_total += batch_size
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(y_a.cpu().numpy())
            pbar.set_postfix(loss=f"{total_loss.item():.4f}")

        scheduler.step()
        avg_train_loss = train_loss / train_total
        train_acc = train_correct / train_total
        train_f1 = f1_score(train_labels, train_preds, average="weighted")
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}")

        # --- Validation Phase (without MixUp/CutMix) ---
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                batch_size = labels.size(0)
                val_loss += loss.item() * batch_size
                preds = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += batch_size
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        avg_val_loss = val_loss / val_total
        val_acc = val_correct / val_total
        val_f1 = f1_score(val_labels, val_preds, average="weighted")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}\n")
    return model

# ============================================================
# Evaluation Function (Test Set) with Optional Logits Adjustment
# ============================================================
def evaluate_model(model, loader, adjust_logits_flag=False, class_priors=None, return_acc=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if adjust_logits_flag and (class_priors is not None):
                outputs = adjust_logits(outputs, class_priors)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)
    acc = (np.array(all_preds) == np.array(all_labels)).mean()
    if return_acc:
        return acc
    print(f"\nTest Accuracy: {acc:.4f}\n")
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()
    all_labels_bin = label_binarize(all_labels, classes=range(len(full_dataset.classes)))
    fpr, tpr, _ = roc_curve(all_labels_bin.ravel(), np.array(all_probs).ravel())
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.4f}')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.show()

def objective(trial):
    config = {
        "lr": trial.suggest_float("lr", 1e-5, 1e-3, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True),
        "epochs": 6,
        "contrastive_weight": trial.suggest_float("contrastive_weight", 0.5, 2.0),
        "distill_weight": trial.suggest_float("distill_weight", 0.1, 1.0),
        "mixup_alpha": trial.suggest_float("mixup_alpha", 0.5, 2.0),
        "cutmix_alpha": trial.suggest_float("cutmix_alpha", 0.5, 2.0)
    }
    # For tuning, we use the standard train_loader.
    model = train_model(config, train_loader, val_loader)
    val_acc = evaluate_model(model, val_loader, return_acc=True)
    trial.report(val_acc, step=0)
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
    return val_acc

study = optuna.create_study(direction="maximize", sampler=TPESampler(), pruner=HyperbandPruner())
study.optimize(objective, n_trials=8)
print("Best Hyperparameters:", study.best_params)

# Use the best hyperparameters from Optuna and increase epochs.
best_config = study.best_params.copy()
best_config["epochs"] = 30  # Increase epochs as desired.

# ----- Training (Single Stage) -----
print("Training on Original (Balanced) Data:")
model_final = train_model(best_config, train_loader, val_loader)

# Compute class priors from training targets for logits adjustment.
train_targets = [full_dataset.targets[i] for i in train_dataset.indices]
train_targets = np.array(train_targets)
unique_classes = np.unique(train_targets)
total_samples = len(train_targets)
class_priors = [np.sum(train_targets == i) / total_samples for i in range(len(unique_classes))]
print("Class Priors:", class_priors)

evaluate_model(model_final, test_loader, adjust_logits_flag=True, class_priors=class_priors)
