In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import json



from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from spdnetwork.optimizers import  MixOptimizer 
from spdnetwork.nn import LogEig
from Utils import get_fold_of_data
from DatasetManagement import DatasetManagement
from Models import Contrastive_CB3, SPDnet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%load_ext autoreload
%autoreload 2

  return torch._C._cuda_getDeviceCount() > 0


In [6]:




'''


'''

def generate_vectors(info_run): 
    
    target_shape = info_run['target_shape']
    sequence_embedding_features = info_run['sequence_embedding_features']
    weights_path = info_run['weights_path']

    model = Contrastive_CB3(device, target_shape, sequence_embedding_features, 'contrastive')
    model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
    
    
    # ===================================================================================================
    
    
    #LOAD DATA
    data_loader = info_run['data']
    
    # ===================================================================================================    
    #OBTENER TODAS LAS ACTIVACIONES DE UN FOWARD PASS

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output
        return hook

    activations = {}
    all_activations_t2 = []
    all_activations_adc = []
    all_activations_bval = []
    all_labels = [] 

    # Registro el hook en la capa 17
    model.t2_conv_branch[17].register_forward_hook(get_activation('t2_conv_branch'))
    model.adc_conv_branch[17].register_forward_hook(get_activation('adc_conv_branch'))
    model.bval_conv_branch[17].register_forward_hook(get_activation('bval_conv_branch'))

    model.eval()
    model.to(device)

    with torch.no_grad():
        for i, data in enumerate(data_loader, 0):
            inputs, labels = data
            for j in range(len(inputs)):
                inputs[j] = inputs[j].to(device).type(torch.float)
            labels = labels.to(device)
            
            outputs = model(inputs)
            
            all_labels.append(labels)
            
            all_activations_t2.append(activations['t2_conv_branch'])
            all_activations_adc.append(activations['adc_conv_branch'])
            all_activations_bval.append(activations['bval_conv_branch'])
            

    concatenated_activations = []

    for t2_act, adc_act, bval_act in zip(all_activations_t2, all_activations_adc, all_activations_bval):
        concatenated = torch.cat([t2_act, adc_act, bval_act], dim=1)
        concatenated_activations.append(concatenated)
        
   #print(concatenated_activations[7].shape) # torch.Size([28, 96, 6, 8, 8])
    print(len(concatenated_activations)) # 8


    all_labels_tensor = torch.cat(all_labels, dim=0)
            
    print(all_labels_tensor.shape)  # Debería imprimir el tamaño total de todos los labels

    # ===================================================================================================    


    #GENERAR LAS MATRICES SPD'S
    def get_spd(acts, type="gramm"):
        
        vect_acts = acts
        # h, w, d = acts.shape
        # vect_acts = acts.reshape(h * w, d)
        if type == "gramm":
            spd = vect_acts.T @ vect_acts
        elif type == "corr":
            spd = np.corrcoef(vect_acts.T)
        elif type == "cov":
            spd = np.cov(vect_acts.T)
        return spd

    all_spds = []

    # Recorro los 8 lotes de 32
    for batch in concatenated_activations:  # batch = [32, 96, 6, 8, 8]
        batch_size, channels, depth, height, width = batch.shape
        
        # Iterar sobre cada muestra en el lote
        for i in range(batch_size):
            sample = batch[i]  # sample = [96, 6, 8, 8]
            #print(f"Sample shape: {sample.shape}")
            
            sample_reshape = sample.permute(1, 2, 3, 0).reshape(-1, channels) # sample_reshape = torch.Size([384, 96])
            # print(sample_reshape.shape)
            
            activation_np = sample_reshape.cpu().numpy()
            
            spd_matrix = get_spd(activation_np, type="gramm")
            
            all_spds.append(spd_matrix)
                      


    all_spds_tensor = torch.tensor(all_spds)
    print(all_spds_tensor.shape)
    print(all_spds_tensor[0].shape)
    

    
    embbedings = all_spds_tensor
    final_labels = all_labels_tensor
    
    return embbedings, final_labels


    



In [88]:
    # =======================================================  RUN  =======================================================
    # =======================================================  RUN  ======================================================= 
    # =======================================================  RUN  ======================================================= 
    # =======================================================  RUN  ======================================================= 
    # =======================================================  RUN  ======================================================= 

        
    #DEFINO LOS DATOS
