In [None]:
import os
import random
import zipfile
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile

import albumentations as A
import cv2
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from clearml import Task
from dotenv import load_dotenv
from IPython.display import Image, display

from mglyph_ml.dataset.glyph_dataset import GlyphDataset
from mglyph_ml.dataset.manifest import DatasetManifest, ManifestSample
from mglyph_ml.experiment.e1.experiment import ExperimentConfig
from mglyph_ml.experiment.e1.train_model import train_and_test_model
from mglyph_ml.experiment.e1.util import load_image_into_ndarray

In [None]:
path = Path("../data/uni.mglyph")

config = ExperimentConfig(
    task_name="Experiment 1.2.1",
    task_tag="exp-1.2.1",
    dataset_path=Path("../data/uni.mglyph"),
    gap_start_x=40.0,
    gap_end_x=60.0,
    quick=True,
    seed=420,
    max_iterations=5,
    offline=True,
)

Task.set_offline(config.offline)
task: Task = Task.init(project_name="mglyph-ml", task_name=config.task_name)
task.add_tags(config.task_tag)
task.connect(config)

load_dotenv()

In [None]:
# loading everything... this cell takes the longest time
# Load the entire zip file into memory
with open(path, "rb") as f:
    temp_archive = ZipFile(BytesIO(f.read()))

manifest_data = temp_archive.read("manifest.json")
manifest = DatasetManifest.model_validate_json(manifest_data)

samples_0 = manifest.samples["0"]  # this is where the training and validation data comes from
samples_1 = manifest.samples["1"]  # this is where the test data comes from

# Create index mappings for each subset
indices_train = [
    i for i, sample in enumerate(samples_0) if sample.x < config.gap_start_x or sample.x >= config.gap_end_x
]
indices_gap = [
    i for i, sample in enumerate(samples_0) if sample.x >= config.gap_start_x and sample.x < config.gap_end_x
]
indices_test = list(range(len(samples_1)))

random.shuffle(indices_train)
random.shuffle(indices_gap)
random.shuffle(indices_test)

indices_train = indices_train[: len(indices_train) // 2]
indices_gap = indices_gap[: len(indices_gap) // 2]
indices_test = indices_test[: len(indices_test) // 2]

affine = A.Affine(
    rotate=(-5, 5),
    translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
    fit_output=False,
    keep_ratio=True,
    border_mode=cv2.BORDER_CONSTANT,
    fill=255,
    p=1.0,
)
normalize = A.Normalize(normalization="min_max")
to_tensor = ToTensorV2()
pipeline = A.Compose([affine, normalize, to_tensor], seed=420)
normalize_pipeline = A.Compose([normalize, to_tensor])


def affine_and_normalize(image: np.ndarray) -> torch.Tensor:
    return pipeline(image=image)["image"]


def just_normalize(image: np.ndarray) -> torch.Tensor:
    return normalize_pipeline(image=image)["image"]


print("loading images")

# Load all images from memory (much faster than disk access)
with ThreadPoolExecutor(max_workers=32) as executor:
    images_train = list(
        executor.map(lambda i: affine_and_normalize(load_image_into_ndarray(temp_archive, samples_0[i].filename)), indices_train)
    )
    # images_0 = list(
    #     executor.map(lambda sample: load_image_into_ndarray(temp_archive, sample.filename), samples_0)
    # )
    # images_1 = list(
    #     executor.map(lambda sample: load_image_into_ndarray(temp_archive, sample.filename), samples_1)
    # )

# Reference image subsets using indices
# images_train = [images_0[i] for i in indices_train]
# images_gap = [images_0[i] for i in indices_gap]
# images_test = [images_1[i] for i in indices_test]

# labels_train = [samples_0[i].x for i in indices_train]
# labels_gap = [samples_0[i].x for i in indices_gap]
# labels_test = [samples_1[i].x for i in indices_test]

temp_archive.close()

In [None]:
# Verify RGB format and display sample image
print(f"Number of images: {len(images_train)}")
print(f"First image type: {type(images_train[0])}")
print(f"First image dtype: {images_train[0].dtype}")
print(f"First image shape: {images_train[0].shape}")
print(f"Training image count: {len(images_train)}")

# Display one sample image (convert BGR to RGB for display)
display(Image(data=cv2.imencode('.png', images_train[21].numpy().transpose(1, 2, 0) * 255.0)[1].tobytes()))

In [None]:
dataset_train = GlyphDataset(images=images_train, labels=labels_train)
dataset_gap = GlyphDataset(images=images_gap, labels=labels_gap)
dataset_test = GlyphDataset(images=images_test, labels=labels_test)

image = dataset_train[2000][0].numpy().transpose(1, 2, 0) * 255.0
print(image.shape)
display(Image(data=cv2.imencode(".png", image)[1].tobytes()))

In [None]:
device = os.environ["MGML_DEVICE"]

train_and_test_model(
    device=device,
    dataset_train=dataset_train,
    dataset_gap=dataset_gap,
    dataset_test=dataset_test,
    seed=420,
    data_loader_num_workers=32,
    batch_size=256,
    quick=False,
    max_epochs=20,
    model_save_path=Path("../models/exp1.pt"),
)