# Vision and Cognition Systems - Project


In [6]:
#libraries to import
#known
import pandas as pd
import numpy as np
import os
import re
import torchvision
import torch
import PIL
from PIL import Image
from PIL import ImageFile
import sys
import time
from math import ceil



#Unknown
from typing import Union
from io import BytesIO
import random
from argparse import Namespace, ArgumentParser
from pathlib import Path
from multiprocessing import Pool
from functools import partial
import requests
import logging
import json
import yaml
from tqdm.auto import tqdm
#from classification.train_base import MultiPartitioningClassifier
#from classification.dataset import FiveCropImageDataset

#to divide
from classification import utils_global
from classification.s2_utils import Partitioning, Hierarchy
from classification.dataset import MsgPackIterableDatasetMultiTargetWithDynLabels


## Functions in the first the main GitHub folder

This is from ```msgpack_viewer.py```

In [None]:
class MsgPackIterableDataset(torch.utils.data.IterableDataset):

    def __init__(
        self,
        path: str,
        key_img_id: str = "id",
        key_img_encoded: str = "image",
        transformation=None,
        shuffle=False,
        cache_size=6 * 4096,
    ):

        super(MsgPackIterableDataset, self).__init__()
        self.path = path
        self.cache_size = cache_size
        self.transformation = transformation
        self.shuffle = shuffle
        self.seed = random.randint(1, 100)
        self.key_img_id = key_img_id.encode("utf-8")
        self.key_img_encoded = key_img_encoded.encode("utf-8")

        if not isinstance(self.path, (list, set)):
            self.path = [self.path]
        
        self.shards = self.__init_shards(self.path)

    @staticmethod
    def __init_shards(path: Union[str, Path]) -> list:
        shards = []
        for i, p in enumerate(path):
            shards_re = r"shard_(\d+).msg"
            shards_index = [
                int(re.match(shards_re, x).group(1))
                for x in os.listdir(p)
                if re.match(shards_re, x)
            ]
            shards.extend(
                [
                    {
                        "path_index": i,
                        "path": p,
                        "shard_index": s,
                        "shard_path": os.path.join(p, f"shard_{s}.msg"),
                    }
                    for s in shards_index
                ]
            )
        if len(shards) == 0:
            raise ValueError("No shards found")
        
        return shards

    def _process_sample(self, x):
        # decode and initial resize if necessary
        img = Image.open(BytesIO(x[self.key_img_encoded]))
        if img.mode != "RGB":
            img = img.convert("RGB")

        if img.width > 320 and img.height > 320:
            img = torchvision.transforms.Resize(320)(img)

        # apply all user specified image transformations
        if self.transformation is not None:
            img = self.transformation(img)
        
        _id = x[self.key_img_id].decode("utf-8")
        return img, _id

    def __iter__(self):

        shard_indices = list(range(len(self.shards)))

        if self.shuffle:
            random.seed(self.seed)
            random.shuffle(shard_indices)

        worker_info = torch.utils.data.get_worker_info()

        if worker_info is not None:

            def split_list(alist, splits=1):
                length = len(alist)
                return [
                    alist[i * length // splits : (i + 1) * length // splits]
                    for i in range(splits)
                ]

            shard_indices_split = split_list(shard_indices, worker_info.num_workers)[
                worker_info.id
            ]

        else:
            shard_indices_split = shard_indices

        cache = []

        for shard_index in shard_indices_split:
            shard = self.shards[shard_index]

            with open(
                os.path.join(shard["path"], f"shard_{shard['shard_index']}.msg"), "rb"
            ) as f:
                unpacker = msgpack.Unpacker(
                    f, max_buffer_size=1024 * 1024 * 1024, raw=True
                )
                for x in unpacker:
                    if x is None:
                        continue

                    if len(cache) < self.cache_size:
                        cache.append(x)

                    if len(cache) == self.cache_size:

                        if self.shuffle:
                            random.shuffle(cache)
                        while cache:
                            yield self._process_sample(cache.pop())
        if self.shuffle:
            random.shuffle(cache)

        while cache:
            yield self._process_sample(cache.pop())



In [None]:
if __name__ == "__main__":

    args = argparse.ArgumentParser()
    args.add_argument("--data", type=str, default="resources/images/mp16")
    args = args.parse_args()

    tfm = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
        ]
    )

    dataset = MsgPackIterableDataset(path=args.data, transformation=tfm)
    dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            num_workers=6,
            pin_memory=False,
        )

    num_images = 0
    for x, image_id in dataloader:
        if num_images == 0:
            print(x.shape, image_id)
        num_images +=1
    
    print(f"{num_images=}")