for k in range(5):
    
    
        x_train, x_validation ,y_train, y_validation, ids_train, ids_val, indexdes = get_fold_of_data(k, 0.80) #CAMBIAR

        val_data = DatasetManagement(x_validation, y_validation)
        train_data = DatasetManagement(x_train, y_train)

        val_loader = torch.utils.data.DataLoader(
            dataset = val_data,
            shuffle = False,
            batch_size = 32,
            pin_memory=False
        )

        train_loader = torch.utils.data.DataLoader(
            dataset = train_data,
            shuffle = False,
            batch_size = 32,
            pin_memory=False
        )

        #DEFINO EL DICCIONARIO PARA CORRER
        #=================================TRAIN=================================

        info_run = {
            'target_shape': (12, 32, 32),
            'sequence_embedding_features': 18432,
            'weights_path' : f'/data/ExperimentsPercentTRIPLETBaseline/Experiments_with_80.0%/models/mertash_contrastive_fold_{k+1}.pt',
            'data' : train_loader,
                    
            }

        embeddings_train, labels_train = generate_vectors(info_run)

        print(embeddings_train.shape)
        print(labels_train.shape)

        destination_path_train = f"/data/Embeddings80%/embeddingsTripletSPD/fold_{k+1}/train" #CAMBIAR
        os.makedirs(destination_path_train, exist_ok=True)

        # Nombres de archivo para guardar
        embeddings_file = os.path.join(destination_path_train, "embeddingsSPD_train.pt")
        labels_file = os.path.join(destination_path_train, "labelsSPD_train.pt")

        # Guardar los tensores de embeddings y labels
        torch.save(embeddings_train, embeddings_file)
        torch.save(labels_train, labels_file)

        print(f"Embeddings guardados en: {embeddings_file}")
        print(f"Labels guardados en: {labels_file}")





        # =================================VALIDATION=================================
        info_run = {
            'target_shape': (12, 32, 32),
            'sequence_embedding_features': 18432,
            'weights_path' : f'/data/ExperimentsPercentTRIPLETBaseline/Experiments_with_80.0%/models/mertash_contrastive_fold_{k+1}.pt',
            'data' : val_loader,
                    
            }

        embeddings_val, labels_val = generate_vectors(info_run)

        destination_path_val = f"/data/Embeddings80%/embeddingsTripletSPD/fold_{k+1}/val" #CAMBIAR
        os.makedirs(destination_path_val, exist_ok=True)

        # Nombres de archivo para guardar
        embeddings_file = os.path.join(destination_path_val, "embeddingsSPD_val.pt")
        labels_file = os.path.join(destination_path_val, "labelsSPD_val.pt")

        # Guardar los tensores de embeddings y labels
        torch.save(embeddings_val, embeddings_file)
        torch.save(labels_val, labels_file)

        print(f"Embeddings guardados en: {embeddings_file}")
        print(f"Labels guardados en: {labels_file}")





-------Cargando los datos del JSON revuelto con el 80.0 % en el fold 1-------
0.8
26
torch.Size([828, 1])
torch.Size([828, 96, 96])
torch.Size([96, 96])
torch.Size([828, 96, 96])
torch.Size([828, 1])
Embeddings guardados en: /data/Embeddings80%/embeddingsTripletSPD/fold_1/train/embeddingsSPD_train.pt
Labels guardados en: /data/Embeddings80%/embeddingsTripletSPD/fold_1/train/labelsSPD_train.pt
9
torch.Size([260, 1])
torch.Size([260, 96, 96])
torch.Size([96, 96])
Embeddings guardados en: /data/Embeddings80%/embeddingsTripletSPD/fold_1/val/embeddingsSPD_val.pt
Labels guardados en: /data/Embeddings80%/embeddingsTripletSPD/fold_1/val/labelsSPD_val.pt
-------Cargando los datos del JSON revuelto con el 80.0 % en el fold 2-------
0.8
26
torch.Size([824, 1])
torch.Size([824, 96, 96])
torch.Size([96, 96])
torch.Size([824, 96, 96])
torch.Size([824, 1])
Embeddings guardados en: /data/Embeddings80%/embeddingsTripletSPD/fold_2/train/embeddingsSPD_train.pt
Labels guardados en: /data/Embeddings80%/emb

In [97]:
i=5


train_embeddings_path = f'/data/Embeddings80%/embeddingsTripletSPD/fold_{i}/train/embeddingsSPD_train.pt'
train_labels_path = f'/data/Embeddings80%/embeddingsTripletSPD/fold_{i}/train/labelsSPD_train.pt'

train_embeddings = torch.load(train_embeddings_path)
train_labels = torch.load(train_labels_path)

val_embeddings_path = f'/data/Embeddings80%/embeddingsTripletSPD/fold_{i}/val/embeddingsSPD_val.pt'
val_labels_path = f'/data/Embeddings80%/embeddingsTripletSPD/fold_{i}/val/labelsSPD_val.pt'

val_embeddings = torch.load(val_embeddings_path)
val_labels = torch.load(val_labels_path)



print(train_embeddings.shape)
print(train_labels.shape)
print('')
print(val_embeddings.shape)
print(val_labels.shape)



torch.Size([834, 96, 96])
torch.Size([834, 1])

torch.Size([252, 96, 96])
torch.Size([252, 1])


In [98]:
fold_indexes = open('/data/json_index_shuffle.json', 'r')
indexdes = json.load(fold_indexes)


print(int(len(indexdes[f'Fold_{i-1}_train']) * 0.8))
print("")
print(len(indexdes[f'Fold_{i-1}_val']))

# array = []

# for i in range (5):
#     array.append(len(indexdes[f'Fold_{i}_val']))
    
# print(np.sum(array))

834

252


In [2]:
fold_indexes = open('/data/json_index_shuffle.json', 'r')
indexdes = json.load(fold_indexes)


len(indexdes['Fold_0_train']) * 0.8

828.0