In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from gl_clip import clip
import numpy as np
from tqdm import tqdm

from modules.feature_extract import Feature_Extractor
from modules.data_loader import WeatherDataset
from modules.classifier import MWSC
from modules.metric import Metric

device = "cuda" if torch.cuda.is_available() else "cpu"
ablation_mode = 1
batch_size = 64
clip_base_model = 'ViT-B/32'

image_path = '/data/MWSC/data/'
test_label_path = '/data/MWSC/data/label/test_data.csv'
state_dict_path = '/data/MWSC/result/mwsc_ablation_mode_1_ViT_B_32.pth'

In [None]:
feature = Feature_Extractor(device, clip_base_model)
test_dataset = WeatherDataset(test_label_path, transform=feature.preprocess, data_dir=image_path)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8)

weather_types = test_dataset.weather_types
severity_levels = test_dataset.severity_levels

prompts = test_dataset.prompts
text_inputs = torch.cat([clip.tokenize(p) for p in prompts]).to(device)

metric = Metric(weather_types, severity_levels)

In [None]:
model = MWSC(clip_base_model, len(weather_types), len(severity_levels), ablation_mode)
model.load_state_dict(torch.load(state_dict_path))
model = model.to(device)
model.eval()
trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

## If you want to measure fps, uncomment the comments below and change the batch size to 1.
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
total_time = 0
idx= 0

all_weather_probs = []
all_weather_labels = []
all_severity_probs = []
all_severity_labels = []

with torch.no_grad():
    for images, weather_labels, severity_labels in tqdm(test_loader):
        images = images.to(device)
        weather_labels = weather_labels.to(device)
        severity_labels = severity_labels.to(device)
        
        starter.record()

        global_feat, local_feat, text_features = feature(images, text_inputs)

        weather_out, severity_out = model(global_feat, local_feat, text_features)
        
        ender.record()
        torch.cuda.synchronize()
        
        weather_probs = torch.sigmoid(weather_out).cpu().numpy()
        all_weather_probs.extend(weather_probs)
        all_weather_labels.extend(weather_labels.cpu().numpy())

        severity_probs = F.softmax(severity_out, dim=1).cpu().numpy()
        all_severity_probs.extend(severity_probs)
        all_severity_labels.extend(severity_labels.cpu().numpy())
        
        infer_time = starter.elapsed_time(ender)
        total_time+=infer_time
        idx+=1

weather_metrics = metric.calculate(np.array(all_weather_probs), np.array(all_weather_labels), is_multilabel=True)
severity_metrics = metric.calculate(np.array(all_severity_probs), np.array(all_severity_labels), is_multilabel=False)

print(f"Trainable Parameters : {trainable_parameters}")
print(f'Weather Accuracy: {weather_metrics["accuracy"]:.4f}')
print(f'Weather mAP:{weather_metrics["map"]:.4f}')
print(f'Severity mAP:{severity_metrics["map"]:.4f}')
print('Weather Classification Report:')
print(weather_metrics['classification_report'])
print('Severity Classification Report:')
print(severity_metrics['classification_report'])
print('Weather Confusion Matrix:')
print(weather_metrics['confusion_matrix'])
print('\nSeverity Confusion Matrix:')
print(severity_metrics['confusion_matrix'])
print('-' * 50)

avg_time_per_iter  = total_time/ idx
avg_time_per_iter_sec = avg_time_per_iter * 1e-3
avg_fps = 1 / avg_time_per_iter_sec

print(f"Avg time per iteration (ms): {avg_time_per_iter:.2f} ms")
print(f"Avg time per iteration (sec): {avg_time_per_iter_sec:.4f} s")
print(f"FPS: {avg_fps:.2f}")

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay
for i in range(4):
    
    cm = weather_metrics['confusion_matrix'][i]

    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    disp.ax_.set_title(f"Weather Confusion Matrix - {weather_types[i]}")


In [None]:
from sklearn.metrics import ConfusionMatrixDisplay
cm = severity_metrics['confusion_matrix']

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=severity_levels)
disp.plot()
disp.ax_.set_title("Severity Confusion Matrix")


In [None]:
from torchmetrics.functional.classification import multiclass_calibration_error
from torch import tensor
all_severity_probs = tensor(all_severity_probs)
all_severity_labels = tensor(all_severity_labels)
print(multiclass_calibration_error(all_severity_probs, all_severity_labels, num_classes=3, n_bins=3, norm='l1'))
# print(multiclass_calibration_error(all_severity_probs, all_severity_labels, num_classes=3, n_bins=3, norm='l2'))
# print(multiclass_calibration_error(all_severity_probs, all_severity_labels, num_classes=3, n_bins=3, norm='max'))

In [None]:
from torch import randn, randint
# Example plotting a multiple values
from torchmetrics.classification import MulticlassCalibrationError
metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1')

values = []
for _ in range(len(all_severity_labels)):
    values.append(metric(all_severity_probs, all_severity_labels))
fig_, ax_ = metric.plot(values)