In [None]:

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F


#### Simple Neural Network          
class NeuralNetwork(nn.Module):
    def __init__(self, net_params):
        super().__init__()
        
        # Get network parameters
        in_dimension = net_params['in_features']
        out_dimension = net_params['out_features']
       
        self.net=nn.Sequential(nn.Linear(in_dimension, round(in_dimension / 2)),  
                            nn.ReLU(inplace=True),
                            nn.Dropout(inplace=False, p=0.5),
                            nn.Linear(round(in_dimension / 2), out_dimension))


    def forward(self, x):
        return self.net(x)
    
#### Attention Neural Netwrok
# Attention Units
class Attention(nn.Module):
    def __init__(self,net_params):
        super(Attention, self).__init__()
        self.M = net_params['in_features'] #Input dimension of the Values NV vectors 
        self.L = net_params['decom_space'] # Dimension of Q(uery),K(eys) decomposition space
        self.ATTENTION_BRANCHES = net_params['ATTENTION_BRANCHES']


        self.attention = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix V
            nn.Tanh(),
            nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)
        )


    def forward(self, x):

        # H feature vector matrix  # NV vectors x M dimensions
        H = x.squeeze(0)
        # Attention weights
        A = self.attention(H)  # NVxATTENTION_BRANCHES
        A = torch.transpose(A, 1, 0)  # ATTENTION_BRANCHESxNV
        A = F.softmax(A, dim=1)  # softmax over NV
        
        # Context Vector (Attention Aggregation)
        Z = torch.mm(A, H)  # ATTENTION_BRANCHESxM 
        
        return Z, A


class GatedAttention(nn.Module):
    def __init__(self,net_params):
        super(GatedAttention, self).__init__()
        self.M = net_params['in_features'] #Input dimension of the Values NV vectors 
        self.L = net_params['decom_space'] # Dimension of Q(uery),K(eys) decomposition space
        self.ATTENTION_BRANCHES = net_params['ATTENTION_BRANCHES']
        
        # Matrix for Query decomposition
        self.attention_V = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix V
            nn.Tanh()
        )
        # Matrix for Keys decomposition
        self.attention_U = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix U
            nn.Sigmoid()
        )

        self.attention_w = nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)


    def forward(self, x):
        # H feature vector matrix  # NV vectors x M dimensions
        H = x.squeeze(0)
        ## Self Attention weights
        # Input Vector Query Decomposition, Q
        A_V = self.attention_V(H)  # NVxL (Projecion of the V input vectors into L dim space)
        # Input Vector Keys Decomposition, K
        A_U = self.attention_U(H)  # NVxL
        # Attention Matrix from Product Q*K 
        A = self.attention_w(A_V * A_U) # element wise multiplication # NVxATTENTION_BRANCHES
        A = torch.transpose(A, 1, 0)  # ATTENTION_BRANCHESxNV
        A = F.softmax(A, dim=1)  # softmax over NV dimension
        
        ## Context Vector (Attention Aggregation)
        Z = torch.mm(A, H)  # ATTENTION_BRANCHESxM

        return Z, A

# Llibreries

In [None]:
import os
import glob
import random
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torchvision.transforms as transforms
import math
import itertools
from torch import nn as nn
import torch.nn.functional as F
from torch import Tensor
from numpy.matlib import repmat
import torch.optim as optim
from collections import OrderedDict
import cv2
import wandb
import matplotlib.pyplot as plt

# LOAD CROPPED

In [None]:
import pandas as pd
import numpy as np
import cv2
import os
import glob
import random


