# Integrate Asger's heads

In [None]:
from itertools import chain

import torch
from torch import nn as nn
import torchvision

# from mini_trainer.classifier import Classifier
from mini_trainer.hierarchical.integration import sparse_masks_from_labels, HierarchicalBuilder
from mini_trainer.hierarchical.model import HierarchicalClassifier
from mini_trainer.hierarchical.loss import MultiLevelWeightedCrossEntropyLoss, MultiLevelLoss

from pathlib import Path
import pandas as pd
from collections import OrderedDict
import yaml

from fastai.learner import Learner
from fastai.vision.all import (
    DataBlock,
    ImageBlock,
    MultiCategoryBlock,
    CategoryBlock,
    ColSplitter,
    ColReader,
    Pipeline,
    Resize,
    aug_transforms,
    vision_learner,
    # partial,
    # F1ScoreMulti,
    # accuracy_multi,
    # ShowGraphCallback,
    # CSVLogger,
    # EarlyStoppingCallback,
    # ImageDataLoaders,
    # SaveModelCallback,
    DisplayedTransform,
)
# from fastai.transform import Transform

from time import time

In [None]:
hier_path = Path("/home/george/codes/lepinet/data/global_lepi/hierarchy.csv")

In [None]:
hier = pd.read_csv(hier_path)
labels = {str(r['speciesKey']):r.values.astype(str).tolist() for i,r in hier.iterrows()}
cls2idx = {str(i):{str(e):j for j,e in enumerate(s.unique())} for i,(n,s) in enumerate(hier.items())}

In [None]:
sparse_masks=sparse_masks_from_labels(labels, cls2idx)

In [None]:
# model
model_builder_kwargs ={
            "model_type" : "efficientnet_v2_s",
            "weights" : None,
            "hidden" : 512,
            "droprate" : 0.1,
            "normalized" : True,
            "sparse_masks" : sparse_masks,
            "num_classes" : len(cls2idx['0']),
        }
model, model_preprocessor = HierarchicalClassifier.build(**model_builder_kwargs)

In [None]:
model_preprocessor

In [None]:
# loss function
class SumMultiLevelWeightedCrossEntropyLoss(torch.nn.modules.loss._Loss):
    def __init__(
            self, 
            weights : list[float | int] | torch.Tensor,
            device : torch._prims_common.DeviceLikeType, 
            dtype : torch.types._dtype, 
            class_weights : list[torch.Tensor] | None=None,
            label_smoothing : float = 0.0
        ):
        super().__init__()
        self.device = device
        self.dtype = dtype

        self.weights = torch.tensor(weights).to(device=device, dtype=dtype)
        self.n_levels = len(weights)
        self.label_smoothing = [1 - (1 - label_smoothing)**(1/(i+1)) for i in range(self.n_levels)]
        
        self._loss_fns = [
            nn.CrossEntropyLoss(
                # weight=None, #self.class_weights[i], 
                # reduction="none", 
                label_smoothing=label_smoothing
            ) for _ in range(self.n_levels)
        ]

    def __call__(
            self, 
            preds : torch.Tensor, 
            *targets
        ) -> "MultiLevelLoss":
        return sum(list(MultiLevelLoss(
            [
                self._loss_fns[i](preds[i], targets[i])
                for i in range(self.n_levels)
            ], 
             self.weights
        )))
loss = SumMultiLevelWeightedCrossEntropyLoss(weights=[1.0,1.0,1,0], device='cpu', dtype=torch.float)

In [None]:
config_path = Path("/home/george/codes/lepinet/configs/20251106_1_test_ece.yaml")
with open(config_path) as f:
    config=yaml.safe_load(f)
# gen_dls = getattr(importlib.import_module('011_lepi_large_prod_v2'), 'gen_dls')
# dls,hierarchy=gen_dls(**config['train'])
# model_arch = getattr(importlib.import_module('fastai.vision.all'), config['train']['model_arch_name'])
# learn = vision_learner(dls, model_arch)
# learn.model

In [None]:
df_path = Path("/home/george/codes/lepinet/data/global_lepi/0032836-250426092105405_processing_metadata_postprocessed_quality_filtered.lepinet.parquet")
df = pd.read_parquet(df_path)

In [None]:
df.head()

In [None]:
i,r=next(iter(df.iterrows()))

In [None]:
r['speciesKey'], r['genusKey'], r['familyKey']

In [None]:
aug_img_size = 460
img_size = 256
batch_size = 64

img_dir = Path("/home/george/codes/lepinet/data/global_lepi/images")
vocab=[]
for c in hier.columns: vocab.extend(hier[c].unique().astype(str).tolist())

In [None]:
start = time()
datablock = DataBlock(
        blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)),
        splitter=ColSplitter(),
        get_x=ColReader(0, pref=img_dir),
        get_y=ColReader(1, label_delim=' '),
        item_tfms=Resize(aug_img_size),
        batch_tfms=aug_transforms(size=img_size)
    )
dls = datablock.dataloaders(df, bs=batch_size)
end = time()
print(end-start)

In [None]:
df[['speciesKey','genusKey','familyKey']] = df['hierarchy_labels'].str.split(' ', expand=True)

In [None]:
df.head()

In [None]:
vocab_3 = df['speciesKey'].unique().tolist()
vocab_4 = df['genusKey'].unique().tolist()
vocab_5 = df['familyKey'].unique().tolist()

In [None]:
start = time()
datablock = DataBlock(
    blocks=(ImageBlock, CategoryBlock(vocab=vocab_3), CategoryBlock(vocab=vocab_4), CategoryBlock(vocab=vocab_5)),
    n_inp=1,
    splitter=ColSplitter(),
    get_x=ColReader(0, pref=img_dir),
    get_y=[ColReader(3), ColReader(4), ColReader(5)],
    item_tfms=[Resize(aug_img_size), model_preprocessor],
    batch_tfms=aug_transforms(size=img_size)
)
dls = datablock.dataloaders(df, bs=4)
end = time()
print(end-start)

In [None]:
b=dls.train.one_batch()

In [None]:
b[-1].tensor()

In [None]:
b[0].shape

In [None]:
b[1].data

In [None]:
out=model(model_preprocessor(b[0].cpu()))

In [None]:
len(out)

In [None]:
out[0].shape

In [None]:
b[1].shape

In [None]:
[type(i) for i in out]

In [None]:
[type(i) for i in list(b[1:])]

In [None]:
loss(out,list(b[1:]))

In [None]:
type(SumMultiLevelWeightedCrossEntropyLoss(weights=[1.0,1.0,1,0], device='cpu', dtype=torch.float))

In [None]:
learn = Learner(
    dls,
    model,
    loss_func=SumMultiLevelWeightedCrossEntropyLoss(weights=[1.0,0.0,0.0], device='cuda', dtype=torch.float),
)

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(1)