In [126]:
import json

# import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

# import re
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm

# from IPython.display import display_html
# from scipy.special import softmax
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# from torch.optim.lr_scheduler import LambdaLR
# from torchvision import transforms
from timm.models.vision_transformer import VisionTransformer

from local_python.local_utils import (
    load_model,
    load_pd_from_json,
    print_parameters,
    set_seed,
)
from local_python.dataset_util import (
    create_dataloaders,
)

# from local_python.feature_evaluation import calculate_scores
from lora_vit.models import LoRA_ViT_timm
from ssl_library.src.models.fine_tuning.classifiers import LinearClassifier

In [127]:
configuration_csv_path = "./configs/finetune-configuration.csv"

batch_size = 64
learning_rate = 1e-4

In [128]:
df_config = pd.read_csv(configuration_csv_path)
df_config

Unnamed: 0,seed,strategy,dataset_path,checkpoint_path
0,1,lora_2,../data_splits/HAM10000_split.csv,../model_weights/vit_t16_v2/ViT_T16-ImageNet_1...
1,1,lora_4,../data_splits/HAM10000_split.csv,../model_weights/vit_t16_v2/ViT_T16-ImageNet_1...
2,1,lora_6,../data_splits/HAM10000_split.csv,../model_weights/vit_t16_v2/ViT_T16-ImageNet_1...


In [129]:
def prepare_model(checkpoint_path, strategy, num_classes, image_shape):
    strategy_params = strategy.split("_")

    if "concat" == strategy_params[0]:
        rank = int(strategy_params[1])
        model = load_model(checkpoint_path, freeze=False, use_ssl_library=True)
        params = list(model.parameters())
        for param in params[: len(params) - rank]:
            param.requires_grad = False
        summary = torchinfo.summary(model, image_shape, batch_dim=0)
        last_output = summary.summary_list[-1].output_size[-1]
        model.head = LinearClassifier(
            last_output,
            num_labels=num_classes,
            use_dropout_in_head=True,
            large_head=False,
            use_bn=True,
        )
    elif "lora" == strategy_params[0]:
        rank = int(strategy_params[1])
        model = load_model(checkpoint_path, freeze=True, use_ssl_library=False)
        summary = torchinfo.summary(model, image_shape, batch_dim=0)
        last_output = summary.summary_list[-1].output_size[-1]
        assert hasattr(model, "blocks"), f"Unknown model type: {type(model)}"
        model.head = LinearClassifier(
            last_output,
            num_labels=num_classes,
            use_dropout_in_head=True,
            large_head=False,
            use_bn=True,
        )
        model = LoRA_ViT_timm(vit_model=model, r=rank, alpha=4)
    else:
        assert False, f"Unknown strategy: {strategy}"
    return model


def train_eval(
    model, optimizer, criterion, start_epoch, end_epoch, dataloaders, loss_file_path, best_loss = None
):
    model = model.to(device)
    for epoch in range(start_epoch, end_epoch):
        model.train()
        print(f"Training epoch {epoch}")
        with open(loss_file_path, "a") as detaillog:
            for i, (images, targets) in enumerate(tqdm(dataloaders["train"])):
                images = images.to(device)
                targets = torch.as_tensor(targets).to(device)

                outputs = model(images)
                loss = criterion(outputs, targets)

                line = {}
                line["epoch"] = epoch
                line["iteration"] = i
                line["loss"] = loss.item()
                line["set"] = "train"
                json.dump(line, detaillog, indent=2)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        torch.save(
            model.state_dict(),
            os.path.join(run_path, f"checkpoint_latest.pth"),
        )

        model.eval()
        valid_loss = []
        with torch.no_grad():
            with open(loss_file_path, "a") as detaillog:
                for i, (images, targets) in enumerate(tqdm(dataloaders["valid"])):
                    images = images.to(device)
                    targets = torch.as_tensor(targets).to(device)

                    outputs = model(images)
                    loss = criterion(outputs, targets)
                    valid_loss.append(loss.item())

                    line = {}
                    line["epoch"] = epoch
                    line["iteration"] = i
                    line["loss"] = valid_loss[-1]
                    line["set"] = "valid"
                    json.dump(line, detaillog, indent=2)
        
        mean_loss = np.array(valid_loss).mean()
        if best_loss is None or best_loss < mean_loss:
            best_loss = mean_loss
            torch.save(
                model.state_dict(),
                os.path.join(run_path, f"checkpoint_best.pth"),
            )

