In [29]:
import json
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
from timm.models.vision_transformer import VisionTransformer

from local_python.general_utils import (
    load_model,
    load_pd_from_json,
    load_values_from_previous_epochs,
    print_parameters,
    set_seed,
)
from local_python.dataset_util import (
    create_dataloaders,
)

from lora_vit.models import LoRA_ViT_timm
from ssl_library.src.models.fine_tuning.classifiers import LinearClassifier

In [30]:
configuration_csv_path = "./configs/finetune-configuration.csv"
learning_rate = 1e-4

In [31]:
df_config = pd.read_csv(configuration_csv_path)
df_config

Unnamed: 0,seed,strategy,dataset_path,checkpoint_path
0,1,lora_2,../data_splits/HAM10000_split.csv,../model_weights/vit_t16_v2/ViT_T16-ImageNet_1...
1,1,lora_4,../data_splits/HAM10000_split.csv,../model_weights/vit_t16_v2/ViT_T16-ImageNet_1...


In [1]:
def prepare_model(checkpoint_path, strategy, num_classes, image_shape):
    strategy_params = strategy.split("_")

    if "concat" == strategy_params[0]:
        rank = int(strategy_params[1])
        model = load_model(checkpoint_path, freeze=False, use_ssl_library=True)
        params = list(model.parameters())
        for param in params[: len(params) - rank]:
            param.requires_grad = False
        summary = torchinfo.summary(model, image_shape, batch_dim=0)
        last_output = summary.summary_list[-1].output_size[-1]
        model.head = LinearClassifier(
            last_output,
            num_labels=num_classes,
            use_dropout_in_head=True,
            large_head=False,
            use_bn=True,
        )
    elif "lora" == strategy_params[0]:
        rank = int(strategy_params[1])
        model = load_model(checkpoint_path, freeze=True, use_ssl_library=False)
        summary = torchinfo.summary(model, image_shape, batch_dim=0)
        last_output = summary.summary_list[-1].output_size[-1]
        assert hasattr(model, "blocks"), f"Unknown model type: {type(model)}"
        model.head = LinearClassifier(
            last_output,
            num_labels=num_classes,
            use_dropout_in_head=True,
            large_head=False,
            use_bn=True,
        )
        model = LoRA_ViT_timm(vit_model=model, r=rank, alpha=4)
    else:
        assert False, f"Unknown strategy: {strategy}"

    return model


def train_eval(
    model, optimizer, criterion, start_epoch, end_epoch, dataloaders, loss_file_path, best_loss = None
):
    model = model.to(device)
    for epoch in range(start_epoch, end_epoch):
        model.train()
        print(f"Training epoch {epoch}")
        with open(loss_file_path, "a") as detaillog:
            for i, (images, targets) in enumerate(tqdm(dataloaders["train"])):
                images = images.to(device)
                targets = torch.as_tensor(targets).to(device)

                outputs = model(images)
                loss = criterion(outputs, targets)

                line = {}
                line["epoch"] = epoch
                line["iteration"] = i
                line["loss"] = loss.item()
                line["set"] = "train"
                json.dump(line, detaillog, indent=2)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        torch.save(
            model.state_dict(),
            os.path.join(run_path, f"checkpoint_latest.pth"),
        )

        model.eval()
        valid_loss = []
        with torch.no_grad():
            with open(loss_file_path, "a") as detaillog:
                for i, (images, targets) in enumerate(tqdm(dataloaders["valid"])):
                    images = images.to(device)
                    targets = torch.as_tensor(targets).to(device)

                    outputs = model(images)
                    loss = criterion(outputs, targets)
                    valid_loss.append(loss.item())

                    line = {}
                    line["epoch"] = epoch
                    line["iteration"] = i
                    line["loss"] = valid_loss[-1]
                    line["set"] = "valid"
                    json.dump(line, detaillog, indent=2)
        
        mean_loss = np.array(valid_loss).mean()
        if best_loss is None or mean_loss < best_loss:
            best_loss = mean_loss
            torch.save(
                model.state_dict(),
                os.path.join(run_path, f"checkpoint_best.pth"),
            )


