In [10]:
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms
from torchvision.models import resnet50

In [11]:
seed = 19
img_size = 224
batch_size = 16
split_ratio_1 = 0.2  # (validation+test) / (train+validation+test)
split_ratio_2 = 0.5  # (test) / (validation+test)

configuration_csv_path = "configs/plant-tasks-configuration.csv"

In [12]:
torch.manual_seed(seed)

<torch._C.Generator at 0x1e1c8c6bfb0>

In [13]:
assert torch.cuda.is_available()
n_devices = torch.cuda.device_count()
for i in range(0, n_devices):
    print(torch.cuda.get_device_name(i))

device = torch.device("cuda")

# Duration of cassava mini:
# 3m 39.1s on cpu
# 0m 58.2s on gpu
# 0m 30.4s on gpu

NVIDIA GeForce GTX 960


In [14]:
def load_model(model_architecture, num_classes_weights, checkpoint_path):
    model = None

    if model_architecture == "resnet50":
        model = resnet50(weights=None)
        if num_classes_weights == 0:
            model.fc = nn.Sequential()
        else:
            model.fc = nn.Linear(model.fc.in_features, num_classes_weights)
        # print(f"model.fc.in_features: {model.fc.in_features}")
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        model.load_state_dict(checkpoint)
        # print(f"model.fc.out_features: {model.fc.out_features}")
        model.fc = nn.Sequential()
    elif model_architecture == "swin_t":
        model = torch.hub.load(
            "SharanSMenon/swin-transformer-hub:main",
            "swin_tiny_patch4_window7_224",
            pretrained=False,
        )
        model.head = nn.Linear(model.head.in_features, num_classes_weights)
        # print(f"model.head.in_features: {model.head.in_features}")
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        model.load_state_dict(checkpoint, strict=False)
        # print(f"model.head.out_features: {model.head.out_features}")
        model.head = nn.Sequential()
    elif model_architecture == "vit_b16":
        model = torch.hub.load(
            "facebookresearch/dino:main",
            "dino_vitb16",
            pretrained=False,
        )
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        model.load_state_dict(checkpoint, strict=False)
    else:
        print(f"Unknown model architecture: {model_architecture}")

    for param in model.parameters():
        param.requires_grad = False
    return model

In [15]:
def load_dataloader(data_dir):
    transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ]
    )
    ds_full = datasets.ImageFolder(data_dir, transform=transform)

    dl_full = data.DataLoader(
        ds_full,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
    )
    return dl_full

In [16]:
def calculate_features(model, dl_full):
    model = model.to(device)
    target_list = []
    features_list = []

    for i, (images, targets) in enumerate(dl_full):
        # print(f"Batch {i}")
        images = images.to(device)
        outputs = model(images)
        target_list.append(targets.numpy())
        features_list.append(outputs.cpu().numpy())

    print(f"Number of batches: {len(target_list)}")
    np_target = np.concatenate(target_list)
    np_features = np.concatenate(features_list)
    df_full = pd.DataFrame(np_features)
    df_full["target"] = pd.Series(np_target)
    return df_full

In [17]:
def split_and_save_dataframe(df_full, csv_path):
    df_train, df_valid_test = train_test_split(
        df_full,
        test_size=split_ratio_1,
        stratify=df_full["target"],
        random_state=seed,
    )
    df_valid, df_test = train_test_split(
        df_valid_test,
        test_size=split_ratio_2,
        stratify=df_valid_test["target"],
        random_state=seed,
    )
    df_train["set"] = "train"
    df_valid["set"] = "valid"
    df_test["set"] = "test"

    df_full = pd.concat([df_train, df_valid, df_test]).sort_index()

    cols = df_full.columns.tolist()
    cols = cols[-2:] + cols[:-2]
    df_full = df_full[cols]
    df_full.to_csv(csv_path)
    print(f"Csv file saved: {feature_path}")

In [19]:
df_config = pd.read_csv(configuration_csv_path, index_col=0)

for index, row in df_config.iterrows():
    architecture = row["architecture"]
    weigths_path = row["weigths_path"]
    weigths_num_classes = row["weigths_num_classes"]
    feature_path = row["feature_path"]
    dataset_path = row["dataset_path"]
    # assert architecture in weigths_path.lower()
    assert architecture in feature_path.lower()
    if os.path.exists(feature_path):
        print(f"Csv file already exists: {feature_path}")
    else:
        model = load_model(architecture, weigths_num_classes, weigths_path)
        dataloader = load_dataloader(dataset_path)
        df_features = calculate_features(model, dataloader)
        split_and_save_dataframe(df_features, feature_path)

Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-Random_19.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-Random_20.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-Random_21.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-ImageNet_v1.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-ImageNet_v2.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-ImageNet_SSL.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ResNet50-PDDD.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ViT_B16-Random_19.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ViT_B16-Random_20.csv
Csv file already exists: ../datasets/intermediate-features/Cassava_Mini-ViT_B16-Random_21.csv
Csv file already exists: ../datasets/intermediate-f