def load_cropped(folder_path, csv_path, patient_list = [], sample_size=200):
    # Cargar el CSV con pandas
    df = pd.read_csv(csv_path)
    
    # Convertir el CSV a un diccionario para un acceso rápido
    patient_metadata = {row['CODI']: row['DENSITAT'] for _, row in df.iterrows()}

    # Inicializar la estructura de datos para almacenar los datos de los pacientes seleccionados
    patients_data = []
    images_list = []

    # si no se proporciona una lista de pacientes, se seleccionan todos los pacientes iterando folder_path
    if not patient_list:
        for patient_folder in glob.glob(os.path.join(folder_path, "*")):
            patient_id = os.path.basename(patient_folder).split("_")[0]
            patient_list.append(patient_id)
    
    incorrect_shape = 0
    # Iterar sobre cada paciente en la lista de IDs proporcionada
    for i,patient_id in enumerate(patient_list):
        print('--- Loading data from patient:', patient_id, f'(--- {i+1}/{len(patient_list)}) ---')
        # Obtener carpeta del paciente
        try:
            patient_folder = glob.glob(os.path.join(folder_path, f"{patient_id}_*"))[0]
        except:
            print(f'Patient {patient_id} not found in the dataset folder')
            continue

        # Verificar que el paciente esté en el CSV
        if patient_id in patient_metadata:
            # Obtener todas las imágenes .png dentro de la carpeta del paciente
            images = glob.glob(os.path.join(patient_folder, "*.png"))
            
            # Si el paciente tiene imágenes en su carpeta
            if images:
                # Mezclar la lista de imágenes
                random.shuffle(images)
                
                # Seleccionar una muestra de tamaño sample_size o menos si hay menos imágenes
                images_sampled = random.sample(images, min(sample_size, len(images)))
                
                for j,image_path in enumerate(images_sampled):
                    if j % 100 == 0:
                        print(f'------ Loading image {j}/{len(images_sampled)} ---')

                    # Cargar la imagen en formato BGR con cv2
                    image_bgr = cv2.imread(image_path)
                    
                    # Convertir la imagen de BGR a RGB
                    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

                    if image_rgb.shape[0] != 256 or image_rgb.shape[1] != 256:
                        incorrect_shape +=1
                        # Resize de la imagen a 256x256
                        image_rgb = cv2.resize(image_rgb, (256, 256))

                    # Permutar canales
                    image_rgb = np.transpose(image_rgb, (2, 0, 1))

                    # Pasar de uint8 a float32
                    image_rgb = image_rgb/255.0
                    
                    # Añadir la imagen a la lista de imágenes en formato RGB
                    images_list.append(image_rgb)
                # Binariar densidad
                if patient_metadata[patient_id] == "NEGATIVA":
                    dens = 0
                else:
                    dens = 1
                
                # Añadir la densidad a la lista de metadatos
                patients_data.extend([dens] * len(images_sampled))

    print(f'Resized images: {incorrect_shape}/{len(images_list)}')
    
    return images_list, patients_data

# Paso 2: Crear la clase Standard_Dataset
class Standard_Dataset(data.Dataset):
    def __init__(self, X, Y=None, transformation=None):
        super().__init__()
        self.X = X
        self.y = Y
        self.transformation = transformation
 
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        image = self.X[idx]
        if self.transformation:
            image = self.transformation(image)
            
        image_tensor = torch.from_numpy(image).float()
        
        if self.y is not None:
            label = torch.tensor(self.y[idx], dtype=torch.float32)
            return image_tensor, label
        else:
            return image_tensor

# Paso 3: Cargar los datos
folder_path = "C:/Users/mirvi/Desktop/mii/UAB/4.1/PSIV2/detect mateicules/repte3/psiv-repte3/data/Cropped_sample"
csv_path = "C:/Users/mirvi/Desktop/mii/UAB/4.1/PSIV2/detect mateicules/repte3/psiv-repte3/data/PatientDiagnosis.csv"
patient_list = ['B22-25']
sample_size = 1

# Filtrando imágenes y etiquetas para obtener solo aquellas cuya densidad es 0
images, labels = load_cropped(folder_path, csv_path, patient_list, sample_size)