In [33]:
assert torch.cuda.is_available()
n_devices = torch.cuda.device_count()
for i in range(0, n_devices):
    print(torch.cuda.get_device_name(i))

device = torch.device("cuda")

NVIDIA GeForce GTX 960


In [34]:
for _, row in df_config.iterrows():
    seed = row["seed"]
    set_seed(seed)

    checkpoint_path = row["checkpoint_path"]
    model_name = os.path.splitext(os.path.basename(checkpoint_path))[0].replace(
        "_headless", ""
    )

    dataset_path = row["dataset_path"]
    dataset_name = os.path.splitext(os.path.basename(dataset_path))[0].replace(
        "_split", ""
    )

    strategy = row["strategy"]

    run_path = os.path.join("../runs/", dataset_name, model_name, f"{strategy}_{seed}")
    if not os.path.exists(run_path):
        os.makedirs(run_path)
    print(f"Results will be saved to {run_path}")
    

    batch_size = 64
    if strategy.startswith("lora"):
        batch_size = 32

    dataloaders = create_dataloaders(dataset_path, batch_size=batch_size)
    train_class_counts = dataloaders["train"].dataset.get_class_counts()
    print(f"Train class (im)balance: {train_class_counts}")
    num_classes = len(train_class_counts)

    images, _ = next(iter(dataloaders["valid"]))
    image_shape = images.shape[1:]

    model = prepare_model(checkpoint_path, strategy, num_classes, image_shape)
    print_parameters(model)

    class_weights_tensor = torch.tensor(
        1.0 / np.array(list(train_class_counts.values())), dtype=torch.float
    )
    class_weights_tensor = class_weights_tensor.to(device)
    loss_function = nn.CrossEntropyLoss(weight=class_weights_tensor, reduction="mean")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    (loss_file_path, latest_epoch, best_loss) = load_values_from_previous_epochs(run_path)

    checkpoint_path = os.path.join(run_path, f"checkpoint_latest.pth")
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        model.load_state_dict(checkpoint, strict=True)

    checkpoint_path = os.path.join(run_path, f"checkpoint_best.pth")
    if not os.path.exists(checkpoint_path):
        best_loss = None # reset    

    start_epoch = latest_epoch + 1
    end_epoch = start_epoch + 30

    train_eval(
        model,
        optimizer,
        loss_function,
        start_epoch,
        end_epoch,
        dataloaders,
        loss_file_path,
        best_loss = best_loss,
    )

Setting seed to 1
Results will be saved to ../runs/HAM10000\ViT_T16-ImageNet_1k_SSL_Dino\lora_2_1
Set train size: 8908
Set valid size: 1103
Train class (im)balance: {'bkl': 971, 'nv': 5969, 'df': 103, 'mel': 992, 'vasc': 128, 'bcc': 455, 'akiec': 290}
Loading vit_tiny_patch16_224 from timm-library
Ignoring prefix 'model.'
Trainable parameters: 18432/5544583
Read 19171 entries from loss.txt
Latest epoch: 62
Best epoch: 61 with 1.2361385055950709
Training epoch 63