In [130]:
assert torch.cuda.is_available()
n_devices = torch.cuda.device_count()
for i in range(0, n_devices):
    print(torch.cuda.get_device_name(i))

device = torch.device("cuda")

NVIDIA GeForce GTX 960


In [131]:
for _, row in df_config.iterrows():
    seed = row["seed"]
    set_seed(seed)

    checkpoint_path = row["checkpoint_path"]
    model_name = os.path.splitext(os.path.basename(checkpoint_path))[0].replace(
        "_headless", ""
    )

    dataset_path = row["dataset_path"]
    dataset_name = os.path.splitext(os.path.basename(dataset_path))[0].replace(
        "_split", ""
    )

    strategy = row["strategy"]

    run_path = os.path.join("../runs/", dataset_name, model_name, f"{strategy}_{seed}")
    if not os.path.exists(run_path):
        os.makedirs(run_path)
    print(f"Results will be saved to {run_path}")
    loss_file_path = os.path.join(run_path, "loss.txt")

    dataloaders = create_dataloaders(dataset_path, batch_size=batch_size)
    train_class_counts = dataloaders["train"].dataset.get_class_counts()
    print(f"Train class (im)balance: {train_class_counts}")
    num_classes = len(train_class_counts)

    images, _ = next(iter(dataloaders["valid"]))
    image_shape = images.shape[1:]

    model = prepare_model(checkpoint_path, strategy, num_classes, image_shape)
    print_parameters(model)

    class_weights_tensor = torch.tensor(
        1.0 / np.array(list(train_class_counts.values())), dtype=torch.float
    )
    class_weights_tensor = class_weights_tensor.to(device)
    loss_function = nn.CrossEntropyLoss(weight=class_weights_tensor, reduction="mean")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    latest_epoch = -1
    best_loss = None
    checkpoint_path = os.path.join(run_path, f"checkpoint_latest.pth")
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        model.load_state_dict(checkpoint, strict=True)
        df_metrics = load_pd_from_json(loss_file_path)
        latest_epoch = df_metrics["epoch"].max()
        print(f"Latest epoch: {latest_epoch}")
        checkpoint_path = os.path.join(run_path, f"checkpoint_best.pth")
        if os.path.exists(checkpoint_path):
            df_losses = df_metrics[df_metrics["set"] == "valid"].groupby(["epoch"])["loss"].mean()
            best_epoch = df_losses.argmin()
            best_loss = df_losses.min()
            print(f"Best epoch: {best_epoch} with {best_loss}")

    start_epoch = latest_epoch + 1
    end_epoch = start_epoch + 30

    train_eval(
        model,
        optimizer,
        loss_function,
        start_epoch,
        end_epoch,
        dataloaders,
        loss_file_path,
        best_loss = best_loss,
    )

Setting seed to 1
Results will be saved to ../runs/HAM10000\ViT_T16-ImageNet_1k_SSL_Dino\lora_2_1
Set train size: 8908
Set valid size: 1103
Train class (im)balance: {'bkl': 971, 'nv': 5969, 'df': 103, 'mel': 992, 'vasc': 128, 'bcc': 455, 'akiec': 290}


Loading vit_tiny_patch16_224 from timm-library
Ignoring prefix 'model.'
Trainable parameters: 18432/5544583
Training epoch 0


100%|██████████| 140/140 [12:47<00:00,  5.48s/it]
100%|██████████| 18/18 [00:15<00:00,  1.19it/s]


Training epoch 1


 86%|████████▋ | 121/140 [14:14<01:52,  5.90s/it]

In [None]:
# torchinfo.summary(model, image_shape, batch_dim=0)

In [None]:
# seed = 19

# checkpoint_path = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth
# checkpoint_path = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth"
# checkpoint_path = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth"