# AUTOENCODER


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def init_weights_transformer(module,InitParams):
    Type=InitParams['Type']
    U=InitParams['U']
    
    if Type=='ori':
        init_weights_transformer_ori(module)
    elif Type=='all':
        init_weights_transformer_all(module,U)       
    else:
        init_weights_transformer_exceptnorm(module,U) 
    
def init_weights_transformer_all(module,U=[0,1]):
    param=module.state_dict()
   
    for name in param.keys():
        if name.find('norm')<0:
           nn.init.uniform_(param[name],a=U[0],b=U[1])

                
def init_weights_transformer_exceptnorm(module,U=[0,1]):
    param=module.state_dict()
   
    for name in param.keys():
        if name.find('norm')<0:
           if (name.find('weight')>0):
                nn.init.uniform_(param[name],a=U[0],b=U[1])
           elif (name.find('bias')>0):
                nn.init.constant_(param[name], 0)
    #        
def init_weights_transformer_ori(module):
    initrange = 0.1
    nn.init.uniform_(module.encoder.weight, -initrange, initrange)
    nn.init.zeros_(module.decoder1[0].bias)
    nn.init.zeros_(module.decoder2.bias)
    nn.init.uniform_(module.decoder1[0].weight, -initrange, initrange)
    nn.init.uniform_(module.decoder2.weight, -initrange, initrange)
    

def init_weights_xavier_normal(module):
    for m in module.modules():
        if isinstance(m, nn.Conv1d):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv3d):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LSTM):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    nn.init.orthogonal_(param.data)
                else:
                    nn.init.normal_(param.data)

def init_weights_xavier_uniform(module):
    for m in module.modules():
        if isinstance(m, nn.Conv1d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv3d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LSTM):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    nn.init.orthogonal_(param.data)
                else:
                    nn.init.uniform_(param.data)

def init_weights_kaiming_uniform(module):
    for m in module.modules():
        if isinstance(m, nn.Conv1d):
            nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv3d):
            nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, val=0.)

def init_weights_kaiming_normal(module):
    for m in module.modules():
        if isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, val=0.)

In [None]:
### Linear FC Blocks
def linear_block(n_inputs_loc, hidden_loc, 
                 activ_config=None,batch_config=None,p_drop_loc=0.1): 
    
    # Dictionary defining Block Architecture
    BlockArchitecture=[]
   
    hidden_loc.insert(0,n_inputs_loc)
  
    if activ_config==None:
        activ_config=repmat('no_activ',len(hidden_loc),1)
    if batch_config==None:
        batch_config=repmat('no_batch',len(hidden_loc),1)
    #Block Layers List
    for i in np.arange(len(hidden_loc)-1):
        BlockArchitecture.append(('linear'+str(i+1),
                                  nn.Linear(hidden_loc[i], hidden_loc[i+1])))
        
        if(activ_config[i]=='relu'):
            BlockArchitecture.append(('relu'+str(i+1),nn.ReLU(inplace=True)))
           
        elif(activ_config[i]=='tanh'):
            BlockArchitecture.append(('tanh'+str(i+1),nn.Tanh()))
        elif(activ_config[i]=='relu6'):
             BlockArchitecture.append(('relu6'+str(i+1),nn.ReLU6(inplace=True)))
             
        if(batch_config[i]=='batch'):
            BlockArchitecture.append(('batch'+str(i+1),nn.BatchNorm1d( hidden_loc[i+1])))
         
        BlockArchitecture.append(('drop'+str(i+1),nn.Dropout(p_drop_loc)))  
    linear_block_loc = nn.Sequential(
        OrderedDict(BlockArchitecture)
        )
    return linear_block_loc