100%|██████████| 279/279 [01:59<00:00,  2.33it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 64


100%|██████████| 279/279 [02:00<00:00,  2.32it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 65


100%|██████████| 279/279 [01:59<00:00,  2.34it/s]
100%|██████████| 35/35 [00:14<00:00,  2.46it/s]


Training epoch 66


100%|██████████| 279/279 [02:01<00:00,  2.29it/s]
100%|██████████| 35/35 [00:15<00:00,  2.29it/s]


Training epoch 67


100%|██████████| 279/279 [02:02<00:00,  2.28it/s]
100%|██████████| 35/35 [00:14<00:00,  2.44it/s]


Training epoch 68


100%|██████████| 279/279 [01:59<00:00,  2.33it/s]
100%|██████████| 35/35 [00:13<00:00,  2.54it/s]


Training epoch 69


100%|██████████| 279/279 [02:00<00:00,  2.31it/s]
100%|██████████| 35/35 [00:15<00:00,  2.27it/s]


Training epoch 70


100%|██████████| 279/279 [02:03<00:00,  2.27it/s]
100%|██████████| 35/35 [00:14<00:00,  2.38it/s]


Training epoch 71


100%|██████████| 279/279 [01:56<00:00,  2.39it/s]
100%|██████████| 35/35 [00:13<00:00,  2.54it/s]


Training epoch 72


100%|██████████| 279/279 [01:58<00:00,  2.36it/s]
100%|██████████| 35/35 [00:14<00:00,  2.49it/s]


Training epoch 73


100%|██████████| 279/279 [01:59<00:00,  2.33it/s]
100%|██████████| 35/35 [00:16<00:00,  2.17it/s]


Training epoch 74


100%|██████████| 279/279 [02:06<00:00,  2.20it/s]
100%|██████████| 35/35 [00:14<00:00,  2.44it/s]


Training epoch 75


100%|██████████| 279/279 [01:56<00:00,  2.40it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 76


100%|██████████| 279/279 [01:55<00:00,  2.41it/s]
100%|██████████| 35/35 [00:18<00:00,  1.92it/s]


Training epoch 77


100%|██████████| 279/279 [02:03<00:00,  2.25it/s]
100%|██████████| 35/35 [00:13<00:00,  2.60it/s]


Training epoch 78


100%|██████████| 279/279 [02:06<00:00,  2.21it/s]
100%|██████████| 35/35 [00:15<00:00,  2.22it/s]


Training epoch 79


100%|██████████| 279/279 [02:03<00:00,  2.26it/s]
100%|██████████| 35/35 [00:14<00:00,  2.35it/s]


Training epoch 80


100%|██████████| 279/279 [02:13<00:00,  2.09it/s]
100%|██████████| 35/35 [00:15<00:00,  2.32it/s]


Training epoch 81


100%|██████████| 279/279 [02:09<00:00,  2.16it/s]
100%|██████████| 35/35 [00:17<00:00,  2.02it/s]


Training epoch 82


100%|██████████| 279/279 [02:04<00:00,  2.25it/s]
100%|██████████| 35/35 [00:15<00:00,  2.32it/s]


Training epoch 83


100%|██████████| 279/279 [02:02<00:00,  2.28it/s]
100%|██████████| 35/35 [00:13<00:00,  2.55it/s]


Training epoch 84


100%|██████████| 279/279 [01:56<00:00,  2.40it/s]
100%|██████████| 35/35 [00:16<00:00,  2.07it/s]


Training epoch 85


100%|██████████| 279/279 [01:58<00:00,  2.35it/s]
100%|██████████| 35/35 [00:14<00:00,  2.44it/s]


Training epoch 86


100%|██████████| 279/279 [02:06<00:00,  2.21it/s]
100%|██████████| 35/35 [00:14<00:00,  2.35it/s]


Training epoch 87


100%|██████████| 279/279 [02:03<00:00,  2.26it/s]
100%|██████████| 35/35 [00:14<00:00,  2.35it/s]


Training epoch 88


100%|██████████| 279/279 [02:05<00:00,  2.23it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 89


100%|██████████| 279/279 [02:03<00:00,  2.26it/s]
100%|██████████| 35/35 [00:13<00:00,  2.52it/s]


Training epoch 90


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:14<00:00,  2.47it/s]


Training epoch 91


100%|██████████| 279/279 [01:55<00:00,  2.43it/s]
100%|██████████| 35/35 [00:13<00:00,  2.53it/s]


Training epoch 92


100%|██████████| 279/279 [01:55<00:00,  2.41it/s]
100%|██████████| 35/35 [00:13<00:00,  2.54it/s]


Setting seed to 1
Results will be saved to ../runs/HAM10000\ViT_T16-ImageNet_1k_SSL_Dino\lora_4_1
Set train size: 8908
Set valid size: 1103
Train class (im)balance: {'bkl': 971, 'nv': 5969, 'df': 103, 'mel': 992, 'vasc': 128, 'bcc': 455, 'akiec': 290}
Loading vit_tiny_patch16_224 from timm-library
Ignoring prefix 'model.'
Trainable parameters: 36864/5563015
Read 18840 entries from loss.txt
Latest epoch: 59
Best epoch: 57 with 1.1211333206721714
Training epoch 60


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:14<00:00,  2.38it/s]


Training epoch 61


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:14<00:00,  2.43it/s]


Training epoch 62


100%|██████████| 279/279 [01:57<00:00,  2.38it/s]
100%|██████████| 35/35 [00:14<00:00,  2.47it/s]


Training epoch 63


100%|██████████| 279/279 [01:58<00:00,  2.36it/s]
100%|██████████| 35/35 [00:13<00:00,  2.59it/s]


Training epoch 64


100%|██████████| 279/279 [02:02<00:00,  2.28it/s]
100%|██████████| 35/35 [00:18<00:00,  1.92it/s]


Training epoch 65


100%|██████████| 279/279 [01:55<00:00,  2.41it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 66


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.58it/s]


Training epoch 67


100%|██████████| 279/279 [01:55<00:00,  2.41it/s]
100%|██████████| 35/35 [00:13<00:00,  2.56it/s]


Training epoch 68


100%|██████████| 279/279 [01:56<00:00,  2.40it/s]
100%|██████████| 35/35 [00:13<00:00,  2.54it/s]


Training epoch 69


100%|██████████| 279/279 [01:54<00:00,  2.43it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 70


100%|██████████| 279/279 [01:54<00:00,  2.43it/s]
100%|██████████| 35/35 [00:13<00:00,  2.56it/s]


Training epoch 71


100%|██████████| 279/279 [01:53<00:00,  2.47it/s]
100%|██████████| 35/35 [00:13<00:00,  2.61it/s]


Training epoch 72


100%|██████████| 279/279 [02:02<00:00,  2.28it/s]
100%|██████████| 35/35 [00:13<00:00,  2.52it/s]


Training epoch 73


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 74


100%|██████████| 279/279 [01:57<00:00,  2.37it/s]
100%|██████████| 35/35 [00:14<00:00,  2.47it/s]


Training epoch 75


100%|██████████| 279/279 [02:00<00:00,  2.31it/s]
100%|██████████| 35/35 [00:14<00:00,  2.45it/s]


Training epoch 76


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.63it/s]


Training epoch 77


100%|██████████| 279/279 [01:57<00:00,  2.37it/s]
100%|██████████| 35/35 [00:14<00:00,  2.47it/s]


Training epoch 78


100%|██████████| 279/279 [01:56<00:00,  2.40it/s]
100%|██████████| 35/35 [00:13<00:00,  2.53it/s]


Training epoch 79


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.57it/s]


