In [1]:
from functools import partial
import gc
import logging
import nltk
import numpy as np
import pandas as pd
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import math
import json
from flax.serialization import to_bytes, from_bytes

import shutil
import torch
from transformers.file_utils import PushToHubMixin
from datasets import load_metric
from torchvision.datasets import VisionDataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, GaussianBlur
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm

import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key, get_metrics, onehot
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
from transformers import MarianTokenizer,MBart50TokenizerFast, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed

class Transform(torch.nn.Module):
    def __init__(self, image_size):
        super().__init__()

        self.transforms = torch.nn.Sequential(
                    Resize([image_size], interpolation=InterpolationMode.BICUBIC),
                    CenterCrop(image_size),
                    ConvertImageDtype(torch.float),
                    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
        return x

class ImageTextDataset(VisionDataset):
    def __init__(
        self,
        root: str,
        file_path: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
        max_samples: int = None
    ):
        super().__init__(root, transforms, transform, target_transform)

        # self.captions = []
        # self.image_paths = []
        # self.lang = []

        examples = pd.read_csv(file_path, sep='\t')
        gc.collect()

        self.map_lang_code = {
            "en": "en_XX",
            "de": "de_DE",
            "fr": "fr_XX",
            "es": "es_XX",
        }

        # for idx,img_file in enumerate(examples["image_file"].values):
        #     if os.path.exists(os.path.join(self.root,img_file)):
        #     self.image_paths.append(img_file)
        #     self.captions.append(examples["caption"].values[idx])
        #     self.lang.append(examples["lang_id"].values[idx])

        self.image_paths = examples["image_path"].values
        self.captions = examples["captions"].values

        if max_samples is None:
            max_samples = len(self.image_paths)

        self.image_paths = self.image_paths[:max_samples]
        self.captions = self.captions[:max_samples]

        # with open(file_path, encoding="utf-8") as fd:
        #     examples = csv.DictReader(fd, delimiter="\t", quotechar='"')
        #     for row in examples:
        #         self.image_paths.append(os.path.join(self.root,row["image_file"]))
        #         self.captions.append(row["caption"])
        #         self.lang.append(row["lang_id"])


    def _load_image(self, idx: int):
        path = self.image_paths[idx]
        return read_image(os.path.join(self.root,path), mode=ImageReadMode.RGB)

    def _load_target(self, idx):
        return self.captions[idx]

    def __getitem__(self, index: int):
        image = self._load_image(index)
        target = self._load_target(index)

        if self.transforms is not None:
            image, target = self.transforms(image, target)
            
        print(image)
        return image, target,

    def __len__(self) -> int:
        return len(self.captions)
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained('munggok/image-captioning-marian')
config = model.config
preprocess = Transform(config.clip_vision_config.image_size)
preprocess = torch.jit.script(preprocess)
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-id')


INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [22]:
image_size = config.clip_vision_config.image_size

In [23]:
transforms = torch.nn.Sequential(
                    Resize([image_size], interpolation=InterpolationMode.BICUBIC),
                    CenterCrop(image_size),
                    ConvertImageDtype(torch.float),
                    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                )

In [24]:
read_trans = read_image('000000512774.jpg', mode=ImageReadMode.RGB)

In [25]:
read_trans = transforms(read_trans)

In [26]:
img = torch.stack([read_trans]).permute(0, 2, 3, 1).numpy()

In [28]:
import numpy as np
pixel_values = np.load('data.npy')

In [29]:
img == pixel_values

array([[[[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         ...,
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]],

        [[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         ...,
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]],

        [[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         ...,
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]],

        ...,

        [[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         ...,
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]],

        [[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         ...,
         [ True,  True,  True],
         [ T

In [2]:
train_dataset = ImageTextDataset(
    'data',
    'data/dev.tsv',
    transform=preprocess
)

In [3]:
train_dataset

Dataset ImageTextDataset
    Number of datapoints: 3
    Root location: data
    StandardTransform
Transform: RecursiveScriptModule(
             original_name=Transform
             (transforms): RecursiveScriptModule(
               original_name=Sequential
               (0): RecursiveScriptModule(original_name=Resize)
               (1): RecursiveScriptModule(original_name=CenterCrop)
               (2): RecursiveScriptModule(original_name=ConvertImageDtype)
               (3): RecursiveScriptModule(original_name=Normalize)
             )
           )

In [17]:
def collate_fn_val(examples):
        pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
        captions = [example[1] for example in examples]

        # tokenizer = map_tokenizer_lang[lang_id[0]]
          # every validation loader has same language
        with tokenizer.as_target_tokenizer():
            tokens = tokenizer(captions, max_length=16, padding="max_length", return_tensors="np", truncation=True)

        # had to create another enum of sorts for lang_id
        # lang_id = np.array([map_lang_num[lang] for lang in lang_id])  # str of type <class 'numpy.ndarray'> is not a valid JAX type
        decoder_input_ids = shift_tokens_right(tokens["input_ids"], config.marian_config.pad_token_id)
        print(captions)
        batch = {
            "pixel_values": pixel_values,
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "decoder_input_ids": decoder_input_ids,
            # "lang": lang_id,
        }
        return batch

In [18]:
def shift_tokens_right(input_ids: np.array, pad_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = np.zeros(input_ids.shape, dtype=np.int64)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = pad_token_id
    return shifted_input_ids

In [19]:
eval_loader = None

In [20]:
eval_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        num_workers=0,
        drop_last=True,
        collate_fn=collate_fn_val,
    )

In [21]:
for asa in eval_loader:
    print(asa)

tensor([[[ 1.6092,  1.7260,  1.1128,  ...,  1.9303,  1.9303,  1.9303],
         [ 1.5800,  1.7990,  1.7114,  ...,  1.9303,  1.9303,  1.9303],
         [ 1.6238,  1.5654,  1.8573,  ...,  1.9303,  1.9303,  1.9303],
         ...,
         [-0.8142, -1.0039, -0.9602,  ...,  0.4267,  0.4559,  0.4267],
         [-0.5222, -0.7704, -1.0769,  ...,  0.4121,  0.3975,  0.2661],
         [-0.6536, -0.9018, -0.9893,  ...,  0.4559,  0.3829,  0.1201]],

        [[ 1.7747,  1.8498,  1.2945,  ...,  2.0749,  2.0749,  2.0749],
         [ 1.8047,  1.9698,  1.9398,  ...,  2.0749,  2.0749,  2.0749],
         [ 1.8047,  1.6547,  2.0299,  ...,  2.0749,  2.0749,  2.0749],
         ...,
         [-0.9117, -1.1218, -1.2869,  ..., -0.8967, -0.8816, -0.9267],
         [-0.3864, -0.7616, -1.0918,  ..., -0.9417, -0.8816, -0.9717],
         [-0.7466, -1.1368, -1.0467,  ..., -0.9117, -0.8666, -1.0167]],

        [[ 1.7762,  1.7904,  1.2358,  ...,  2.1459,  2.1459,  2.1459],
         [ 1.8331,  1.9610,  1.8615,  ...,  2