In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from UCroma import PretrainedCROMA
import copy

from main import CRLoader as Loader, FExtractor, Trainer, Tester

sample_size = 120
batch_size = 16

Croma = PretrainedCROMA(pretrained_path='CR.pt', size='base', modality='both', image_resolution=sample_size)
num_features = Croma(SAR_images=torch.randn(1, 2, sample_size, sample_size), 
                     optical_images=torch.randn(1, 12, sample_size, sample_size))['joint_GAP'].shape[1]
head = nn.Sequential(nn.Linear(num_features, 1000), nn.ReLU(), nn.Linear(1000, 1000), nn.ReLU(), nn.Linear(1000, 2))

opt_root_dir = "/home/stagiaire/D/R/patchs/90"
sar_root_dir = "/home/stagiaire/D/R/patchs/90R"

num_cv = 5
num_folds = 5
metrics = [[] for _ in range(num_cv)]

for n in range(num_cv):
    processor = Loader(opt_root_dir=opt_root_dir, sar_root_dir=sar_root_dir, num_folds=num_folds)
    processor.load_data(sample_size=sample_size, batch_size=batch_size, data_seed=None, split_seed=None)
    loaders = processor.loaders
    
    floaders = []
    for loader in loaders:
        feature_extractor = FExtractor(dataloader=loader, use_8_bit=True)
        features, labels = feature_extractor.extract_features()
        dataset = TensorDataset(torch.from_numpy(features), torch.from_numpy(labels))
        floaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))
        
    for i in range(num_folds):
        
        head_i = copy.deepcopy(head)
        
        train_loaders = [floaders[(j + i) % num_folds] for j in range(num_folds - 1)]
        test_loader = floaders[(i + num_folds - 1) % num_folds]
        
        for train_loader in train_loaders:
            trainer = Trainer(head_i, train_loader)
            trainer.train(num_epochs=100)
            
        tester = Tester(head_i, test_loader, processor.combined_dataset)
        tester.evaluate()
        
        metrics[n].append(tester.report)

In [None]:
from R import RGenerator
    
new_metrics = []
for l in metrics:
    new_metrics.extend(l)

reporter = RGenerator(new_metrics).report()

In [None]:
for l in metrics:
    reporter = RGenerator(l).report()