In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from gl_clip import clip
import ast
import numpy as np
from tqdm import tqdm
from PIL import Image

from modules.feature_extract import Feature_Extractor
from modules.data_loader import WeatherDataset
from modules.classifier_finetune import Finetune_model
from modules.classifier import MWSC
from modules.metric import Metric

device = "cuda" if torch.cuda.is_available() else "cpu"
ablation_mode = 1
batch_size = 64
pretrain = False
clip_base_model = 'ViT-B/32'
timm_model_name = 'vit_base_patch32_clip_224'

image_path = '/data/MWSC/data/'
test_label_path = '/data/MWSC/data/label/test_data.csv'
state_dict_path = '/data/MWSC/result/mwsc_ablation_mode_1_ViT_B_32.pth'
fine_tune_state_dict_path = f'/data/MWSC/result/finetune_{timm_model_name}.pth'
image_save_path = '/data/MWSC/image/'

In [None]:
feature = Feature_Extractor(device, clip_base_model)
test_dataset = WeatherDataset(test_label_path, transform=feature.preprocess, data_dir=image_path)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

weather_types = test_dataset.weather_types
severity_levels = test_dataset.severity_levels

prompts = test_dataset.prompts
text_inputs = torch.cat([clip.tokenize(p) for p in prompts]).to(device)

metric = Metric(weather_types, severity_levels)

In [None]:
ww = []
ss = []


model = MWSC(clip_base_model, len(weather_types), len(severity_levels), ablation_mode)
model.load_state_dict(torch.load(state_dict_path))
model = model.to(device)
model.eval()

result = []
all_weather_probs = []
all_weather_labels = []
all_severity_probs = []
all_severity_labels = []
w_deep_features = []
s_deep_features = []


model2 = Finetune_model(timm_model_name, len(weather_types), len(severity_levels), pretrain)
model2.load_state_dict(torch.load(fine_tune_state_dict_path))
model2 = model2.to(device)
model2.eval()

result2 = []
all_weather_probs2 = []
all_weather_labels2 = []
all_severity_probs2 = []
all_severity_labels2 = []
w_deep_features2 = []
s_deep_features2 = []


In [None]:
with torch.no_grad():
    for images, weather_labels, severity_labels in tqdm(test_loader):
        images = images.to(device)
        weather_labels = weather_labels.to(device)
        severity_labels = severity_labels.to(device)
        ww += weather_labels.cpu().numpy().tolist()
        ss += severity_labels.cpu().numpy().tolist()
        
        global_feat, local_feat, text_features = feature(images, text_inputs)

        weather_out, severity_out = model(global_feat, local_feat, text_features)
        
        weather_probs = torch.sigmoid(weather_out).cpu().numpy()
        all_weather_probs.extend(weather_probs)
        all_weather_labels.extend(weather_labels.cpu().numpy())

        severity_probs = F.softmax(severity_out, dim=1).cpu().numpy()
        all_severity_probs.extend(severity_probs)
        all_severity_labels.extend(severity_labels.cpu().numpy())
        
        result.append([all_weather_probs, all_weather_labels, all_severity_probs, all_severity_labels])

        w_deep_features += weather_out.cpu().numpy().tolist()
        s_deep_features += severity_out.cpu().numpy().tolist()


        weather_out2, severity_out2 = model2(images)
        
        weather_probs2 = torch.sigmoid(weather_out2).cpu().numpy()
        all_weather_probs2.extend(weather_probs2)
        all_weather_labels2.extend(weather_labels.cpu().numpy())

        severity_probs2 = F.softmax(severity_out2, dim=1).cpu().numpy()
        all_severity_probs2.extend(severity_probs2)
        all_severity_labels2.extend(severity_labels.cpu().numpy())
        
        result2.append([all_weather_probs2, all_weather_labels2, all_severity_probs2, all_severity_labels2])

        w_deep_features2 += weather_out2.tolist()
        s_deep_features2 += severity_out2.tolist()


## t-SNE

In [None]:
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
ww = np.array(ww)
ww_class = np.argmax(ww, axis=1).reshape(-1, 1)
weather = ['clear', 'foggy', 'snowy', 'rainy']

tsne = TSNE(n_components=2, random_state=0)
wcluster = np.array(tsne.fit_transform(np.array(w_deep_features)))

plt.figure(figsize=(10, 10))
for i, label in zip(range(4), weather):
    idx = np.where(ww_class == i)
    plt.scatter(wcluster[idx, 0], wcluster[idx, 1], marker='.', label=label)
plt.legend(fontsize = 25)
plt.show()

In [None]:
tsne = TSNE(n_components=2, random_state=0)
wcluster2 = np.array(tsne.fit_transform(np.array(w_deep_features2)))

plt.figure(figsize=(10, 10))
for i, label in zip(range(4), weather):
    idx = np.where(ww_class == i)
    plt.scatter(wcluster2[idx, 0], wcluster2[idx, 1], marker='.', label=label)

plt.legend(fontsize = 25)
plt.show()

In [None]:
ss = np.array(ss)
severity = ['light', 'moderate', 'heavy']

tsne = TSNE(n_components=2, random_state=0)
scluster = np.array(tsne.fit_transform(np.array(s_deep_features)))

plt.figure(figsize=(10, 10))
for i, label in zip(range(3), severity):
    idx = np.where(ss == i)
    plt.scatter(scluster[idx, 0], scluster[idx, 1], marker='.', label=label)

