In [1]:
from datasets import load_dataset
import os
import requests
import cv2
import numpy as np
import torch
import torchvision.transforms
from datasets import load_dataset
from transformers import T5Tokenizer

dataset = load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT")

# 初始化分词器
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# 定义图像尺寸
img_size = 224  # 或适合您模型的图像尺寸

# 定义设备（GPU或CPU）
device = "cuda" if torch.cuda.is_available() else "cpu"

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [2]:
dataset = dataset['train']
dataset

Dataset({
    features: ['URL', 'TEXT'],
    num_rows: 591753
})

In [3]:
len(dataset)

591753

In [55]:
dataset[0]

{'URL': 'http://images.cocodataset.org/train2017/000000391895.jpg',
 'TEXT': 'A man with a red helmet on a small moped on a dirt road. '}

In [56]:
example = dataset[:100]
example

{'URL': ['http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000522418.jpg',
  'http://images.cocodataset.org/train2017/000000522418.jpg',
  'http://images.cocodataset.org/train2017/000000522418.jpg',
  'http://images.cocodataset.org/train2017/000000522418.jpg',
  'http://images.cocodataset.org/train2017/000000522418.jpg',
  'http://images.cocodataset.org/train2017/000000184613.jpg',
  'http://images.cocodataset.org/train2017/000000184613.jpg',
  'http://images.cocodataset.org/train2017/000000184613.jpg',
  'http://images.cocodataset.org/train2017/000000184613.jpg',
  'http://images.cocodataset.org/train2017/000000184613.jpg',
  'http://images.cocodataset.org/train2017/000000318219.jpg',
 

In [57]:
def slice_datadict(dct, start_idx, end_idx):
    slice_dict = {}
    keys = list(dct.keys())
    for key in keys:
        slice_dict[key] = dct[key][start_idx:end_idx]
    return slice_dict

In [58]:
minibatch = slice_datadict(example,0,5)
type(minibatch)

dict

In [59]:
minibatch

{'URL': ['http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg',
  'http://images.cocodataset.org/train2017/000000391895.jpg'],
 'TEXT': ['A man with a red helmet on a small moped on a dirt road. ',
  'Man riding a motor bike on a dirt road on the countryside.',
  'A man riding on the back of a motorcycle.',
  'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ',
  'A man in a red shirt and a red hat is on a motorcycle on a hill side.']}

In [60]:
def process_batch(minibatch, tokenizer, img_size, device):
    augmented_imgs = []
    captions = []
    value_list = list(minibatch.values())
    url_list = value_list[0]
    cap_list = value_list[1]
    assert len(url_list) == len(cap_list)

    for url,cap in zip(url_list,cap_list):
        print(f"url: {url}")
        print(f"caption: {cap}")
        response = requests.get(url)
        if response.status_code == 200:
            img_data = response.content
        else:
            print(f"Failed to fetch image from URL. Status code: {response.status_code}")
            continue 
        # img_data = response.content
        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), -1)
        resize_shape = (img_size, img_size)
        img = cv2.resize(img, resize_shape, interpolation=cv2.INTER_LINEAR)
        img = np.float32(img) / 255
        img = torch.tensor(img)
        img = img.permute(2, 1, 0)  # [w, h, c] -> [c, h, w]
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(int(1.25 * img_size)),  # image_size + 1/4 * image_size
            torchvision.transforms.RandomResizedCrop(resize_shape, scale=(0.8, 1.0)),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # zero mean, unit std
        ])
        img = transforms(img)
        augmented_imgs.append(img)

        caption_tokens_dict = tokenizer(cap, return_tensors='pt', padding=True, truncation=True)
        captions.append(caption_tokens_dict)
    
    return augmented_imgs, captions




In [61]:
type(minibatch)

dict

In [62]:
def process_batch_(minibatch, img_size):
    """process the url

    Parameters
    ----------
    minibatch: Dict
        key: ['URL','text']
        value: list[URL],list[text]

    img_size: int
        the size of image
    
    
    Returns
    -------
    augmented_imgs: List
        length of augmented_imgs: batch
    captions: List
        length of caption: batch
    """
    value_list = list(minibatch.values())
    url_list = value_list[0]
    captions = value_list[1]
    augmented_imgs = []
    #processing 
    for url,cap in zip(url_list,captions):
        print(f"processing url: {url}")
        # print(f"caption: {cap}")
        response = requests.get(url)
        if response.status_code == 200:
            img_data = response.content
        else:
            print(f"Failed to fetch image from URL. Status code: {response.status_code}")
            continue 
        # img_data = response.content
        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), -1)
        resize_shape = (img_size, img_size)
        img = cv2.resize(img, resize_shape, interpolation=cv2.INTER_LINEAR)
        img = np.float32(img) / 255
        img = torch.tensor(img)
        img = img.permute(2, 1, 0)  # [w, h, c] -> [c, h, w]
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(int(1.25 * img_size)),  # image_size + 1/4 * image_size
            torchvision.transforms.RandomResizedCrop(resize_shape, scale=(0.8, 1.0)),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # zero mean, unit std
        ])
        img = transforms(img)
        augmented_imgs.append(img)
        
    return augmented_imgs, captions
    


In [63]:
augmented_imgs, captions = process_batch_(minibatch,img_size)

processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg


In [64]:
augmented_imgs

[tensor([[[ 1.0000,  1.0000,  1.0000,  ..., -0.2563, -0.1477, -0.0945],
          [ 1.0000,  1.0000,  1.0000,  ..., -0.0134, -0.0268, -0.0477],
          [ 1.0000,  1.0000,  1.0000,  ...,  0.1766,  0.0571, -0.0625],
          ...,
          [ 1.0000,  1.0000,  1.0000,  ..., -0.7622, -0.6470, -0.5478],
          [ 0.9999,  0.9979,  0.9964,  ..., -0.7698, -0.6579, -0.5655],
          [ 0.9949,  0.9908,  0.9906,  ..., -0.7948, -0.7225, -0.6816]],
 
         [[ 1.0000,  1.0000,  1.0000,  ...,  0.0276,  0.2057,  0.2529],
          [ 1.0000,  1.0000,  1.0000,  ...,  0.2291,  0.2554,  0.2250],
          [ 1.0000,  1.0000,  1.0000,  ...,  0.4288,  0.2953,  0.1307],
          ...,
          [ 1.0000,  1.0000,  1.0000,  ..., -0.4311, -0.4137, -0.3255],
          [ 0.9999,  0.9979,  0.9964,  ..., -0.4335, -0.3547, -0.3042],
          [ 0.9973,  0.9923,  0.9906,  ..., -0.5055, -0.4235, -0.3838]],
 
         [[ 1.0000,  1.0000,  1.0000,  ...,  0.2502,  0.3998,  0.4102],
          [ 1.0000,  1.0000,

In [65]:
captions

['A man with a red helmet on a small moped on a dirt road. ',
 'Man riding a motor bike on a dirt road on the countryside.',
 'A man riding on the back of a motorcycle.',
 'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ',
 'A man in a red shirt and a red hat is on a motorcycle on a hill side.']

In [98]:
def load_img2cap(batch_size,dataset,tokenizer,img_size, device):
    """load the image-caption dataset and return torch.Tensor

    Parameters
    ----------
    batch_size: int

    dataset: DataDict

    tokenizer: T5

    device: string -- torch.device
        cuda or cpu

    Returns
    -------
    img_tenosr: torch.Tensor
    caption_tenosr: torch.Tensor

    """
    img_list = []
    caption_list = []
    n = len(list(dataset.values())[0])
    for i in range(0, n, batch_size):
        minibatch = slice_datadict(dataset, i, i+batch_size)
        augmented_imgs, captions = process_batch_(minibatch, img_size)
        img_list.extend(augmented_imgs)
        caption_list.extend(captions)
        print("-----------------------")
    
    img_tensor = torch.stack(img_list, dim=0).to(device)
    # caption = tokenizer(caption_list, padding=True, truncation=True, return_tensors="pt")
    # caption = {key: val.to(device) for key, val in caption.items()}
    caption = {
        key: val.to(device) if isinstance(val, torch.Tensor) else val
        for key, val in tokenizer(caption_list, padding=True, truncation=True, return_tensors="pt").items()
    }
    
    return img_tensor,caption 
    

In [99]:
example = dataset[:30]
img, caption = load_img2cap(
    batch_size=5,
    dataset=example,
    tokenizer=tokenizer,
    img_size=img_size,
    device=device,
)

processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
processing url: http://images.cocodataset.org/train2017/000000391895.jpg
-----------------------
processing url: http://images.cocodataset.org/train2017/000000522418.jpg
processing url: http://images.cocodataset.org/train2017/000000522418.jpg
processing url: http://images.cocodataset.org/train2017/000000522418.jpg
processing url: http://images.cocodataset.org/train2017/000000522418.jpg
processing url: http://images.cocodataset.org/train2017/000000522418.jpg
-----------------------
processing url: http://images.cocodataset.org/train2017/000000184613.jpg
processing url: http://images.cocodataset.org/train2017/000000184613.jpg
processing url: http://images.cocodataset.org/train2017/000000184613.jpg
pro

In [100]:
img.shape

torch.Size([30, 3, 224, 224])

In [101]:
type(caption)

dict

In [102]:
caption

{'input_ids': tensor([[  71,  388,   28,  ...,    0,    0,    0],
         [1140, 7494,    3,  ...,    0,    0,    0],
         [  71,  388, 7494,  ...,    0,    0,    0],
         ...,
         [   3,    9, 2335,  ...,    0,    0,    0],
         [  71, 2335,   19,  ...,    0,    0,    0],
         [  71, 2335,   19,  ...,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]])}

In [103]:
input_ids = caption["input_ids"]
input_ids.shape

torch.Size([30, 45])

In [104]:
caption.input_ids.shape

AttributeError: 'dict' object has no attribute 'input_ids'

In [105]:
print(f"pad: {tokenizer.pad_token}")

pad: <pad>


In [106]:
start_token_id = tokenizer(tokenizer.pad_token, return_tensors='pt', padding=False, truncation=True).input_ids
start_token_id

tensor([[0, 1]])

In [107]:
start_token_id = start_token_id[:, 0]
start_token_id

tensor([0])

In [108]:
start_token_id = start_token_id.expand(5, -1).to(device)
start_token_id

tensor([[0],
        [0],
        [0],
        [0],
        [0]])