# dataset_path = "../datasets/ddi-diverse-dermatology-images/split.csv"
# dataset_path = "../datasets/PAD-UFES-20/images/split.csv"
# dataset_path = "../datasets/HAM10000/images/split.csv"

In [None]:
# def eval(model, data_loader_valid, sample=False, verbose=False):
#     model = model.to(device)
#     model.eval()

#     label_tensor = torch.zeros(0)
#     pred_tensor = torch.zeros(0)

#     with torch.no_grad():
#         for images, targets in tqdm(data_loader_valid):
#             target_tensor = torch.as_tensor([label_map[target] for target in targets])
#             label_tensor = torch.cat([label_tensor, target_tensor])

#             images = images.to(device)
#             outputs = model(images)

#             _, preds = torch.max(outputs, 1)
#             pred_tensor = torch.cat([pred_tensor, preds.view(-1).cpu()])
#             if verbose:
#                 for output, target in zip(outputs, targets):
#                     probs = softmax(output.cpu().detach().numpy())
#                     max_idx = np.argmax(probs)
#                     print(
#                         f"Actual class {label_map[target]}, predicted class {max_idx} with probability {probs[max_idx]} "
#                     )

#             if sample:
#                 break

#         # target_tensor = target_tensor.to(device)

#     labels = label_tensor.numpy()
#     preds = pred_tensor.numpy()

#     label_names = [k for k, v in label_map.items() if (v in labels or v in preds)]
#     cm = confusion_matrix(labels, preds)
#     cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
#     cm_display.plot()
#     plt.ylabel("True label")
#     plt.xlabel("Predicted label")
#     plt.show()


# def load_checkpoint_by_epoch(model, epoch):
#     checkpoint = os.path.join(run_path, f"checkpoint_{epoch}.pth")
#     print(f"Loading {checkpoint}")
#     checkpoint = torch.load(checkpoint, map_location=torch.device("cpu"))
#     model.load_state_dict(checkpoint, strict=True)
#     return model


# def display_side_by_side(dfs, captions=[]):
#     html_string = ""
#     for i, df in enumerate(dfs):
#         styler = df.style.set_table_attributes("style='display:inline'")
#         if i < len(captions):
#             styler.set_caption(captions[i])
#         html_string += styler.to_html()
#     display_html(html_string, raw=True)

In [None]:
# available_epochs = []
# for name in os.listdir(path=run_path):
#     match = re.search("checkpoint_(\d+)\.pth", name)
#     if match:
#         available_epochs.append(int(match.group(1)))
# latest_epoch = -1
# if 0 < len(available_epochs):
#     latest_epoch = max(available_epochs)
#     print(f"{len(available_epochs)} checkpoints found. Latest epoch: {latest_epoch}")
#     model = load_checkpoint_by_epoch(model, latest_epoch)
# else:
#     print(f"No checkpoint found in {run_path}")

In [None]:
# eval(model, dl_valid)

In [None]:
# df_metrics = load_pd_from_json(loss_file_path)
# metric_columns = [
#     column
#     for column in df_metrics.columns.values
#     if column not in ["epoch", "iteration", "set"]
# ]
# df_metrics = (
#     df_metrics
#     .groupby(["set", "epoch"])[metric_columns]
#     .mean()
# ).reset_index()

# n = 3
# for metric_column in metric_columns:
#     df_temp = df_metrics[df_metrics["set"] == "valid"].sort_values(metric_column)
#     display_side_by_side(
#         [df_temp.head(n=n), df_temp.tail(n=n)],
#         [
#             f"{n} epochs with lowest validation {metric_column}:",
#             f"{n} epochs with highest validation {metric_column}:",
#         ],
#     )

In [None]:
# y_column = "loss"
# x_column = "epoch"
# fig, ax = plt.subplots()
# plt.figure()
# for set_name in df_metrics["set"].unique():
#     df_metrics[df_metrics["set"] == set_name].plot.line(
#         y=y_column, x=x_column, label=set_name, ax=ax
#     )
# plt.show()

In [None]:
# best_epoch = df_metrics["loss"].idxmin()
# print(f"Best epoch by validation loss is {best_epoch}")
# model = load_checkpoint_by_epoch(model, best_epoch)
# eval(model, dl_valid)