Training epoch 80


100%|██████████| 279/279 [01:56<00:00,  2.39it/s]
100%|██████████| 35/35 [00:13<00:00,  2.51it/s]


Training epoch 81


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.61it/s]


Training epoch 82


100%|██████████| 279/279 [01:53<00:00,  2.47it/s]
100%|██████████| 35/35 [00:15<00:00,  2.32it/s]


Training epoch 83


100%|██████████| 279/279 [01:54<00:00,  2.43it/s]
100%|██████████| 35/35 [00:14<00:00,  2.49it/s]


Training epoch 84


100%|██████████| 279/279 [01:54<00:00,  2.45it/s]
100%|██████████| 35/35 [00:13<00:00,  2.59it/s]


Training epoch 85


100%|██████████| 279/279 [01:53<00:00,  2.46it/s]
100%|██████████| 35/35 [00:13<00:00,  2.52it/s]


Training epoch 86


100%|██████████| 279/279 [01:56<00:00,  2.40it/s]
100%|██████████| 35/35 [00:15<00:00,  2.32it/s]


Training epoch 87


100%|██████████| 279/279 [01:56<00:00,  2.39it/s]
100%|██████████| 35/35 [00:13<00:00,  2.57it/s]


Training epoch 88


100%|██████████| 279/279 [01:55<00:00,  2.41it/s]
100%|██████████| 35/35 [00:13<00:00,  2.53it/s]


Training epoch 89


100%|██████████| 279/279 [01:55<00:00,  2.42it/s]
100%|██████████| 35/35 [00:13<00:00,  2.58it/s]