This is from ```filter_by_downloaded_images.py``` that maybe we won't need since we will use a different dataset.

In [None]:
def main():

    for dataset_type in ["train", "val"]:
        with open(config[f"{dataset_type}_label_mapping"]) as f:
            mapping = json.load(f)

        logging.info(f"Expected dataset size: {len(mapping)}")
        msgpack_path = config[f"msgpack_{dataset_type}_dir"]
        image_ids_path = config[f"{dataset_type}_meta_path"]
        dataset = MsgPackIterableMetaDataset(
            msgpack_path,
            image_ids_path,
            image_ids_path,
            key_img_id=config["key_img_id"],
            key_img_encoded=config["key_img_encoded"],
            ignore_image=True,
        )

        filtered_mapping = {}
        for _, meta in dataset:
            if meta["img_id"] in mapping:
                filtered_mapping[meta["img_id"]] = mapping[meta["img_id"]]
        logging.info(f"True dataset size: {len(filtered_mapping)}")

        with open(config[f"{dataset_type}_label_mapping"], "w") as fw:
            json.dump(filtered_mapping, fw)
    return


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("-c", "--config", type=Path, default="config/baseM.yml")
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    logging.basicConfig(
        format="%(asctime)s %(levelname)s: %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S",
        level=logging.INFO,
    )

    args = parse_args()

    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    config = config["model_params"]

    main()


Finally this is from ```download_images.py```.

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True


class MsgPackWriter:
    def __init__(self, path, chunk_size=4096):
        self.path = Path(path).absolute()
        self.path.mkdir(parents=True, exist_ok=True)
        self.chunk_size = chunk_size

        shards_re = r"shard_(\d+).msg"
        self.shards_index = [
            int(re.match(shards_re, x).group(1))
            for x in self.path.iterdir()
            if x.is_dir() and re.match(shards_re, x)
        ]
        self.shard_open = None

    def open_next(self):
        if len(self.shards_index) == 0:
            next_index = 0
        else:
            next_index = sorted(self.shards_index)[-1] + 1
        self.shards_index.append(next_index)

        if self.shard_open is not None and not self.shard_open.closed:
            self.shard_open.close()

        self.count = 0
        self.shard_open = open(self.path / f"shard_{next_index}.msg", "wb")

    def __enter__(self):
        self.open_next()
        return self

    def __exit__(self, type, value, tb):
        self.shard_open.close()

    def write(self, data):
        if self.count >= self.chunk_size:
            self.open_next()

        self.shard_open.write(msgpack.packb(data))
        self.count += 1


def _thumbnail(img: PIL.Image, size: int) -> PIL.Image:
    # resize an image maintaining the aspect ratio
    # the smaller edge of the image will be matched to 'size'
    w, h = img.size
    if (w <= size) or (h <= size):
        return img
    if w < h:
        ow = size
        oh = int(size * h / w)
        return img.resize((ow, oh), PIL.Image.BILINEAR)
    else:
        oh = size
        ow = int(size * w / h)
        return img.resize((ow, oh), PIL.Image.BILINEAR)


