# Writing a Dataset to AIS in WDs format 

In this notebook we will download and store the following datasets in [WebDataset](https://github.com/webdataset/webdataset) format in AIS:

- [The Oxford-IIIT Pet Dataset](https://academictorrents.com/details/b18bbd9ba03d50b0f7f479acc9f4228a408cecc1)
- [Flickr Image dataset](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)

In [None]:
pip install aistore

## Setting Up Client

In [None]:
import os
from aistore.client import Client

ais_url = os.getenv("AIS_ENDPOINT", "http://localhost:8080")
client = Client(ais_url)

## The Oxford-IIIT Pet Dataset

### Downloading the Dataset

In [None]:
import requests
import tarfile
import os


def download_and_extract(url, dest_path):
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(dest_path, "wb") as f:
            f.write(response.raw.read())
        with tarfile.open(dest_path) as tar:
            tar.extractall(path=os.path.dirname(dest_path))
        os.remove(dest_path)  # Clean up the tar file after extraction

In [None]:
base_url = "http://www.robots.ox.ac.uk/~vgg/data/pets/data"
images_url = f"{base_url}/images.tar.gz"
annotations_url = f"{base_url}/annotations.tar.gz"

data_dir = "/data"
images_path = os.path.join(data_dir, "images.tar.gz")
annotations_path = os.path.join(data_dir, "annotations.tar.gz")

if not os.path.exists(data_dir):
    os.makedirs(data_dir)

download_and_extract(images_url, images_path)
download_and_extract(annotations_url, annotations_path)

### Creating a bucket and writing the dataset

In [None]:
from pathlib import Path
from aistore.sdk.dataset.dataset_config import DatasetConfig
from aistore.sdk.dataset.data_attribute import DataAttribute
from aistore.sdk.dataset.label_attribute import LabelAttribute

bucket = client.bucket("pets-dataset").create(exist_ok=True)
base_path = Path("/data")

In [None]:
# Function to get label from the annotation file


def get_class_dict(path: Path):
    parsed_dict = {}
    with open(path, "r", encoding="utf-8") as file:
        for line in file.readlines():
            if line[0] == "#":
                continue
            file_name, label = line.split(" ")[:2]
            parsed_dict[file_name] = label

    return parsed_dict


parsed_dict = get_class_dict(base_path.joinpath("annotations").joinpath("list.txt"))


def get_label_for_filename(filename):
    return parsed_dict.get(filename, None)

In [None]:
dataset_config = DatasetConfig(
    primary_attribute=DataAttribute(
        path=base_path.joinpath("images"), file_type="jpg", name="image"
    ),
    secondary_attributes=[
        DataAttribute(
            path=base_path.joinpath("annotations").joinpath("trimaps"),
            file_type="png",
            name="trimap",
        ),
        LabelAttribute(
            name="cls",
            label_identifier=get_label_for_filename,
        ),
    ],
)

bucket.write_dataset(config=dataset_config, pattern="img_dataset", maxcount=1000)

## Flickr Image dataset

### Downloading the Dataset

**NOTE:** We are using the [kaggle API](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md) to download the dataset. 

In [None]:
pip install kaggle

In [None]:
!kaggle datasets download -d hsankesara/flickr-image-dataset -p /data --unzip

### Creating a bucket and writing the dataset

In [None]:
from pathlib import Path
from aistore.sdk.dataset.dataset_config import DatasetConfig
from aistore.sdk.dataset.data_attribute import DataAttribute
from aistore.sdk.dataset.label_attribute import LabelAttribute

bucket = client.bucket("flickr-dataset").create(exist_ok=True)
base_path = Path("/data")

In [None]:
# Function to get the caption from results.csv
def parse_csv(path: Path):
    parsed_dict = {}
    with open(path, "r", encoding="utf-8") as file:
        for line in file:
            splitted = line.split("|")
            if len(splitted) < 3:
                continue
            filename = splitted[0].strip().split(".")[0]
            caption = splitted[2].strip()
            parsed_dict[filename] = caption
    return parsed_dict


parsed_dict = parse_csv(base_path.joinpath("flickr30k_images/results.csv"))


def get_caption_for_filename(filename):
    return parsed_dict.get(filename, None)

In [None]:
dataset_config = DatasetConfig(
    primary_attribute=DataAttribute(
        path=base_path.joinpath("flickr30k_images/flickr30k_images"),
        file_type="jpg",
        name="image",
    ),
    secondary_attributes=[
        LabelAttribute(
            name="caption",
            label_identifier=get_caption_for_filename,
        ),
    ],
)

bucket.write_dataset(config=dataset_config, pattern="flickr_dataset", maxcount=1000)