In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F
from collections import deque
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
import os
import math
from sklearn.metrics import davies_bouldin_score
import concurrent.futures
import matplotlib.pyplot as plt
import cv2

from RLEnviroment import RL_Agent, NEUEnvironment, Gym
from utils.Loader import NEUDataset
from utils.Perspectiver import Perspectiver
from source.Prototype1 import Prototype1

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from stable_baselines3 import SAC

# Load the SAC model
loaded_model = SAC.load("SAC_TEST")  # Pass env if resuming training

# Verify it loaded correctly
print("✅ Loaded SAC model:", loaded_model)


In [2]:
def plot_barchartImage(image):
    x = np.arange(image.shape[0])
    y = np.arange(image.shape[1])
    x, y = np.meshgrid(x, y)

    # Flatten arrays for plotting
    x = x.flatten()
    y = y.flatten()
    z = np.zeros_like(x)
    dx = dy = np.ones_like(x)
    dz = image.flatten()

    # Plot the 3D bar chart
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.bar3d(x, y, z, dx, dy, dz, shade=True)

    # Add labels and title
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Value')
    ax.set_title('3D Bar Chart of (200, 200) Array')

    plt.show()

In [3]:
dataset = NEUDataset(set="train", scale=0.5)

# Load the saved model checkpoint
checkpoint = torch.load("h1.pth", map_location=torch.device("cpu"))

# Recreate the model architecture (must match the one used during training)
loaded_model = Prototype1(num_attention_heads=checkpoint['num_attention_heads'])

# Load the saved weights into the model
loaded_model.load_state_dict(checkpoint['state_dict'])

# Set the model to evaluation mode
loaded_model.eval()

print("Modelo cargado correctamente:", loaded_model)


Modelo cargado correctamente: Prototype1(
  (cnn_block): CNNBlock(
    (conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (attention_block): AttentionBlock(
    (qkv): Linear(in_features=144, out_features=432, bias=False)
    (out_proj): Linear(in_features=144, out_features=1, bias=True)
  )
  (experts): ModuleList(
    (0-31): 32 x FCExpert(
      (fc1): Linear(in_features=144, out_features=

  checkpoint = torch.load("h1.pth", map_location=torch.device("cpu"))


In [13]:
for i in range(len(dataset)):
    image , label = dataset.__getitem__(index=i)
    values  = loaded_model(torch.tensor(Perspectiver.normalize_to_uint8(image.detach().cpu().numpy()), dtype=torch.float))

    sp = float(values[0][0].detach().cpu().numpy())
    sr = float(values[0][1].detach().cpu().numpy())

    print("sp", sp)
    print("sr", sr)

sp 0.08607348054647446
sr 0.0
sp 0.08607301861047745
sr 0.0
sp 0.08607195317745209
sr 0.0
sp 0.08607251942157745
sr 0.0
sp 0.08607520908117294
sr 0.0
sp 0.08607380092144012
sr 0.0
sp 0.08606449514627457
sr 0.0
sp 0.0860712006688118
sr 0.0
sp 0.08607058227062225
sr 0.0
sp 0.08606700599193573
sr 0.0
sp 0.08606716990470886
sr 0.0
sp 0.08606978505849838
sr 0.0
sp 0.08607392013072968
sr 0.0
sp 0.08607078343629837
sr 0.0
sp 0.08607298880815506
sr 0.0
sp 0.08607037365436554
sr 0.0
sp 0.08606404811143875
sr 0.0
sp 0.08606887608766556
sr 0.0
sp 0.08607126772403717
sr 0.0
sp 0.08606567978858948
sr 0.0
sp 0.08607321977615356
sr 0.0
sp 0.08607441931962967
sr 0.0
sp 0.0860724151134491
sr 0.0
sp 0.08606933057308197
sr 0.0
sp 0.08606626093387604
sr 0.0
sp 0.08607734739780426
sr 0.0
sp 0.08606968075037003
sr 0.0
sp 0.0860738530755043
sr 0.0
sp 0.08606355637311935
sr 0.0
sp 0.08606576919555664
sr 0.0
sp 0.08607008308172226
sr 0.0
sp 0.08607180416584015
sr 0.0
sp 0.08606593310832977
sr 0.0
sp 0.08606700

In [None]:
image = Perspectiver.grayscale_to_rgb(Perspectiver.normalize_to_uint8(image.detach().cpu().numpy()[0]))
image.shape

In [None]:
plot_barchartImage(Perspectiver.rgb_to_grayscale(image))

In [None]:
clustered_image = Perspectiver.meanShift(image, float(sp.detach().cpu().numpy()), float(sr.detach().cpu().numpy()))
plot_barchartImage(Perspectiver.rgb_to_grayscale(clustered_image))