In [1]:
import io
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm
from torch import nn
from zipfile import ZipFile

INPLACE=False 
CONVPAD='zeros'

def pt_get_activation(activation) -> nn.Module:
    """ Retorna un modulo de activacion por nombre.
    Retorna None si el nombre no coincide """
    if(activation is None): return None
    elif isinstance(activation,nn.Module): return activation
    elif(activation=='relu'): return nn.ReLU(inplace=INPLACE)
    elif(activation=='elu'): return nn.ELU(inplace=INPLACE)
    elif(activation=='leakyrelu'): return nn.LeakyReLU(inplace=INPLACE)
    elif(activation=='sigmoid'): return nn.Sigmoid()
    elif(activation=='logsigmoid'): return nn.LogSigmoid()
    elif(activation=='softmax'): return nn.Softmax(dim=1)
    elif(activation=='softmin'): return nn.Softmin(dim=1)
    elif(activation=='logsoftmax'): return nn.LogSoftmax(dim=1)
    elif(activation=='prelu'): return nn.PReLU()
    elif(activation=='relu6'): return nn.ReLU6(inplace=INPLACE)
    elif(activation=='rrelu'): return nn.RReLU(inplace=INPLACE)
    elif(activation=='selu'): return nn.SELU(inplace=INPLACE)
    elif(activation=='celu'): return nn.CELU(inplace=INPLACE)
    elif(activation=='gelu'): return nn.GELU(approximate='tanh')
    elif(activation=='silu'): return nn.SiLU(inplace=INPLACE)
    elif(activation=='mish'): return nn.Mish(inplace=INPLACE)
    elif(activation=='softplus'): return nn.Softplus()
    elif(activation=='softsign'): return nn.Softsign()
    elif(activation=='softshrink'): return nn.Softshrink()
    elif(activation=='tanh'): return nn.Tanh()
    return None

class Conv(nn.Module):
    """ Convolution + Activation """
        
    def __init__(self, ic:int, oc:int, k=3, s=1, p=1, bias=True, activation=None, scale=None, residual=False):
        super().__init__()
        self.conv = nn.Conv2d(ic, oc, kernel_size=k, stride=s, padding=p, bias=bias, padding_mode=CONVPAD)
        self.activation = pt_get_activation(activation)
        self.scale = scale
        self.residual = residual
        
    def forward(self, x):
        x0 = x
        x = self.conv(x)
        if(self.activation is not None): x = self.activation(x)
        if(self.scale is not None): x = self.scale(x)
        if(self.residual): x = x0 + x
        return x
    
class ResID07(nn.Module):
    """ Bloque residual bn+act+conv+bn+act+conv con ajuste de dimensiones espaciales y semanticas """ 
    def __init__(self, ic:int, oc:int, activation='relu', dropout=0.0, expansion=1, resample=None):
        super().__init__()
        mc = int(ic*expansion)
        self.norm1 = torch.nn.BatchNorm2d(ic, momentum=0.01)
        self.act1  = pt_get_activation(activation) 
        self.conv1 = Conv(ic,mc,k=3,s=1,p=1)
        
        self.resample = resample
        
        self.norm2 = torch.nn.BatchNorm2d(mc, momentum=0.01)
        self.act2  = pt_get_activation(activation)
        self.dropout = nn.Dropout(dropout) if dropout>0.0 else None
        self.conv2 = Conv(mc,oc,k=3,s=1,p=1)
        
        self.conv3 = Conv(ic,oc,k=1,s=1,p=0) if ic!=oc else None
        
    def forward(self, x, emb=None):
        x0 = x
        if(self.norm1 is not None): x = self.norm1(x)
        if(self.act1  is not None): x = self.act1(x)
        x = self.conv1(x)
        
        if(emb is not None): x = x+emb
        
        if(self.resample is not None):
            x  = F.interpolate(x, scale_factor=self.resample, mode='bilinear')
            x0 = F.interpolate(x0,scale_factor=self.resample, mode='bilinear')
        
        if(self.norm2   is not None): x = self.norm2(x)
        if(self.act2    is not None): x = self.act2(x)
        if(self.dropout is not None): x = self.dropout(x)
        x = self.conv2(x)
        
        if self.conv3 is not None: x0 = self.conv3(x0) #(b,oc,h,w) Ajusta los canales para que sean iguales.
        
        return x0 + x
    