class LinearBlock(nn.Module):
    """
    MultiLayer Perceptron: 
    Netwotk with n_hidden layers with architecture linear+drop+relu+batch
     Constructor Parameters:
           n_inputs: dimensionality of input features (n_channels * n_features , by default) 
                     n_channels (=14), number of sensors or images for each case
                     n_features(=40), number of features for each n_channels
           n_classes: number of output classes (=3, by default)
           hidden(=[128,128], default): list with the number of neurons for each hidden layer
           p_drop(=0.1, default): probability for Drop layer (=0, no drop is performed)

    """
    
    def __init__(self, inputmodule_params,net_params):
        super().__init__()

        
       
        ### Input Parameters
        self.n_inputs = inputmodule_params['n_inputs']

       
        self.hidden=net_params['hidden']
        self.dropout=net_params['dropout']
        if net_params['dropout'] is None:
            self.dropout=0.5
        self.nlayers=len(self.hidden)
        if 'activ_config' not in list(net_params.keys()):
    
            self.activ_config=None
        else:
             self.activ_config=net_params['activ_config']
        
        if 'batch_config' not in list(net_params.keys()):
            self.batch_config=None
        else:
            self.batch_config=net_params['batch_config']
             
              
        
        self.linear_block0= linear_block(self.n_inputs, self.hidden.copy(), 
                                                 activ_config=self.activ_config, 
                                                 batch_config=self.batch_config,
                                                 p_drop_loc=self.dropout)

       
        
      #  self.fc_out=nn.Identity()
        # weight init
        init_weights_xavier_normal(self)

    def forward(self, x):
              
      
        return self.linear_block0(x)

### Convolutional 
class _CNNLayer(nn.Module):
    def __init__(
        self, num_input_features: int, n_neurons: int, kernel_sze:int =3, 
        stride:int=1,
        drop_rate: float=0,
        Relu=True
         ) -> None:
        super().__init__()
        

        norm1 = nn.BatchNorm2d(n_neurons)
        conv1 = nn.Conv2d(num_input_features, n_neurons, kernel_size=kernel_sze,  
                               stride=stride, padding=(int((kernel_sze-1)/2)))

      #  relu1 = nn.ReLU(inplace=True)
        relu1= nn.LeakyReLU(inplace=True)

        drop=nn.Dropout(drop_rate)
        if Relu:
            self.cnn_layer=nn.Sequential(conv1,norm1,relu1,drop)
        else:
            self.cnn_layer=nn.Sequential(conv1,norm1,drop)
        init_weights_xavier_normal(self)
        
    def forward(self, x: Tensor):
         
        return(self.cnn_layer(x))

class _UnCNNLayer(nn.Module):
    def __init__(
        self, num_input_features: int, n_neurons: int, kernel_sze:int =3, 
        stride:int=2,
        drop_rate: float=0, 
        Relu=True
         ) -> None:
        super().__init__()
        
        self.stride=stride
        norm1 = nn.BatchNorm2d(n_neurons)
        conv1 = nn.ConvTranspose2d(num_input_features, n_neurons, kernel_size=kernel_sze,  
                               stride=stride, padding=(int((kernel_sze-1)/2)))

        
     #   relu1 = nn.ReLU(inplace=True)
        relu1 = nn.LeakyReLU(inplace=True)

        drop=nn.Dropout(drop_rate)
        if Relu:
            self.cnn_layer=nn.Sequential(conv1,norm1,relu1,drop)
        else:
            self.cnn_layer=nn.Sequential(conv1,norm1,drop)
        init_weights_xavier_normal(self)
        
    def forward(self, x: Tensor):
        
        if  self.stride>1:
            sze_enc=x.shape[-1]
            x=self.cnn_layer[0](x,output_size=(sze_enc*2,sze_enc*2))
            for k in np.arange(1,len(self.cnn_layer)):
                x=self.cnn_layer[k](x)
        else:
            x=self.cnn_layer(x)
            
        return(x)
    
    # def forward(self, x1, x2):
    #     x1 = self.cnn_layer(x1)
    #     # input is CHW
    #     diffY = x2.size()[2] - x1.size()[2]
    #     diffX = x2.size()[3] - x1.size()[3]

    #     x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
    #                     diffY // 2, diffY - diffY // 2])
    #     # if you have padding issues, see
    #     # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
    #     # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
    #     x = torch.cat([x2, x1], dim=1)
    #     return self.conv(x)
    