def flickr_download(x, size_suffix="z", min_edge_size=None):

    # prevent downloading in full resolution using size_suffix
    # https://www.flickr.com/services/api/misc.urls.html

    image_id = x["image_id"]
    url_original = x["url"]
    if size_suffix != "":
        url = url_original
        # modify url to download image with specific size
        ext = Path(url).suffix
        url = f"{url.split(ext)[0]}_{size_suffix}{ext}"
    else:
        url = url_original

    r = requests.get(url)
    if r:
        try:
            image = PIL.Image.open(BytesIO(r.content))
        except PIL.UnidentifiedImageError as e:
            logger.error(f"{image_id} : {url}: {e}")
            return
    elif r.status_code == 129:
        time.sleep(60)
        logger.warning("To many requests, sleep for 60s...")
        flickr_download(x, min_edge_size=min_edge_size, size_suffix=size_suffix)
    else:
        logger.error(f"{image_id} : {url}: {r.status_code}")
        return None

    if image.mode != "RGB":
        image = image.convert("RGB")

    # resize if necessary
    image = _thumbnail(image, min_edge_size)
    # convert to jpeg
    fp = BytesIO()
    image.save(fp, "JPEG")

    raw_bytes = fp.getvalue()
    return {"image": raw_bytes, "id": image_id}


class ImageDataloader:
    def __init__(self, url_csv: Path, shuffle=False, nrows=None):

        logger.info("Read dataset")
        self.df = pd.read_csv(
            url_csv, names=["image_id", "url"], header=None, nrows=nrows
        )
        # remove rows without url
        self.df = self.df.dropna()
        if shuffle:
            logger.info("Shuffle images")
            self.df = self.df.sample(frac=1, random_state=10)
        logger.info(f"Number of URLs: {len(self.df.index)}")

    def __len__(self):
        return len(self.df.index)

    def __iter__(self):
        for image_id, url in zip(self.df["image_id"].values, self.df["url"].values):
            yield {"image_id": image_id, "url": url}


def parse_args():
    args = ArgumentParser()
    args.add_argument(
        "--threads",
        type=int,
        default=24,
        help="Number of threads to download and process images",
    )
    args.add_argument(
        "--output",
        type=Path,
        default=Path("resources/images/mp16"),
        help="Output directory where images are stored",
    )
    args.add_argument(
        "--url_csv",
        type=Path,
        default=Path("resources/mp16_urls.csv"),
        help="CSV with Flickr image id and URL for downloading",
    )
    args.add_argument(
        "--size",
        type=int,
        default=320,
        help="Rescale image to a minimum edge size of SIZE",
    )
    args.add_argument(
        "--size_suffix",
        type=str,
        default="z",
        help="Image size suffix according to the Flickr API; Empty string for original image",
    )
    args.add_argument("--nrows", type=int)
    args.add_argument(
        "--shuffle", action="store_true", help="Shuffle list of URLs before downloading"
    )
    return args.parse_args()


def main():

    image_loader = ImageDataloader(args.url_csv, nrows=args.nrows, shuffle=args.shuffle)

    counter_successful = 0
    with Pool(args.threads) as p:
        with MsgPackWriter(args.output) as f:
            start = time.time()
            for i, x in enumerate(
                p.imap(
                    partial(
                        flickr_download,
                        size_suffix=args.size_suffix,
                        min_edge_size=args.size,
                    ),
                    image_loader,
                )
            ):
                if x is None:
                    continue

                f.write(x)
                counter_successful += 1

                if i % 1000 == 0:
                    end = time.time()
                    logger.info(f"{i}: {1000 / (end - start):.2f} image/s")
                    start = end
    logger.info(
        f"Sucesfully downloaded {counter_successful}/{len(image_loader)} images ({counter_successful / len(image_loader):.3f})"
    )
    return 0


if __name__ == "__main__":
    args = parse_args()
    args.output.mkdir(parents=True, exist_ok=True)

    logger = logging.getLogger("ImageDownloader")
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(str(args.output / "writer.log"))
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    logger.addHandler(fh)
    logger.addHandler(ch)

    sys.exit(main())