class ResID07N(nn.Module):
    """ N bloques ResID07 """
    def __init__(self, ic:int, n=2, activation='relu', dropout=0.0, expansion=2):
        super().__init__()
        self.layers = []
        for _ in range(n): self.layers.append(ResID07(ic,ic, activation=activation, dropout=dropout, expansion=expansion))
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.layers(x)

class UNetSam(nn.Module):
    """ UNET
    La primera seccion aplica ResID07 con resample=0.5 varias veces para reducir  la resolucion y aumentar la cantidad de filtros, hasta llegar a la maxima cantidad de filtros.
    La segunda seccion aplica ResID07 con resample=2.0 varias veces para aumentar la resolucion y reducir  la cantidad de filtros.
    Agrega enlaces entre la primera y segunda seccion.
    Retorna una lista 'y' con todas las salidas de cada nivel, siendo y[0] la entrada e y[-1] la ultima salida
    activation2: Activacion de SAM: sigmoid, softmax8, softmax16, softmax32
    """
    
    def __init__(self, filters=(1,16,32,64,128,256,128,64,32,16), n=1, activation='relu', dropout=0.0, expansion=1):
        super().__init__()
        c = len(filters) #Total de filtros
        j = np.argmax(filters) #Indice de la mayor cantidad de filtros (Parte mas ancha de la UNet)
        
        self.convs = nn.ModuleList() #Lista de modulos ConvBnAct
        self.convs_res = nn.ModuleList() #Lista de modulos ResID01N

        for i in range(1,j+1):
            f1 = filters[i-1]
            f2 = filters[i]
            self.convs.append( ResID07(f1,f2,resample=0.5,activation=activation, dropout=dropout, expansion=expansion))
            module = ResID07N(f2,n=n,activation=activation, dropout=dropout, expansion=expansion)
            self.convs_res.append(module)
            
        self.convts = nn.ModuleList() #Lista de modulos ConvTBnAct (Convolucion Transpuesta)
        self.convts_res = nn.ModuleList() #Lista de modulos ResID01N
        self.links = nn.ModuleList() #Lista de modulos ConvBnAct para enlazar la primera seccion con la segunda
        for i in range(j+1,c):
            f1 = filters[i-1]
            f2 = filters[i]
            self.convts.append(ResID07(f1,f2,resample=2,activation=activation, dropout=dropout, expansion=expansion))
            module = ResID07N(f2,n=n,activation=activation, dropout=dropout, expansion=expansion)
            self.convts_res.append(module)
            self.links.append(ResID07N(f2,n=n,activation=activation, dropout=dropout, expansion=expansion))
            
        self.c = c #Total de filtros
        self.j = j #Indice de la mayor cantidad de filtros

    def forward(self, x):
        """ x:(b,c,h,w) """

        y = [x] #Lista de salidas. El primer valor de salida es la entrada
        for conv,res  in zip(self.convs, self.convs_res): #Por cada convolucion de la primera seccion
            x = conv(x) #Aplica la convolucion
            
            x = res(x) #Aplica residual
            y.append(x) #Guarda la salida
            
        i=1
        j=self.j
        for convt,res,link in zip(self.convts, self.convts_res, self.links): #Por cada convolucion transpuesta y enlace
            x = convt(x) #Aplica la convolucion transpuesta 
            
            x = res(x) #Aplica residual
            if(j-i>=0): x=x+link(y[j-i]) #Aplica el enlace a la primera seccion
            i+=1
            y.append(x) #Guarda la salida
        
        return y #Retorna una lista con todas las salidas
    