class _CNNBlock(nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_input_channels: int=1,
        drop_rate=0,
        block_config = (64,128),
        stride=None,
        decoder=False,
        Relu=True
    
    ) -> None:
        super().__init__()
        
        num_layers=len(block_config)
        self.num_input_channels=num_input_channels
        print('block inp ch',num_input_channels)
        
        if stride is None:
            stride=np.ones(num_layers)
            
        for i in range(num_layers):
            if decoder==True:
                layer = _UnCNNLayer(
                    num_input_channels,
                    n_neurons=block_config[i],
                    stride=stride[i],
                    drop_rate=drop_rate
                    
                )
            else:
                layer = _CNNLayer(
                    num_input_channels,
                    n_neurons=block_config[i],
                    stride=stride[i],
                    drop_rate=drop_rate, 
                    Relu=Relu
                    
                )
            self.add_module("cnnlayer%d" % (i + 1), layer)
            num_input_channels=block_config[i]

    def forward(self, x: Tensor) -> Tensor:
        
        for name, layer in self.items():
            x = layer(x)
            
            
        return x



In [None]:

#  BACKBONE MODULES 
class Encoder(nn.Module):
    r"""Encoder class
    `".
    Input Parameters:
        1. inputmodule_params: dictionary with keys ['num_input_channels']
            inputmodule_params['num_input_channels']=Channels of input images
        2. net_params: dictionary defining architecture: 
            net_params['block_configs']: list of number of neurons for each 
            convolutional block. A block can have more than one layer
            net_params['stride']:list of strides for each block layers
            net_params['drop_rate']: value of the Dropout (equal for all blocks)
        Examples: 
            1. Encoder with 4 blocks with one layer each
            net_params['block_configs']=[[32],[64],[128],[256]]
            net_params['stride']=[[2],[2],[2],[2]]
            2. Encoder with 2 blocks with two layers each
            net_params['block_configs']=[[32,32],[64,64]]
            net_params['stride']=[[1,2],[1,2]]
            
    """

    def __init__(self, inputmodule_params,net_params):
        super().__init__()
        
        
        num_input_channels=inputmodule_params['num_input_channels']
        

            
        drop_rate=net_params['drop_rate']
        block_configs=net_params['block_configs'].copy()
        n_blocks=len(block_configs)
        if 'stride' in net_params.keys():
            stride=net_params['stride']
        else:
            stride=[]
            for i in np.arange(len(block_configs)):
                stride.append(list(np.ones(len(block_configs[i])-1,dtype=int))+[2])
                
        # Encoder
        self.encoder=nn.Sequential(          
            )
        outchannels_encoder=[]
        for i in np.arange(n_blocks):
            print('block',i)
            block = _CNNBlock(
                num_input_channels=num_input_channels,
                drop_rate=drop_rate,
                block_config=block_configs[i],
                stride= stride[i]               
                
            )
            self.encoder.add_module("cnnblock%d" % (i + 1), block)
            
            if stride==1:
                self.encoder.add_module("mxpool%d" % (i + 1), 
                                         nn.MaxPool2d(kernel_size=2, stride=2, padding=0))

            num_input_channels=block_configs[i][-1] 
           # outchannels_encoder.append(num_input_channels)
                 
    def forward(self, x: Tensor) -> Tensor:
        x=self.encoder(x)
        return x