Now following the github page we need:
- item 1
- item 2
- item 3
- item 4 

The main link and paper that we need to follow is [this](https://github.com/TIBHannover/GeoEstimation) and [this](https://github.com/TIBHannover/GeoEstimation/releases/) for the pretrained models.

Davide ha trovato questo che forse è meglio [kaggle](https://www.kaggle.com/code/habedi/inspect-the-dataset/data)

Qui ci sono dei links che potrebbero essere usati con colab col comando [!wget](https://qualinet.github.io/databases/image/world_wide_scale_geotagged_image_dataset_for_automatic_image_annotation_and_reverse_geotagging/)

# Functions inside Classification folder

This is ```inference.py```

In [None]:
def parse_args():
    args = ArgumentParser()
    args.add_argument(
        "--checkpoint",
        type=Path,
        default=Path("models/base_M/epoch=014-val_loss=18.4833.ckpt"),
        help="Checkpoint to already trained model (*.ckpt)",
    )
    args.add_argument(
        "--hparams",
        type=Path,
        default=Path("models/base_M/hparams.yaml"),
        help="Path to hparams file (*.yaml) generated during training",
    )
    args.add_argument(
        "--image_dir",
        type=Path,
        default=Path("resources/images/im2gps"),
        help="Folder containing images. Supported file extensions: (*.jpg, *.jpeg, *.png)",
    )
    # environment
    args.add_argument(
        "--gpu",
        action="store_true",
        help="Use GPU for inference if CUDA is available",
    )
    args.add_argument("--batch_size", type=int, default=64)
    args.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="Number of workers for image loading and pre-processing",
    )
    return args.parse_args()


args = parse_args()

print("Load model from ", args.checkpoint)
model = MultiPartitioningClassifier.load_from_checkpoint(
    checkpoint_path=str(args.checkpoint),
    hparams_file=str(args.hparams),
    map_location=None,
)
model.eval()
if args.gpu and torch.cuda.is_available():
    model.cuda()

print("Init dataloader")
dataloader = torch.utils.data.DataLoader(
    FiveCropImageDataset(meta_csv=None, image_dir=args.image_dir),
    batch_size=ceil(args.batch_size / 5),
    shuffle=False,
    num_workers=args.num_workers,
)
print("Number of images: ", len(dataloader.dataset))
if len(dataloader.dataset) == 0:
    raise RuntimeError(f"No images found in {args.image_dir}")

rows = []
for X in tqdm(dataloader):
    if args.gpu:
        X[0] = X[0].cuda()
    img_paths, pred_classes, pred_latitudes, pred_longitudes = model.inference(X)
    for p_key in pred_classes.keys():
        for img_path, pred_class, pred_lat, pred_lng in zip(
            img_paths,
            pred_classes[p_key].cpu().numpy(),
            pred_latitudes[p_key].cpu().numpy(),
            pred_longitudes[p_key].cpu().numpy(),
        ):
            rows.append(
                {
                    "img_id": Path(img_path).stem,
                    "p_key": p_key,
                    "pred_class": pred_class,
                    "pred_lat": pred_lat,
                    "pred_lng": pred_lng,
                }
            )
df = pd.DataFrame.from_records(rows)
df.set_index(keys=["img_id", "p_key"], inplace=True)
print(df)
fout = Path(args.checkpoint).parent / f"inference_{args.image_dir.stem}.csv"
print("Write output to", fout)
df.to_csv(fout)


This is ```train_base.py``` which is huge.

In [8]:
class MultiPartitioningClassifier(pl.LightningModule):
    def __init__(self, hparams: Namespace):
        super().__init__()
        self.hparams = hparams

        self.partitionings, self.hierarchy = self.__init_partitionings()
        self.model, self.classifier = self.__build_model()

    def __init_partitionings(self):

        partitionings = []
        for shortname, path in zip(
            self.hparams.partitionings["shortnames"],
            self.hparams.partitionings["files"],
        ):
            partitionings.append(Partitioning(Path(path), shortname, skiprows=2))

        if len(self.hparams.partitionings["files"]) == 1:
            return partitionings, None

        return partitionings, Hierarchy(partitionings)

    def __build_model(self):
        logging.info("Build model")
        model, nfeatures = utils_global.build_base_model(self.hparams.arch)

        classifier = torch.nn.ModuleList(
            [
                torch.nn.Linear(nfeatures, len(self.partitionings[i]))
                for i in range(len(self.partitionings))
            ]
        )

        if self.hparams.weights:
            logging.info("Load weights from pre-trained model")
            model, classifier = utils_global.load_weights_if_available(
                model, classifier, self.hparams.weights
            )

        return model, classifier

    def forward(self, x):
        fv = self.model(x)
        yhats = [self.classifier[i](fv) for i in range(len(self.partitionings))]
        return yhats

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        images, target = batch

        if not isinstance(target, list) and len(target.shape) == 1:
            target = [target]

        # forward pass
        output = self(images)

        # individual losses per partitioning
        losses = [
            torch.nn.functional.cross_entropy(output[i], target[i])
            for i in range(len(output))
        ]

        loss = sum(losses)

        # stats
        losses_stats = {
            f"loss_train/{p}": l
            for (p, l) in zip([p.shortname for p in self.partitionings], losses)
        }
        for metric_name, metric_value in losses_stats.items():
            self.log(metric_name, metric_value, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, **losses_stats}

    def validation_step(self, batch, batch_idx):
        images, target, true_lats, true_lngs = batch

        if not isinstance(target, list) and len(target.shape) == 1:
            target = [target]

        # forward
        output = self(images)

        # loss calculation
        losses = [
            torch.nn.functional.cross_entropy(output[i], target[i])
            for i in range(len(output))
        ]

        loss = sum(losses)

        # log top-k accuracy for each partitioning
        individual_accuracy_dict = utils_global.accuracy(
            output, target, [p.shortname for p in self.partitionings]
        )
        # log loss for each partitioning
        individual_loss_dict = {
            f"loss_val/{p}": l
            for (p, l) in zip([p.shortname for p in self.partitionings], losses)
        }

        # log GCD error@km threshold
        distances_dict = {}

        if self.hierarchy is not None:
            hierarchy_logits = [
                yhat[:, self.hierarchy.M[:, i]] for i, yhat in enumerate(output)
            ]
            hierarchy_logits = torch.stack(hierarchy_logits, dim=-1,)
            hierarchy_preds = torch.prod(hierarchy_logits, dim=-1)

        pnames = [p.shortname for p in self.partitionings]
        if self.hierarchy is not None:
            pnames.append("hierarchy")
        for i, pname in enumerate(pnames):
            # get predicted coordinates
            if i == len(self.partitionings):
                i = i - 1
                pred_class_indexes = torch.argmax(hierarchy_preds, dim=1)
            else:
                pred_class_indexes = torch.argmax(output[i], dim=1)
            pred_latlngs = [
                self.partitionings[i].get_lat_lng(idx)
                for idx in pred_class_indexes.tolist()
            ]
            pred_lats, pred_lngs = map(list, zip(*pred_latlngs))
            pred_lats = torch.tensor(pred_lats, dtype=torch.float)
            pred_lngs = torch.tensor(pred_lngs, dtype=torch.float)
            # calculate error
            distances = utils_global.vectorized_gc_distance(
                pred_lats,
                pred_lngs,
                true_lats.type_as(pred_lats),
                true_lngs.type_as(pred_lats),
            )
            distances_dict[f"gcd_{pname}_val"] = distances

        output = {
            "loss_val/total": loss,
            **individual_accuracy_dict,
            **individual_loss_dict,
            **distances_dict,
        }
        return output

    def validation_epoch_end(self, outputs):
        pnames = [p.shortname for p in self.partitionings]

        # top-k accuracy and loss per partitioning
        loss_acc_dict = utils_global.summarize_loss_acc_stats(pnames, outputs)

        # GCD stats per partitioning
        gcd_dict = utils_global.summarize_gcd_stats(pnames, outputs, self.hierarchy)

        metrics = {
            "val_loss": loss_acc_dict["loss_val/total"],
            **loss_acc_dict,
            **gcd_dict,
        }
        for metric_name, metric_value in metrics.items():
            self.log(metric_name, metric_value, logger=True)

    def _multi_crop_inference(self, batch):
        images, meta_batch = batch
        cur_batch_size = images.shape[0]
        ncrops = images.shape[1]

        # reshape crop dimension to batch
        images = torch.reshape(images, (cur_batch_size * ncrops, *images.shape[2:]))

        # forward pass
        yhats = self(images)
        yhats = [torch.nn.functional.softmax(yhat, dim=1) for yhat in yhats]

        # respape back to access individual crops
        yhats = [
            torch.reshape(yhat, (cur_batch_size, ncrops, *list(yhat.shape[1:])))
            for yhat in yhats
        ]

        # calculate max over crops
        yhats = [torch.max(yhat, dim=1)[0] for yhat in yhats]

        hierarchy_preds = None
        if self.hierarchy is not None:
            hierarchy_logits = torch.stack(
                [yhat[:, self.hierarchy.M[:, i]] for i, yhat in enumerate(yhats)],
                dim=-1,
            )
            hierarchy_preds = torch.prod(hierarchy_logits, dim=-1)

        return yhats, meta_batch, hierarchy_preds

    def inference(self, batch):

        yhats, meta_batch, hierarchy_preds = self._multi_crop_inference(batch)

        if self.hierarchy is not None:
            nparts = len(self.partitionings) + 1
        else:
            nparts = len(self.partitionings)

        pred_class_dict = {}
        pred_lat_dict = {}
        pred_lng_dict = {}
        for i in range(nparts):
            # get pred class indices
            if self.hierarchy is not None and i == len(self.partitionings):
                pname = "hierarchy"
                pred_classes = torch.argmax(hierarchy_preds, dim=1)
                i = i - 1
            else:
                pname = self.partitionings[i].shortname
                pred_classes = torch.argmax(yhats[i], dim=1)

            # calculate GCD
            pred_lats, pred_lngs = map(
                list,
                zip(
                    *[
                        self.partitionings[i].get_lat_lng(c)
                        for c in pred_classes.tolist()
                    ]
                ),
            )
            pred_lats = torch.tensor(pred_lats, dtype=torch.float)
            pred_lngs = torch.tensor(pred_lngs, dtype=torch.float)
            pred_lat_dict[pname] = pred_lats
            pred_lng_dict[pname] = pred_lngs
            pred_class_dict[pname] = pred_classes

        return meta_batch["img_path"], pred_class_dict, pred_lat_dict, pred_lng_dict

    def test_step(self, batch, batch_idx, dataloader_idx=None):

        yhats, meta_batch, hierarchy_preds = self._multi_crop_inference(batch)

        distances_dict = {}
        if self.hierarchy is not None:
            nparts = len(self.partitionings) + 1
        else:
            nparts = len(self.partitionings)

        for i in range(nparts):
            # get pred class indices
            if self.hierarchy is not None and i == len(self.partitionings):
                pname = "hierarchy"
                pred_classes = torch.argmax(hierarchy_preds, dim=1)
                i = i - 1
            else:
                pname = self.partitionings[i].shortname
                pred_classes = torch.argmax(yhats[i], dim=1)

            # calculate GCD
            pred_lats, pred_lngs = map(
                list,
                zip(
                    *[
                        self.partitionings[i].get_lat_lng(c)
                        for c in pred_classes.tolist()
                    ]
                ),
            )
            pred_lats = torch.tensor(pred_lats, dtype=torch.float)
            pred_lngs = torch.tensor(pred_lngs, dtype=torch.float)

            distances = utils_global.vectorized_gc_distance(
                pred_lats,
                pred_lngs,
                meta_batch["latitude"].type_as(pred_lats),
                meta_batch["longitude"].type_as(pred_lngs),
            )
            distances_dict[pname] = distances

        return distances_dict

    def test_epoch_end(self, outputs):
        result = utils_global.summarize_test_gcd(
            [p.shortname for p in self.partitionings], outputs, self.hierarchy
        )
        return {**result}

    def configure_optimizers(self):

        optim_feature_extrator = torch.optim.SGD(
            self.parameters(), **self.hparams.optim["params"]
        )

        return {
            "optimizer": optim_feature_extrator,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.MultiStepLR(
                    optim_feature_extrator, **self.hparams.scheduler["params"]
                ),
                "interval": "epoch",
                "name": "lr",
            },
        }

    def train_dataloader(self):

        with open(self.hparams.train_label_mapping, "r") as f:
            target_mapping = json.load(f)

        tfm = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.RandomResizedCrop(224, scale=(0.66, 1.0)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                ),
            ]
        )

        dataset = MsgPackIterableDatasetMultiTargetWithDynLabels(
            path=self.hparams.msgpack_train_dir,
            target_mapping=target_mapping,
            key_img_id=self.hparams.key_img_id,
            key_img_encoded=self.hparams.key_img_encoded,
            shuffle=True,
            transformation=tfm,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers_per_loader,
            pin_memory=True,
        )
        return dataloader

    def val_dataloader(self):

        with open(self.hparams.val_label_mapping, "r") as f:
            target_mapping = json.load(f)

        tfm = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(256),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                ),
            ]
        )
        dataset = MsgPackIterableDatasetMultiTargetWithDynLabels(
            path=self.hparams.msgpack_val_dir,
            target_mapping=target_mapping,
            key_img_id=self.hparams.key_img_id,
            key_img_encoded=self.hparams.key_img_encoded,
            shuffle=False,
            transformation=tfm,
            meta_path=self.hparams.val_meta_path,
            cache_size=1024,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers_per_loader,
            pin_memory=True,
        )

        return dataloader


