################################################################################
# NEW                                                                          #
################################################################################

In [1]:
import os
import random
import numpy as np
import pandas as pd
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

def get_transforms(image_size: int):
    """
    Define and return strong data augmentation transforms for training,
    and standard transforms for evaluation (validation/testing).
    Args:
        image_size (int): Target size for image resizing/cropping.
    Returns:
        train_transform: torchvision transform for training images.
        eval_transform: torchvision transform for evaluation images.
    """
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    eval_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    return train_transform, eval_transform

def train_model_for_condition(
    data_dir: str,
    condition: str,
    num_epochs: int = 5,
    batch_size: int = 32,
    lr: float = 1e-4,
    image_size: int = 224,
    device: str = "cpu"
):
    """
    Train a ResNet-18 model to classify a target weather condition (vs. clear).
    Args:
        data_dir (str): Root directory of the ACDC dataset.
        condition (str): Weather condition to classify (e.g., 'fog', 'rain', etc.).
        num_epochs (int): Number of epochs to train.
        batch_size (int): Batch size for DataLoader.
        lr (float): Learning rate for Adam optimizer.
        image_size (int): Image resize/crop size.
        device (str): 'cpu' or 'cuda' for GPU acceleration.
    Returns:
        train_losses: List of training loss values per epoch.
        val_losses: List of validation loss values per epoch.
        test_loss: Final loss value on the test set.
        test_acc: Final accuracy value on the test set.
    """
    # Construct paths for the current condition's data splits
    train_dir = os.path.join(data_dir, condition, "train")
    val_dir   = os.path.join(data_dir, condition, "val")
    test_dir  = os.path.join(data_dir, condition, "test")

    # Get data transforms for train and eval phases
    train_tf, eval_tf = get_transforms(image_size)

    # DataLoader: loads and batches images from the folder structure
    train_loader = DataLoader(
        datasets.ImageFolder(train_dir, transform=train_tf),
        batch_size=batch_size, shuffle=True, num_workers=2
    )
    val_loader = DataLoader(
        datasets.ImageFolder(val_dir, transform=eval_tf),
        batch_size=batch_size, shuffle=False, num_workers=2
    )
    test_loader = DataLoader(
        datasets.ImageFolder(test_dir, transform=eval_tf),
        batch_size=batch_size, shuffle=False, num_workers=2
    )

    # Load ResNet-18, replace last FC for binary classification (condition vs. clear)
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Lists to store loss history
    train_losses, val_losses = [], []

    # Training loop
    for epoch in range(1, num_epochs+1):
        model.train()
        run_loss = run_corr = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            loss = criterion(out, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            preds = out.argmax(1)
            run_loss += loss.item() * imgs.size(0)
            run_corr += (preds == labels).sum().item()
        train_losses.append(run_loss / len(train_loader.dataset))

        # Validation loop (no gradients)
        model.eval()
        val_loss = val_corr = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                out = model(imgs)
                loss = criterion(out, labels)
                preds = out.argmax(1)
                val_loss += loss.item() * imgs.size(0)
                val_corr += (preds == labels).sum().item()
        val_losses.append(val_loss / len(val_loader.dataset))

        print(f"{condition} Epoch {epoch}/{num_epochs}: "
              f"Train L {train_losses[-1]:.4f}, Val L {val_losses[-1]:.4f}")

    # Final evaluation on test set
    test_loss = test_corr = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            loss = criterion(out, labels)
            preds = out.argmax(1)
            test_loss += loss.item() * imgs.size(0)
            test_corr += (preds == labels).sum().item()
    test_loss /= len(test_loader.dataset)
    test_acc = test_corr / len(test_loader.dataset)

    # Save the trained model weights for this condition
    torch.save(model.state_dict(), f"model_{condition}_new.pth")
    return train_losses, val_losses, test_loss, test_acc

def dark_cell_style(df, caption):
    """
    Helper function to apply dark-themed styling to pandas DataFrame (for HTML output).
    Args:
        df (pd.DataFrame): DataFrame to style.
        caption (str): Table caption.
    Returns:
        pd.io.formats.style.Styler: Styled DataFrame for display or saving as HTML.
    """
    dark_bg, white_txt = '#222', '#fff'
    fmt = {col: '{:.3f}' for col in df.columns if col != 'condition'}
    styler = (
        df.style
          .format(fmt, na_rep='–')
          .set_properties(**{'background-color': dark_bg, 'color': white_txt})
          .set_table_styles([
              {'selector': 'th', 'props': [('background-color', dark_bg), ('color', white_txt)]},
              {'selector': 'caption', 'props': [('caption-side', 'top'), ('font-size', '1.1em'), ('color', white_txt)]},
          ])
          .set_caption(caption)
    )
    return styler

def main():
    # Path to the root of ACDC dataset (edit as needed)
    data_root = r"D:\praca_magisterska\Adverse_weather_detection\ACDC\rgb_anon_concat"
    # Weather conditions to classify (each vs. clear)
    conditions = ["fog", "rain", "snow", "night"]
    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_histories, summary = {}, []
    # Loop over each weather condition and train a separate model
    for cond in conditions:
        tr_h, val_h, t_l, t_a = train_model_for_condition(
            data_root, cond, num_epochs=10,
            batch_size=32, lr=1e-4,
            image_size=224, device=device
        )
        all_histories[cond] = {'train': tr_h, 'val': val_h}
        summary.append({'condition': cond, 'test_loss': t_l, 'test_acc': t_a})

    # Store summary results as a CSV file for later analysis
    df = pd.DataFrame(summary)
    df.to_csv('metrics_summary.csv', index=False)

    # Common Plotly layout configuration for all plots
    layout = go.Layout(
        title="Loss VS step",
        xaxis=dict(title="Epoch", gridwidth=0.5, gridcolor='rgb(60, 60, 60)', tickfont=dict(size=14)),
        yaxis=dict(title="Value", gridwidth=0.5, gridcolor='rgb(60, 60, 60)', tickfont=dict(size=14)),
        plot_bgcolor='rgb(30, 30, 30)',
        paper_bgcolor='rgb(30, 30, 30)',
        font=dict(color='white'),
        legend=dict(font=dict(size=16))
    )

    # Plot: Test Accuracy bar chart for all weather conditions
    fig_acc = go.Figure(
        data=[go.Bar(x=df['condition'], y=df['test_acc'], name='Test Accuracy')],
        layout=layout
    )
    fig_acc.update_traces(
        text=[f"{v:.3f}" for v in df['test_acc']],
        textposition='auto',
        textfont=dict(size=18)
    )
    fig_acc.update_layout(
        title='Test Accuracy Comparison Across Conditions',
        title_font_size=30,
        xaxis=dict(
            title='Condition',
            title_font_size=24,
            tickfont_size=20
        ),
        yaxis=dict(
            title='Accuracy',
            title_font_size=24,
            tickfont_size=20
        ),
        legend=dict(font=dict(size=18)),
        plot_bgcolor='rgb(30, 30, 30)',
        paper_bgcolor='rgb(30, 30, 30)',
        font=dict(color='white')
    )
    fig_acc.write_html('accuracy_comparison.html')

    # Plot: Validation loss curves for all conditions
    fig_val = go.Figure(layout=layout)
    for cond in conditions:
        fig_val.add_trace(go.Scatter(
            x=list(range(1, len(all_histories[cond]['val'])+1)),
            y=all_histories[cond]['val'],
            mode='lines', name=cond
        ))
    fig_val.update_layout(title='Validation Loss Curves for All Conditions')
    fig_val.write_html('val_loss_comparison.html')

    # Styled HTML table with metrics summary
    styled = dark_cell_style(df, 'Quantitative Metrics Summary')
    html = styled.to_html()
    with open('metrics_table.html', 'w', encoding='utf-8') as f:
        f.write(html)

if __name__ == "__main__":
    main()




fog Epoch 1/10: Train L 0.1288, Val L 0.0013
fog Epoch 2/10: Train L 0.0070, Val L 0.0011
fog Epoch 3/10: Train L 0.0050, Val L 0.0020
fog Epoch 4/10: Train L 0.0014, Val L 0.0006
fog Epoch 5/10: Train L 0.0010, Val L 0.0005
fog Epoch 6/10: Train L 0.0012, Val L 0.0003
fog Epoch 7/10: Train L 0.0007, Val L 0.0002
fog Epoch 8/10: Train L 0.0004, Val L 0.0002
fog Epoch 9/10: Train L 0.0005, Val L 0.0002
fog Epoch 10/10: Train L 0.0005, Val L 0.0001
rain Epoch 1/10: Train L 0.1515, Val L 0.0292
rain Epoch 2/10: Train L 0.0199, Val L 0.0135
rain Epoch 3/10: Train L 0.0125, Val L 0.0819
rain Epoch 4/10: Train L 0.0102, Val L 0.0164
rain Epoch 5/10: Train L 0.0041, Val L 0.0108
rain Epoch 6/10: Train L 0.0037, Val L 0.0138
rain Epoch 7/10: Train L 0.0025, Val L 0.0170
rain Epoch 8/10: Train L 0.0020, Val L 0.0068
rain Epoch 9/10: Train L 0.0008, Val L 0.0181
rain Epoch 10/10: Train L 0.0087, Val L 0.0147
snow Epoch 1/10: Train L 0.1435, Val L 0.0009
snow Epoch 2/10: Train L 0.0110, Val L 0.0