In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import timm

from modules.data_loader import WeatherDataset
from modules.classifier_finetune import Finetune_model
from modules.metric import Metric

device = "cuda" if torch.cuda.is_available() else "cpu"
pretrain = False
batch_size = 64
timm_model_name = 'crossvit_18_240'

image_path = '/data/MWSC/data/'
test_label_path = '/data/MWSC/data/label/test_data.csv'
fine_tune_state_dict_path = f'/data/MWSC/result/finetune_{timm_model_name}.pth'

In [None]:
if pretrain == True:
    m = timm.create_model(timm_model_name, pretrained=True)
else:
    m = timm.create_model(timm_model_name, pretrained=False)
data_config = timm.data.resolve_model_data_config(m)
transforms = timm.data.create_transform(**data_config, is_training=False)
test_dataset = WeatherDataset(test_label_path, transform=transforms, data_dir=image_path)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

weather_types = test_dataset.weather_types
severity_levels = test_dataset.severity_levels

metric = Metric(weather_types, severity_levels)

In [None]:
model = Finetune_model(timm_model_name, len(weather_types), len(severity_levels), pretrain)
model.load_state_dict(torch.load(fine_tune_state_dict_path))
model = model.to(device)
model.eval()
trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

all_weather_probs = []
all_weather_labels = []
all_severity_probs = []
all_severity_labels = []

with torch.no_grad():
    for images, weather_labels, severity_labels in tqdm(test_loader):
        images = images.to(device)
        weather_labels = weather_labels.to(device)
        severity_labels = severity_labels.to(device)
        
        weather_out, severity_out = model(images)
        
        weather_probs = torch.sigmoid(weather_out).cpu().numpy()
        all_weather_probs.extend(weather_probs)
        all_weather_labels.extend(weather_labels.cpu().numpy())

        severity_probs = F.softmax(severity_out, dim=1).cpu().numpy()
        all_severity_probs.extend(severity_probs)
        all_severity_labels.extend(severity_labels.cpu().numpy())
        
weather_metrics = metric.calculate(np.array(all_weather_probs), np.array(all_weather_labels), is_multilabel=True)
severity_metrics = metric.calculate(np.array(all_severity_probs), np.array(all_severity_labels), is_multilabel=False)
print(f"Trainable Parameters : {trainable_parameters}")
print(f'Weather Accuracy: {weather_metrics["accuracy"]:.4f}')
print(f'Weather mAP:{weather_metrics["map"]:.4f}')
print(f'Severity mAP:{severity_metrics["map"]:.4f}')
print('Weather Classification Report:')
print(weather_metrics['classification_report'])
print('Severity Classification Report:')
print(severity_metrics['classification_report'])
print('Weather Confusion Matrix:')
print(weather_metrics['confusion_matrix'])
print('-' * 50)