In [1]:
import numpy as np
import sys
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets, transforms
from torchvision.models import resnet50
from transformers import ViTModel

sys.path.append("ssl_library")
from src.pkg.embedder import Embedder
from src.pkg.wrappers import ViTWrapper

In [2]:
seed = 19
img_size = 224
batch_size = 16
normalise_mean = (0.485, 0.456, 0.406)  # ImageNet
normalise_std = (0.229, 0.224, 0.225)  # ImageNet

configuration_csv_path = "configs/tasks-configuration-new.csv"

In [3]:
torch.manual_seed(seed)

<torch._C.Generator at 0x24845202cb0>

In [4]:
assert torch.cuda.is_available()
n_devices = torch.cuda.device_count()
for i in range(0, n_devices):
    print(torch.cuda.get_device_name(i))

device = torch.device("cuda")

NVIDIA GeForce GTX 960


In [5]:
def load_model(model_architecture, checkpoint_path):
    model = None

    if model_architecture == "resnet50":
        print(f"Implement!")
    elif model_architecture == "swin_t":
        print(f"Implement!")
    elif model_architecture == "vit_b16":
        print(f"Implement!")
    elif model_architecture == "ViT_T16":
        if "vit_t16_v1" in checkpoint_path:
            model = Embedder.load_pretrained("imagenet_vit_tiny")
        elif "vit_t16_v2" in checkpoint_path:
            model = Embedder.load_pretrained("vit_tiny_random")
        elif "vit_t16_v3" in checkpoint_path:
            # NOTE: VisionTransformer from timm neds to be wrapped to get intermediate results
            model = timm.create_model("vit_tiny_patch16_224", pretrained=False)
            model = ViTWrapper(model)
        model.head = nn.Sequential()
    else:
        print(f"Unknown model architecture: {model_architecture}")
        return None

    assert model != None
    print(f"Loading {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint, strict=True)

    for param in model.parameters():
        param.requires_grad = False

    return model

In [6]:
class ImageFolderExtented(datasets.ImageFolder):
    # NOTE: ImageFolder uses pil_loader as default, which executes Image.convert("RGB") implicitly
    def __getitem__(self, index: int):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target, os.path.basename(path)


def load_dataloader(data_dir):
    transform = transforms.Compose(
        [
            transforms.Resize((256, 256)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(normalise_mean, normalise_std),
        ]
    )
    ds_full = ImageFolderExtented(data_dir, transform=transform)

    dl_full = data.DataLoader(
        ds_full,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
    )
    return dl_full

In [7]:
def calculate_features(model, dl_full):
    model = model.to(device)
    model.eval()
    target_list = []
    features_list = []
    filename_list = []

    for i, (images, targets, filenames) in enumerate(dl_full):
        # print(f"Batch {i}")
        filename_list.extend(filenames)

        numpy_targets = targets.numpy()
        target_list.append(numpy_targets)

        outputs = numpy_targets
        images = images.to(device)

        # NOTE: VitWrapper automatically returns results of last for blocks
        outputs = model(images)

        # numpy_outputs = None
        # if "BaseModelOutputWithPooling" in str(type(outputs)):
        #     preferred_shape = [numpy_targets.shape[0], -1] # Shape must be 2D
        #     numpy_outputs = outputs.last_hidden_state.cpu().numpy().reshape(preferred_shape)
        # else:
        numpy_outputs = outputs.cpu().numpy()
        features_list.append(numpy_outputs)

    print(f"Number of batches: {len(target_list)}")
    np_features = np.concatenate(features_list)
    np_target = np.concatenate(target_list)  # .reshape((np_features.shape[0], 1))
    df_full = pd.DataFrame(np_features)
    df_full["target_num"] = pd.Series(np_target)
    df_full["filename"] = pd.Series(filename_list)
    return df_full

In [8]:
def split_and_save_dataframe(df_features, csv_path):
    df_split = pd.read_csv(os.path.join(dataset_path, "split.csv"), index_col=0)
    assert (
        df_split["target_code"].unique().size == df_features["target_num"].unique().size
    )
    assert df_split["filename"].unique().size == df_features["filename"].unique().size
    df_split["filename"] = df_split["filename"].apply(lambda x: os.path.splitext(x)[0])
    df_features["filename"] = df_features["filename"].apply(
        lambda x: os.path.splitext(x)[0]
    )
    df_merged = pd.merge(df_features, df_split, on="filename")
    groups = df_merged.groupby(["target_code", "target_num"])
    assert groups["set"].count().size == df_split["target_code"].unique().size

    df_merged.drop(columns=["filename", "target_code"], inplace=True)
    cols = df_merged.columns.tolist()
    cols = cols[-2:] + cols[:-2]
    df_merged = df_merged[cols]
    df_merged.to_csv(csv_path)
    print(f"Csv file saved: {feature_path}")

In [9]:
df_config = pd.read_csv(configuration_csv_path)
counter = 0

df_config

for index, row in df_config.iterrows():
    dataset_path = str(row["dataset_path"])
    architecture = str(row["architecture"])
    weigths_path = str(row["weigths_path"])
    feature_path = str(row["feature_path"])

    target_filename = os.path.basename(feature_path)
    name_parts = target_filename.split(".")
    assert len(name_parts) == 2
    assert name_parts[1] == "csv"
    name_parts = name_parts[0].split("-")
    assert len(name_parts) == 3
    assert name_parts[0].lower() in dataset_path.replace("-", "_").lower()
    assert name_parts[1] == architecture
    assert name_parts[2] in weigths_path

    if os.path.exists(feature_path):
        print(f"Csv file already exists: {feature_path}")
        counter += 1
    else:
        if 0 < counter:
            print(f"Skipped {counter} existing files")
            counter = 0
        model = load_model(architecture, weigths_path)
        dataloader = load_dataloader(dataset_path)
        df_features = calculate_features(model, dataloader)
        split_and_save_dataframe(df_features, feature_path)

Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-Derma.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-ImageNet_1k_SL_WinKawaks.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-ImageNet_1k_SSL_Dino.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-ImageNet_AugReg.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-Plant.csv
Csv file already exists: ../datasets/intermediate-features/PAD_UFES_20-ViT_T16-Random.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ViT_T16-Derma.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ViT_T16-ImageNet_1k_SL_WinKawaks.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ViT_T16-ImageNet_1k_SSL_Dino.csv
Csv file already exists: ../datasets/intermediate-features/DDI-ViT_T16-ImageNet_AugReg.csv
Csv file already exists: ../datasets/intermediate-fea

Loading ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth
Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-Derma.csv


Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth


  context_layer = torch.nn.functional.scaled_dot_product_attention(


Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-ImageNet_1k_SL_WinKawaks.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth
Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-ImageNet_1k_SSL_Dino.csv
Loading ../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth
Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-ImageNet_AugReg.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth
Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-Plant.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth
Number of batches: 1338
Csv file saved: ../datasets/intermediate-features/Cassava-ViT_T16-Random.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-D

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-ImageNet_1k_SL_WinKawaks.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-ImageNet_1k_SSL_Dino.csv
Loading ../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-ImageNet_AugReg.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-Plant.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth
Number of batches: 3395
Csv file saved: ../datasets/intermediate-features/PlantVillage-ViT_T16-Random.csv
Loading ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_he

FileNotFoundError: [Errno 2] No such file or directory: '../datasets/PlantDataset/split.csv'

In [None]:
# df_features[df_features["filename"].duplicated(keep=False)]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,760,761,762,763,764,765,766,767,target_num,filename
334,1.301772,-0.293454,-2.359711,-1.382079,4.100875,-0.692655,2.21866,1.059823,-1.12755,-1.072324,...,5.492403,3.125957,0.019723,0.181436,1.861225,1.345096,-1.938499,2.245877,4,0.jpg
337,1.087374,1.962794,-1.745766,1.045243,5.741971,-0.497438,1.147758,0.12976,-0.239011,-0.562781,...,5.302965,-0.37335,2.505411,0.973601,2.960911,2.721305,-2.103797,1.077753,4,0000.jpg
583,0.058927,0.814614,-2.640271,0.698918,3.428504,1.822236,4.023545,-1.998618,-0.761803,-1.238528,...,3.983012,0.59139,2.002811,1.094554,3.116507,3.457204,0.535622,0.127342,7,2013Corn_GrayLeafSpot_0815_0003.JPG.jpg
585,-1.236956,0.618194,-2.329068,0.269733,2.379933,1.960748,1.841969,-0.499793,-0.583335,-1.098202,...,4.521903,-1.595904,3.805128,1.32774,1.866429,3.395298,2.675997,-0.03894,7,2015070295153021.jpg
612,0.054682,0.043818,-2.291487,-1.388277,4.698354,0.387614,0.548864,0.947637,-1.153519,-1.007654,...,5.55998,-1.224414,4.743848,1.843997,0.960791,1.988337,0.347505,1.559715,7,IMG_42231.jpg
619,1.304623,1.165353,-1.979944,0.460565,1.741875,1.873505,3.239161,0.603222,0.803573,-1.630752,...,2.175456,-0.408724,0.178493,1.150051,1.234284,2.39303,-2.088743,1.368775,7,corn-gray-leaf-spot-f4.jpg
669,0.282563,0.150028,-2.894969,0.519314,3.362071,1.009444,3.90726,-1.501823,-0.865108,-1.103484,...,5.152362,0.288276,2.74234,1.35631,3.423897,3.233586,0.870671,-0.276661,8,2013Corn_GrayLeafSpot_0815_0003.JPG.jpg
672,-1.236956,0.618194,-2.329068,0.269733,2.379933,1.960748,1.841969,-0.499793,-0.583335,-1.098202,...,4.521903,-1.595904,3.805128,1.32774,1.866429,3.395298,2.675997,-0.03894,8,2015070295153021.jpg
721,0.054682,0.043818,-2.291487,-1.388277,4.698354,0.387614,0.548864,0.947637,-1.153519,-1.007654,...,5.55998,-1.224414,4.743848,1.843997,0.960791,1.988337,0.347505,1.559715,8,IMG_42231.jpg
774,1.304623,1.165353,-1.979944,0.460565,1.741875,1.873505,3.239161,0.603222,0.803573,-1.630752,...,2.175456,-0.408724,0.178493,1.150051,1.234284,2.39303,-2.088743,1.368775,8,corn-gray-leaf-spot-f4.jpg