class Unet(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.unet = UNetSam(filters=(1,16,32,64,128,256,128,64,32,16), n=1, activation='relu', dropout=0.0, expansion=1)
        self.convT_o1 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.convT_o2 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
        

    def forward(self, x):
        x -=  F.avg_pool2d(x, 101, stride=1, padding=50)
        y = self.unet(x)
        y1 = self.convT_o1(y[-1])
        output_pretil = self.convT_o2(y[-1])
        output_class = F.softmax(y1, dim=1)        
        return output_class, output_pretil

model = Unet()
model.cuda()

x = torch.randn(1, 1, 512, 512).cuda()
y = model(x)
y[0].shape, y[1].shape

(torch.Size([1, 3, 512, 512]), torch.Size([1, 1, 512, 512]))

In [2]:
import os

def imagenp_pad_img_multiplo(img,multiplo=64):
    """ Agrega filas y columnas con zeros a una imagen hasta que sus dimensiones sean multiplos del parámetro multiplo
    img: Numpy Array float32 con la forma (h,w,c).
    retorna: Numpy array float32 con la forma (h+deltah,w+deltaw,c).
    """
    h,w,c=img.shape
    th = int(max(np.ceil(h/multiplo)*multiplo,1.0))
    tw = int(max(np.ceil(w/multiplo),1.0)*multiplo)
    deltah=th-h
    deltaw=tw-w
    if(deltaw>0):
        cols=np.zeros((h,deltaw,c),dtype=img.dtype)
        img=np.concatenate([img,cols],axis=1)
    if(deltah>0):
        rows=np.zeros((deltah,tw,c),dtype=img.dtype)
        img=np.concatenate([img,rows],axis=0)
    return img

def testing(real_img, mask, epoch_idx, device, path):
    os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
    
    img_shape = real_img.shape 
    
    real_img = np.expand_dims(real_img, -1)
    real_img = imagenp_pad_img_multiplo(real_img, 32)
    real_img = np.expand_dims(real_img, 0)
    real_img = torch.tensor(real_img, dtype=torch.float32, device=device)
    real_img = real_img.permute(0,3,1,2)

    print(torch.min(real_img), torch.max(real_img))

    class_, pretil = model(real_img)

    pretil_img = pretil[0].cpu().detach().numpy().transpose(1, 2, 0)
    class_img = class_[0].cpu().detach().numpy().transpose(1, 2, 0)
    
    pretil_img = pretil_img[:img_shape[0], :img_shape[1], :]
    class_img = class_img[:img_shape[0], :img_shape[1], :]

    mask = np.expand_dims(mask, -1)
    class_img *= mask 
    pretil_img *= mask

    pretil_img = pretil_img.astype(np.float32) 
    print(pretil_img)                                              
    cv2.imwrite(f'{path}/pretil_epoch_{epoch_idx}.exr', pretil_img)

    pretil_img = (np.clip(pretil_img,0,1) * 255).astype(np.uint8)
    class_img = (np.clip(class_img,0,1) * 255).astype(np.uint8)    
    class_img = class_img[...,::-1]

    cv2.imwrite(f'{path}/pretil_epoch_{epoch_idx}.png',pretil_img)
    cv2.imwrite(f'{path}/class_epoch_{epoch_idx}.png',class_img)

In [3]:
# import copy 
# from torch.export import Dim

# def pt_module_export_onnx(module, input_shapes=[(1,3,32,32),(1,3,32,32)], filename="module.onnx", input_names=('input0','input1'), output_names=('output0','output1'), eval=True, float16=False, dynamic_axes_names={0:'b'}):
#     """ export model to onnx 
#     Args:
#         module: module to export
#         input_shapes: shapes of the inputs
#         filename: filename to save the onnx model.
#         input_names: Names of the inputs for the onnx module
#         output_names: Names of the outputs for the onnx module
#         dynamic_axes_names={0:'b'} para batch dinamico
#         dynamic_axes_names={0:'b',1:'h',2:'w'} para batch,height and width dinamicos
#     """
#     dtype = torch.float16 if float16 else torch.float32
#     inputs = [torch.randn(input_shape, device='cuda', dtype=dtype) for input_shape in input_shapes] #dummy inputs
#     inputs = tuple(inputs)
    
#     module = copy.deepcopy(module)
    
#     module.cuda()
#     if eval: module.eval()
#     if float16: module.half()
#     else: module.float()
    
#     axes_names = dynamic_axes_names
#     dynamic_axes = {}
#     for name in input_names:  dynamic_axes[name] = axes_names
#     for name in output_names: dynamic_axes[name] = axes_names
#     print(dynamic_axes)
#     torch.onnx.export(module, 
#                       inputs, 
#                       f=filename, 
#                       input_names=input_names, 
#                       output_names=output_names, 
#                       dynamic_axes=dynamic_axes, 
#                       report=True,
#                       optimize=True,
#                       verify=True,
#                       profile=True,
#                       dump_exported_program=True,
#                       fallback=True,
#                       verbose=True, 
#                       dynamo=True,
#                       external_data=False)
    
#     del module

# model = Unet()
# pt_module_export_onnx(model, input_shapes=[(1,1,32,32)], input_names=('x',), output_names=('class_output', 'pretil_output'),  dynamic_axes_names={0:'b',1:'h',2:'w'})

In [4]:
import torch
print(torch.__version__)

2.6.0+cu124


In [5]:
# model = Unet().cuda().eval()
# dummy_input = torch.randn(1, 1, 1984, 1472, device='cuda') 

# torch.onnx.export(
#     model,
#     dummy_input,
#     'unet_dynamic.onnx',
#     input_names=['input'],
#     output_names=['output_class', 'output_pretil'],
#     #dynamic_axes=dynamic_axes,
#     opset_version=12,
#     verbose=True,
#     export_params=True,  # Store the trained parameter weights inside the model file
#     do_constant_folding=True,  # Whether to execute constant folding for optimization
#     dynamo=True
# )

In [6]:
# import onnxruntime as ort
# import numpy as np
# import torch

# # Cargar el modelo ONNX
# onnx_model_path = 'module.onnx'
# ort_session = ort.InferenceSession(onnx_model_path)

# # Crear una entrada de prueba (igual que la que usaste para exportar el modelo)
# dummy_input = torch.randn(1, 1, 256, 256, device='cuda').cpu().numpy()

# # Ejecutar el modelo ONNX
# outputs = ort_session.run(None, {'x': dummy_input})

# # Imprimir las salidas
# print("Output class:", outputs[0].shape)
# print("Output pretil:", outputs[1].shape)

In [7]:
import numpy as np

dem_file = 'dem3.npz'
real_image = np.load(dem_file)['dem']
print(real_image.shape)
mask = np.load(dem_file)['mask']
model = Unet()

for i in range(100):
    model.load_state_dict(torch.load(f'weigths/model_epoch_{i+1}.pth'))
    model.eval()
    model.cuda()
    testing(real_image, mask, i, 'cuda', 'testing/inference_dem_3')
    break

(1954, 1464)
tensor(-14.2266, device='cuda:0') tensor(13.3157, device='cuda:0')
[[[ 0.       ]
  [-0.       ]
  [ 0.       ]
  ...
  [ 5.446713 ]
  [-1.340947 ]
  [ 5.721717 ]]

 [[-0.       ]
  [ 0.       ]
  [-0.       ]
  ...
  [ 2.6523829]
  [ 5.258266 ]
  [ 2.5021825]]

 [[ 0.       ]
  [ 0.       ]
  [ 0.       ]
  ...
  [ 3.9652572]
  [ 0.9163634]
  [ 5.240374 ]]

 ...

 [[-0.       ]
  [-0.       ]
  [ 0.       ]
  ...
  [ 0.       ]
  [-0.       ]
  [ 0.       ]]

 [[ 0.       ]
  [ 0.       ]
  [-0.       ]
  ...
  [-0.       ]
  [-0.       ]
  [-0.       ]]

 [[-0.       ]
  [-0.       ]
  [ 0.       ]
  ...
  [ 0.       ]
  [-0.       ]
  [ 0.       ]]]
