In [None]:
import os, io, time
from pathlib import Path
from glob import glob

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from PIL import Image

from tfrecord.tools import tfrecord2idx
from tfrecord.torch.dataset import TFRecordDataset

# Modelo Transformer ViT

In [None]:
ruta_carpeta_actual = os.getcwd()
ruta_carpeta_raiz = os.path.dirname(ruta_carpeta_actual)

def ensure_index(tfr_path):
    idx_path = tfr_path + ".index"
    if not os.path.exists(idx_path):
        tfrecord2idx.create_index(tfr_path, idx_path)
    return idx_path

files = [
    os.path.join(ruta_carpeta_raiz, "dataset", "PuntosMuestra_CR_2023_patches_images", "patches11x11_CR_shard_0.tfrecord.gz"),
    os.path.join(ruta_carpeta_raiz, "dataset", "PuntosMuestra_CR_2023_patches_images", "patches11x11_CR_shard_1.tfrecord.gz"),
    os.path.join(ruta_carpeta_raiz, "dataset", "PuntosMuestra_CR_2023_patches_images", "patches11x11_CR_shard_2.tfrecord.gz"),
]
index_files = [ensure_index(p) for p in files]


## Limpieza de datos

In [None]:
description = {"image": "byte", "label": "int"}

vit_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])

keep_labels = [1,3,4,6,7,10]
label_map = {lbl:i for i,lbl in enumerate(keep_labels)}

def decode_and_transform(sample):
    lbl = int(sample["label"])
    if lbl == 2 or lbl not in label_map:
        return None
    img = Image.open(io.BytesIO(sample["image"])).convert("RGB")
    x = vit_tfms(img)
    y = label_map[lbl]
    return x, y

def collate_skip_none(batch):
    batch = [b for b in batch if b is not None]
    return (torch.utils.data.default_collate(batch)
            if batch else (torch.empty(0), torch.empty(0, dtype=torch.long)))

dataset = TFRecordDataset(
    data_path=files,
    index_path=index_files,                 # ← ahora SÍ hay índice
    description=description,
    compression_type="GZIP",
    transform=decode_and_transform,
    shuffle_queue_size=0,
)

### Matriz de confusion

In [None]:
n_total = len(dataset)
n_train = int(0.7 * n_total)
train_ds, test_ds = random_split(
    dataset, [n_train, n_total - n_train],
    generator=torch.Generator().manual_seed(42)
)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True,
                      num_workers=0, collate_fn=collate_skip_none)
test_dl  = DataLoader(test_ds, batch_size=64, shuffle=False,
                      num_workers=0, collate_fn=collate_skip_none)

xb, yb = next(iter(train_dl))
print(xb.shape, yb.shape)  # esperado: [B,3,224,224], [B]

# Registrar informacion

In [None]:
import sys
sys.path.append("..")
import importlib, utils_log
importlib.reload(utils_log)
from utils_log import log_row

carpeta_actual = ruta_carpeta_actual.split("\\")[-1]
dataset_utilizado = data_dir.split("\\")[-1]

log_row(
  script="20250901_PruebasEntrenamientoViT.ipynb",
  algoritmo="ViT_tiny_linear_probe",   # o "ViT_base_full"
  dataset=dataset_utilizado,
  clases_removidas=["02"],
  seed=42,
  n_train=len(train_ds), n_test=len(test_ds),
  n_features=None, num_classes=num_classes,
  fit_seconds=fit_s, pred_seconds=pred_s,
  ms_per_sample=(pred_s/len(y_true))*1000,
  OA=oa, F1_macro=f1m,
  carpeta=carpeta_actual
)

In [3]:
import os, io
import tensorflow as tf
from PIL import Image

ruta_carpeta_actual = os.getcwd()
ruta_carpeta_raiz = os.path.dirname(ruta_carpeta_actual)
tfr_path = os.path.join(
    ruta_carpeta_raiz, "dataset", "PuntosMuestra_CR_2023_patches_images", "patches11x11_CR_shard_0.tfrecord.gz"   # o "patches11x11_CR_shard_0.tfrecord" si ya lo descomprimiste
)

# Detecta si está comprimido por la extensión
compressed = tfr_path.lower().endswith(".gz")
ds = tf.data.TFRecordDataset(tfr_path, compression_type=("GZIP" if compressed else None))

def show_first(n=1):
    for i, raw in enumerate(ds.take(n)):
        ex = tf.train.Example(); ex.ParseFromString(raw.numpy())
        feat = ex.features.feature

        print(f"\n-- Registro {i} --")
        print("Claves:", list(feat.keys()))
        # imprime un resumen de cada campo
        for k, v in feat.items():
            if v.bytes_list.value:
                b = v.bytes_list.value[0]
                print(f"  {k}: bytes[{len(b)}] {b[:16]!r}...")
            elif v.int64_list.value:
                print(f"  {k}: int {int(v.int64_list.value[0])}")
            elif v.float_list.value:
                print(f"  {k}: float {float(v.float_list.value[0])}")

        # Si tu imagen viene como PNG/JPEG en un campo tipo bytes, pon aquí su nombre:
        IMAGE_KEY = "image_png"   # cámbialo si tu clave es distinta (p.ej. "image", "png", etc.)
        if IMAGE_KEY in feat and feat[IMAGE_KEY].bytes_list.value:
            img = Image.open(io.BytesIO(feat[IMAGE_KEY].bytes_list.value[0])).convert("RGB")
            print("Imagen:", img.size, img.mode)
            img.show()

show_first(n=3)



-- Registro 0 --
Claves: ['B2', 'CATEGORIA', 'B11', 'MNDWI', 'NDVI', 'EVI2', 'system:index', 'SR', 'NDMI', 'DEM', 'R35', 'R54', 'B8', 'B3', 'B4', 'GCVI']
  B2: float 0.04529999941587448
  CATEGORIA: float 1.0
  B11: float 0.22250999510288239
  MNDWI: float -0.4598506987094879
  NDVI: float 0.4820280373096466
  EVI2: float 0.2632536292076111
  system:index: bytes[22] b'0000000000000000'...
  SR: float 2.861212730407715
  NDMI: float 0.032487813383340836
  DEM: float 8.0
  R35: float 0.3470313251018524
  R54: float 1.0671573877334595
  B8: float 0.24530750513076782
  B3: float 0.08129750192165375
  B4: float 0.08055499941110611
  GCVI: float 2.1579105854034424

-- Registro 1 --
Claves: ['B2', 'CATEGORIA', 'B11', 'MNDWI', 'NDVI', 'EVI2', 'system:index', 'SR', 'NDMI', 'DEM', 'R35', 'R54', 'B8', 'B3', 'B4', 'GCVI']
  B2: float 0.03311749920248985
  CATEGORIA: float 1.0
  B11: float 0.20254500210285187
  MNDWI: float -0.4685955345630646
  NDVI: float 0.6857360601425171
  EVI2: float 0.44636