In [None]:
print("test nbstripout")

In [None]:
images = "../data/01_raw/hateful_memes/img"

In [None]:
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from typing import Optional

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class Memes_PL(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Resize((224, 224, 3))
        ])
        
        # self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (3, 224, 224)

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

            # Optionally...
            # self.dims = tuple(self.mnist_train[0][0].shape)

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

            # Optionally...
            # self.dims = tuple(self.mnist_test[0][0].shape)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [None]:
import torch
data_loader = torch.utils.data.DataLoader("../data/01_raw/hateful_memes/img")

In [None]:
next(data_loader)

In [None]:
from typing import Union
from pathlib import Path
import pandas as pd
from skimage import io, transform


class Memes(torch.utils.data.Dataset):
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.info = pd.read_json(self.root_dir/"train.jsonl", lines=True)
        # self.imgs = 

    def __len__(self):
        return len(self.info)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = self.info.iloc[idx]
        img_path = self.root_dir/data['img']
        image = io.imread(img_path)

        text = data['text']
        label = data['label']

        sample = dict(
            image=image,
            text=text,
            label=label,
        )

        if self.transform:
            self.transform(sample)
        return sample
            


In [None]:
import matplotlib.pyplot as plt
fig = plt.figure()
ds = Memes("../data/01_raw/hateful_memes")
# print(ds[0])
for i in range(len(ds)):
    sample = ds[i+10]

    print(i, sample['image'].shape, sample['text'])
    # print(sample)

    ax = plt.subplot(1, 2, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    plt.imshow(sample['image'])
    # show_landmarks(**sample)

    if i == 1:
        plt.show()
        break

In [None]:
import pandas as pd
df = pd.read_json("../data/01_raw/hateful_memes/train.jsonl", lines=True)

In [None]:
df.head()

In [None]:
bad_images = []
print(len(df))
for i in df['img']:
    try:
        io.imread(f"../data/01_raw/hateful_memes/{i}")
    except:
        bad_images.append(i)
print(len(bad_images))

In [None]:
print(bad_images[0:10])