In [None]:
import torch
import torch.nn as nn
import timm
import copy
from torch.utils.data import TensorDataset, DataLoader
import pickle

from main import Loader, Trainer, Tester, FeatureExtractor, RGenerator

timm_models = [
    "resnet18.a1_in1k",
    "resnet152.a1_in1k",
    "densenet201.tv_in1k",
    "resnext50_32x4d.ra_in1k",
    "efficientnet_b0.ra_in1k",
    "vit_small_patch16_224.dino",
    "vit_base_patch16_224.dino",
    "vit_base_patch32_224.sam_in1k",
    "vit_base_patch16_224.sam_in1k"
]

timm_models = ["vit_base_patch14_dinov2.lvd142m"]

sample_size = 518
root_dir = "/home/stagiaire/D/R/patchs/70"
cross_validations = 5
num_folds = 5
batch_size = 16

for model_name in timm_models:
    model = timm.create_model(model_name, pretrained=True, num_classes=0)
    num_features = model(torch.randn(1, 3, sample_size, sample_size)).shape[1]
    head = nn.Sequential(
        nn.Linear(num_features, 1000),
        nn.ReLU(),
        nn.Linear(1000, 1000),
        nn.ReLU(),
        nn.Linear(1000, 2)
    )

    metrics = [[] for _ in range(cross_validations)]

    for n in range(cross_validations):
        processor = Loader(root_dir=root_dir, num_folds=num_folds)
        processor.load_data(sample_size=sample_size, batch_size=batch_size)
        loaders = processor.loaders

        fe = FeatureExtractor(model_name)
        floaders = []
        for loader in loaders:
            features, labels = fe.get_features(loader)
            dataset = TensorDataset(torch.from_numpy(features), torch.from_numpy(labels))
            floaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))

        for i in range(num_folds):
            head_i = copy.deepcopy(head)

            train_loaders = [floaders[(j + i) % num_folds] for j in range(num_folds - 1)]
            test_loader = floaders[(i + num_folds - 1) % num_folds]

            for train_loader in train_loaders:
                trainer = Trainer(head_i, train_loader)
                trainer.train(num_epochs=40)

            tester = Tester(head_i, test_loader, processor.dataset)
            tester.evaluate()

            metrics[n].append(tester.report)

    with open(f"/home/stagiaire/D/B/metrics/{model_name}.pkl", 'wb') as f:
        pickle.dump(metrics, f)

    new_metrics = []
    for l in metrics:
        new_metrics.extend(l)
    print(f"{model_name}")
    reporter = RGenerator(new_metrics).report()