class Decoder(nn.Module):
    r"""Decoder class
    `".
    Input Parameters:
        1. inputmodule_params: dictionary with keys ['num_input_channels']
            inputmodule_params['num_input_channels']=Channels of input images
        2. net_params: dictionary defining architecture: 
            net_params['block_configs']: list of number of neurons for each conv block
            net_params['stride']:list of strides for each block layers
            net_params['drop_rate']: value of the Dropout (equal for all blocks)
    """
    def __init__(self, inputmodule_params,net_params):
        super().__init__()
        
   
        num_input_channels=inputmodule_params['num_input_channels']
        
        self.upPoolMode='bilinear'

            
        drop_rate=net_params['drop_rate']
        block_configs=net_params['block_configs'].copy()
        self.n_blocks=len(block_configs)
        
        if 'stride' in net_params.keys():
            stride=net_params['stride']
        else:
            stride=[]
            for i in np.arange(len(block_configs)):
                stride.append(list(np.ones(len(block_configs[i])-1,dtype=int))+[2])
                

        # Decoder
        self.decoder=nn.Sequential(          
            )
        
        for i0 in np.arange(self.n_blocks)[::-1]:
            i=self.n_blocks-(i0+1)
            block = _CNNBlock(
                num_input_channels=num_input_channels,
                drop_rate=drop_rate,
                block_config=block_configs[i], 
                stride=stride[i],
                decoder=True
            )
            
            # if stride==1:
            #     self.decoder.add_module("uppool%d" % (i + 1), 
            #                               nn.Upsample(scale_factor=2, 
            #                                           mode=self.upPoolMode, align_corners=True))
            
            self.decoder.add_module("cnnblock%d" % (i0+1), block)
      

            num_input_channels=block_configs[i][-1]
        
        
        self.decoder[-1][list(self.decoder[-1].keys())[-1]].cnn_layer[2]=nn.Identity()
        
    def forward(self, x: Tensor) -> Tensor:
        input_sze=x.shape
     #   for i in np.arange(n_blocks)[::-1]:     
        x=self.decoder(x)
        return x
    
class Attention(nn.Module):
    def __init__(self,net_params):
        super(Attention, self).__init__()
        self.M = net_params['in_features'] #Input dimension of the Values NV vectors 
        self.L = net_params['decom_space'] # Dimension of Q(uery),K(eys) decomposition space
        self.ATTENTION_BRANCHES = net_params['ATTENTION_BRANCHES']


        self.attention = nn.Sequential(
            nn.Linear(self.M, self.L), # matrix V
            nn.Tanh(),
            nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)
        )
    def forward(self, x):

        # H feature vector matrix  # NV vectors x M dimensions
        H = x.squeeze(0)
        # Attention weights
        A = self.attention(H)  # NVxATTENTION_BRANCHES
        A = torch.transpose(A, 1, 0)  # ATTENTION_BRANCHESxNV
        A = F.softmax(A, dim=1)  # softmax over NV
        
        # Context Vector (Attention Aggregation)
        Z = torch.mm(A, H)  # ATTENTION_BRANCHESxM 
        
        return Z, A

##### GENERATIVE MODELS 
class AutoEncoderWithAttention(nn.Module):
    r"""AutoEncoderCNN model with Attention.
    Incorporates an attention mechanism between the encoder and decoder.
    """
    def __init__(self, inputmodule_paramsEnc, net_paramsEnc, inputmodule_paramsDec, net_paramsDec, attention_params):
        super().__init__()

        self.inputmodule_paramsEnc = inputmodule_paramsEnc
        self.inputmodule_paramsDec = inputmodule_paramsDec
        self.net_paramsEnc = net_paramsEnc
        self.net_paramsDec = net_paramsDec

        # Encoder
        self.encoder = Encoder(inputmodule_paramsEnc, net_paramsEnc)

        # Attention mechanism
        self.attention = Attention(attention_params)

        # Decoder
        self.decoder = Decoder(inputmodule_paramsDec, net_paramsDec)

    def forward(self, x: Tensor) -> Tensor:
        # Paso por el Encoder
        encoded = self.encoder(x)

        # Aplicamos la atención al cuello de botella
        attended, attention_weights = self.attention(encoded)

        # Paso por el Decoder
        decoded = self.decoder(attended.unsqueeze(0))  # Añadimos la dimensión de batch
        return decoded, attention_weights
