In [1]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import requests
from PIL import Image
import numpy as np
from io import BytesIO

In [2]:
batch_size = 32
target_size = (1024, 1024)

trainDB_url = "https://datasets-server.huggingface.co/first-rows?dataset=osunlp%2FMagicBrush&config=default&split=train"
devDB_url = "https://datasets-server.huggingface.co/first-rows?dataset=osunlp%2FMagicBrush&config=default&split=dev"

In [3]:
res = requests.get(trainDB_url)
features = res.json()['features']
data = res.json()['rows']

In [4]:
train_instructions = []
train_input_imgs = []
train_output_imgs = []

In [5]:
for d in data:
  # Getting the instruction from dataset
  train_instructions.append(d['row']['instruction'])

  # downloading the source image from the url
  res = requests.get(d['row']['source_img']['src'])
  img = Image.open(BytesIO(res.content))
  img = img.resize(target_size)
  input_img = np.array(img)

  # downloading the target image from the url
  res = requests.get(d['row']['target_img']['src'])
  img = Image.open(BytesIO(res.content))
  img = img.resize(target_size)
  output_img = np.array(img)

  train_input_imgs.append(input_img)
  train_output_imgs.append(output_img)

In [6]:
res = requests.get(devDB_url)
features = res.json()['features']
data = res.json()['rows']

In [7]:
dev_instructions = []
dev_input_imgs = []
dev_output_imgs = []

In [8]:
for d in data:
  # Getting the instruction from dataset
  dev_instructions.append(d['row']['instruction'])

  # downloading the source image from the url
  res = requests.get(d['row']['source_img']['src'])
  img = Image.open(BytesIO(res.content))
  img = img.resize(target_size)
  input_img = np.array(img)

  # downloading the target image from the url
  res = requests.get(d['row']['target_img']['src'])
  img = Image.open(BytesIO(res.content))
  img = img.resize(target_size)
  output_img = np.array(img)

  dev_input_imgs.append(input_img)
  dev_output_imgs.append(output_img)

In [9]:
train_dataset = []
l = len(train_input_imgs)
for i in range(l):
    train_dataset.append(
        (train_input_imgs[i], train_instructions[i], train_output_imgs[i])
    )


dev_dataset = []
l = len(dev_input_imgs)
for i in range(l):
    dev_dataset.append((dev_input_imgs[i], dev_instructions[i], dev_output_imgs[i]))

In [10]:
train_dataloader = DataLoader(
    train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size
)

validation_dataloader = DataLoader(
    dev_dataset, sampler=SequentialSampler(dev_dataset), batch_size=batch_size
)