def parse_args():
    args = ArgumentParser()
    args.add_argument("-c", "--config", type=Path, default=Path("config/baseM.yml"))
    args.add_argument("--progbar", action="store_true")
    return args.parse_args()


def main():
    args = parse_args()
    logging.basicConfig(level=logging.INFO)

    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    model_params = config["model_params"]
    trainer_params = config["trainer_params"]

    utils_global.check_is_valid_torchvision_architecture(model_params["arch"])

    out_dir = Path(config["out_dir"]) / datetime.now().strftime("%y%m%d-%H%M")
    out_dir.mkdir(exist_ok=True, parents=True)
    logging.info(f"Output directory: {out_dir}")

    # init classifier
    model = MultiPartitioningClassifier(hparams=Namespace(**model_params))

    logger = pl.loggers.TensorBoardLogger(save_dir=str(out_dir), name="tb_logs")
    checkpoint_dir = out_dir / "ckpts" / "{epoch:03d}-{val_loss:.4f}"
    checkpointer = pl.callbacks.model_checkpoint.ModelCheckpoint(checkpoint_dir)

    progress_bar_refresh_rate = 0
    if args.progbar:
        progress_bar_refresh_rate = 1

    trainer = pl.Trainer(
        **trainer_params,
        logger=logger,
        val_check_interval=model_params["val_check_interval"],
        checkpoint_callback=checkpointer,
        progress_bar_refresh_rate=progress_bar_refresh_rate,
    )

    trainer.fit(model)


if __name__ == "__main__":
    main()


NameError: name 'pl' is not defined