In [38]:
import numpy as np
import sys
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets, transforms
from torchvision.models import resnet50
from transformers import ViTModel

sys.path.append("ssl_library")
from src.pkg.embedder import Embedder
from src.pkg.wrappers import ViTWrapper, Wrapper

# sys.path.append("local_python")
from local_python.dataframe_image_dataset import DataframeImageDataset

In [39]:
seed = 19
img_size = 224
batch_size = 16
normalise_mean = (0.485, 0.456, 0.406)  # ImageNet
normalise_std = (0.229, 0.224, 0.225)  # ImageNet

configuration_csv_path = "configs/tasks-configuration-new.csv"

In [40]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

In [41]:
assert torch.cuda.is_available()
n_devices = torch.cuda.device_count()
for i in range(0, n_devices):
    print(torch.cuda.get_device_name(i))

device = torch.device("cuda")

NVIDIA GeForce GTX 960


In [42]:
def load_model(model_architecture, checkpoint_path):
    model = None
    if model_architecture == "ResNet50":
        model = resnet50(weights=None)
        model = nn.Sequential(*list(model.children())[:-1])
        model = Wrapper(model=model)
    elif model_architecture == "swin_t":
        print(f"Implement!")
    elif model_architecture == "vit_b16":
        print(f"Implement!")
    elif model_architecture == "ViT_T16":
        if "vit_t16_v1" in checkpoint_path:
            model = Embedder.load_pretrained("imagenet_vit_tiny")
        elif "vit_t16_v2" in checkpoint_path:
            model = Embedder.load_pretrained("vit_tiny_random")
        elif "vit_t16_v3" in checkpoint_path:
            # NOTE: VisionTransformer from timm neds to be wrapped to get intermediate results
            model = timm.create_model("vit_tiny_patch16_224", pretrained=False)
            model = ViTWrapper(model)
        model.head = nn.Sequential()
    else:
        print(f"Unknown model architecture: {model_architecture}")
        return None

    assert model != None
    print(f"Loading {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint, strict=True)

    for param in model.parameters():
        param.requires_grad = False

    return model

In [43]:
def load_dataloader(data_dir):
    df = pd.read_csv(data_dir)
    transform = transforms.Compose(
        [
            # NOTE: ResNet50_Weights.IMAGENET1K_V1 also uses these resize and crop values
            transforms.Resize((256, 256)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(normalise_mean, normalise_std),
        ]
    )

    # NOTE: DataframeImageDataset uses pil_loader as default, which executes Image.convert("RGB") implicitly
    ds_full = DataframeImageDataset(
        df,
        filepath_column="filepath",
        label_columns=["target_code", "set"],
        transform=transform,
    )

    dl_full = data.DataLoader(
        ds_full,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
    )
    return dl_full


# type(df["filepath"].iloc[0]) # str
# type(df["filepath"].iloc[[0]]) # pandas.core.series.Series
# type(df[["filepath"]].iloc[0]) # pandas.core.series.Series
# type(df[["filepath", "filepath"]].iloc[0]) # pandas.core.series.Series
# type(df[["filepath", "filepath"]].iloc[[0]]) # pandas.core.frame.DataFrame
# type(df[label_columns].iloc[0].values) # pandas.core.frame.DataFrame

# dl_demo = load_dataloader("../datasets/demo/split.csv")
# for i, (images, target_codes, sets) in enumerate(dl_demo):
#     print(f"{i}: {len(images)}, {len(target_codes)}, {len(sets)}")

In [44]:
def calculate_features(model, dl_full):
    model = model.to(device)
    model.eval()
    features_list = []
    target_code_list = []
    set_list = []

    for images, target_codes, sets in dl_full:
        target_code_list.append(target_codes)
        set_list.append(sets)
        images = images.to(device)

        # NOTE: VitWrapper automatically returns results of last 4 blocks
        outputs = model(images)
        numpy_outputs = outputs.cpu().numpy()

        preferred_shape = [numpy_outputs.shape[0], -1] # Shape must be 2D
        features_list.append(numpy_outputs.reshape(preferred_shape))

    print(f"Number of batches: {len(features_list)}")
    np_features = np.concatenate(features_list)
    df_full = pd.DataFrame(np_features)
    df_full["target_code"] = pd.Series(np.concatenate(target_code_list))
    df_full["set"] = pd.Series(np.concatenate(set_list))
    return df_full

In [45]:
def save_dataframe(df_features, csv_path):
    cols = df_features.columns.tolist()
    cols = cols[-2:] + cols[:-2]
    df_features = df_features[cols]
    df_features.to_csv(csv_path)
    print(f"Csv file saved: {feature_path}")

In [46]:
df_config = pd.read_csv(configuration_csv_path)
counter = 0

df_config

for index, row in df_config.iterrows():
    dataset_path = str(row["dataset_path"])
    architecture = str(row["architecture"])
    weigths_path = str(row["weigths_path"])
    feature_path = str(row["feature_path"])

    target_filename = os.path.basename(feature_path)
    name_parts = target_filename.split(".")
    assert len(name_parts) == 2
    assert name_parts[1] == "csv"
    name_parts = name_parts[0].split("-")
    assert len(name_parts) == 3
    assert name_parts[0].lower() in dataset_path.replace("-", "_").lower()
    assert name_parts[1] == architecture
    assert name_parts[2] in weigths_path

    if os.path.exists(feature_path):
        print(f"Csv file already exists: {feature_path}")
        counter += 1
    else:
        if 0 < counter:
            print(f"Skipped {counter} existing files")
            counter = 0
        model = load_model(architecture, weigths_path)
        dataloader = load_dataloader(dataset_path)
        df_features = calculate_features(model, dataloader)
        save_dataframe(df_features, feature_path)

Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ResNet50-Derma_SSL_SimCLR.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ResNet50-ImageNet_1k_SL_V1.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ResNet50-ImageNet_1k_SSL_SimCLR.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ResNet50-PDDD.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ResNet50-Random.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ResNet50-Derma_SSL_SimCLR.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ResNet50-ImageNet_1k_SL_V1.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ResNet50-ImageNet_1k_SSL_SimCLR.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ResNet50-PDDD.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ResNet50-Random.csv
Csv file already exists: ../datasets/intermediate-features/H