In [None]:
import torch
import torch.nn as nn
import timm
import copy 
from torch.utils.data import TensorDataset, DataLoader
import pickle

from main import CRLoader as Loader, FExtractor, Trainer, Tester

sample_size = 120
batch_size = 16

num_features = 768
head = nn.Sequential(nn.Linear(num_features, 1000), nn.ReLU(), 
                     nn.Linear(1000, 1000), nn.ReLU(), 
                     nn.Linear(1000, 2))

for p in range(30, 91, 10):

    cross_validations = 5
    metrics = [[] for _ in range(cross_validations)]
    opt_root_dir = f"/home/stagiaire/D/R/patchs/{p}"
    sar_root_dir = f"/home/stagiaire/D/R/patchs/{p}R"

    print(f"Recouvrement: {p}%")
    
    for n in range(cross_validations):

        num_folds = 5
        processor = Loader(opt_root_dir=opt_root_dir, sar_root_dir=sar_root_dir, num_folds=num_folds)
        processor.load_data(sample_size=sample_size, batch_size=batch_size)
        loaders = processor.loaders

        floaders = []
        for loader in loaders:
            feature_extractor = FExtractor(dataloader=loader, use_8_bit=True)
            features, labels = feature_extractor.extract_features()
            dataset = TensorDataset(torch.from_numpy(features), torch.from_numpy(labels))
            floaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))

        for i in range(num_folds):
            
            head_i = copy.deepcopy(head)

            train_loaders = [floaders[(j + i) % num_folds] for j in range(num_folds - 1)]
            test_loader = floaders[(i + num_folds - 1) % num_folds]
            
            for train_loader in train_loaders:
                trainer = Trainer(head_i, train_loader)
                trainer.train(num_epochs=40)
                
            tester = Tester(head_i, test_loader, processor.combined_dataset)
            tester.evaluate()
            metrics[n].append(tester.report)
         
    with open(f"/home/stagiaire/D/R/metrics/{p}.pkl", 'wb') as f:
        pickle.dump(metrics, f)