plt.legend(fontsize = 25)
plt.show()

In [None]:
tsne = TSNE(n_components=2, random_state=0)
scluster2 = np.array(tsne.fit_transform(np.array(s_deep_features2)))
plt.figure(figsize=(10, 10))
for i, label in zip(range(3), severity):
    idx = np.where(ss == i)
    plt.scatter(scluster2[idx, 0], scluster2[idx, 1], marker='.', label=label)

plt.legend(fontsize=25)
plt.show()

## Image Output

In [None]:
multi = []
for idx, i in enumerate(test_dataset.data['weather']):
    if len(i) > 9:
        multi.append(idx)

In [None]:
len(multi)

In [None]:
weather_to_idx = {w: i for i, w in enumerate(weather_types)}
severity_to_idx = {s: i for i, s in enumerate(severity_levels)}

# sample = random.sample(multi, 10)
for idx in multi:
    plt.figure(figsize=(8,4))
    plt.imshow(Image.open(test_dataset.data['filepath'][idx]))
    
    plt.title("GT: " + test_dataset.data['weather'][idx]+", " + test_dataset.data['severity'][idx], fontdict = {'fontsize' : 18})
    plt.axis("off")
    plt.savefig(f'{image_save_path}/result_image_{idx}')
    #plt.show()
    
    weather_list = ast.literal_eval(test_dataset.data.iloc[idx]['weather'])
    weather_list = [x.lower() for x in weather_list]
    weather_label = np.zeros(len(weather_types))
    for w in weather_list:
        weather_label[weather_to_idx[w.lower()]] = 1
        
    severity = test_dataset.data.iloc[idx]['severity'].lower()
    severity = severity.split('\'')[1]
    
    plt.figure(figsize=(4,4))
    plt.rc('font', size=16) 
    true = [1 if x>0.5 else 0 for i, x in enumerate(result[-1][0][idx])]
    max = 0
    if len(set(true)) == 1:
        for x in b[idx]:
            if x > max:
                max = x
    if len(set(true)) == 1:
        true = [1 if max==x else 0 for i,x in enumerate(result[-1][0][idx])]
    clrs = []
    for i, j in zip(true, weather_label):
        if i == 1 and i == j:
            clrs.append('#B7F0B1')
        elif i == 1 and i != j:
            clrs.append('#FFA7A7')
        elif i == 0 and j == 1:
            clrs.append('#FFA7A7')
        else:
            clrs.append('#B2CCFF')
    plt.bar(weather_types, result[-1][0][idx], color=clrs)
    plt.gca().set_axisbelow(True)
    plt.xticks()
    plt.ylim(0,1)
    plt.axhline(y=0.5, xmin=0, xmax=1, color='black', linestyle='dashed')
    plt.savefig(f'{image_save_path}/mwsc_result_weather_{idx}')
    #plt.show()
    
    plt.figure(figsize=(4,4))
    plt.rc('font', size=17) 
    preds = np.max(result[-1][2][idx])
    true = [severity_levels[i] if x==preds else None for i, x in enumerate(result[-1][2][idx])]
    clrs = []
    for i in true:
        if i == severity:
            clrs.append('#B7F0B1')
        elif i == None:
            clrs.append('#B2CCFF')
        else:
            clrs.append('#FFA7A7')
    plt.bar(severity_levels, result[-1][2][idx], color=clrs)
    plt.gca().set_axisbelow(True)
    plt.ylim(0,1)
    plt.savefig(f'{image_save_path}/mwsc_result_severity_{idx}')
    #plt.show()
    
    plt.figure(figsize=(4,4))
    plt.rc('font', size=16) 
    true = [1 if x>0.5 else 0 for i, x in enumerate(result2[-1][0][idx])]
    max = 0
    if len(set(true)) == 1:
        for x in result2[-1][0][idx]:
            if x > max:
                max = x
    if len(set(true)) == 1:
        true = [1 if max==x else 0 for i,x in enumerate(result2[-1][0][idx])]
    clrs2 = []
    for i, j in zip(true, weather_label):            
        if i == 1 and i == j:
            clrs2.append('#B7F0B1')
        elif i == 1 and i != j:
            clrs2.append('#FFA7A7')
        elif i == 0 and j == 1:
            clrs2.append('#FFA7A7')
        else:
            clrs2.append('#B2CCFF')
    plt.bar(weather_types, result2[-1][0][idx], color=clrs2)
    plt.gca().set_axisbelow(True)
    plt.ylim(0,1)
    plt.axhline(y=0.5, xmin=0, xmax=1, color='black', linestyle='dashed')
    plt.savefig(f'{image_save_path}/finetune_result_weather_{idx}')
    #plt.show()

    plt.figure(figsize=(4,4))
    plt.rc('font', size=17) 
    preds = np.max(result2[-1][2][idx])
    true = [severity_levels[i] if x==preds else None for i, x in enumerate(result2[-1][2][idx])]
    clrs = []
    for i in true:
        if i == severity:
            clrs.append('#B7F0B1')
        elif i == None:
            clrs.append('#B2CCFF')
        else:
            clrs.append('#FFA7A7')
    plt.bar(severity_levels, result2[-1][2][idx], color=clrs)
    plt.gca().set_axisbelow(True)
    plt.ylim(0,1)
    plt.savefig(f'{image_save_path}/finetune_result_severity_{idx}')
    # plt.show()
