-
Notifications
You must be signed in to change notification settings - Fork 0
/
CLIPDataSet.py
109 lines (81 loc) · 3.5 KB
/
CLIPDataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from torch.utils.data import Dataset
from transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel
from torchvision.transforms import transforms
import os
from PIL import Image
import pandas as pd
from tqdm import tqdm
class CLIPDataSet(Dataset):
def __init__(self, train_data: pd.DataFrame, origin_file_path: str, load_first: bool = True):
# Text Encoder
self.bpe = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
self.bpe.pad_token = self.bpe.eos_token
self.origin_file_path = origin_file_path
self.load_first = load_first
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4441, 0.4212, 0.3847], std=[0.2613, 0.2547, 0.2656])
])
# Initial tokenizing
self.image_data = []
self.text_data = []
for d in tqdm(train_data.iloc, total=len(train_data)):
if self.load_first:
image_ = self.transform(Image.open(os.path.join(self.origin_file_path, d.image_name)))
self.image_data.append(image_)
else:
self.image_data.append(os.path.join(self.origin_file_path, d.image_name))
text_ = self.bpe.bos_token + d.comment + self.bpe.eos_token
self.text_data.append(text_)
assert len(self.image_data) == len(self.text_data)
def n_vocab(self) -> int:
return self.bpe.vocab_size
def __getitem__(self, i):
if self.load_first:
image_ = self.image_data[i]
else:
image_ = self.transform(Image.open(os.path.join(self.image_data[i])))
text_ = self.text_data[i]
return image_, text_
def __len__(self):
return len(self.image_data)
class CLIPZeroShotDataSet(Dataset):
def __init__(self, train_data: pd.DataFrame, origin_file_path: str, prefix_phrase: str = '*', load_first: bool = True):
self.bpe = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
self.bpe.pad_token = self.bpe.eos_token
self.origin_file_path = origin_file_path
self.load_first = load_first
self.prefix_phrase = prefix_phrase
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4441, 0.4212, 0.3847], std=[0.2613, 0.2547, 0.2656])
])
# Initial tokenizing
self.image_data = []
self.text_data = []
self.label = []
for text_ in list(train_data['text'].unique()):
self.text_data.append(prefix_phrase.replace('*', text_))
for d in tqdm(train_data.iloc, total=len(train_data)):
if self.load_first:
image_ = self.transform(Image.open(os.path.join(self.origin_file_path, d.image_name)))
self.image_data.append(image_)
else:
self.image_data.append(os.path.join(self.origin_file_path, d.image_name))
self.label.append(d.label)
assert len(self.image_data) == len(self.label)
def n_vocab(self) -> int:
return self.bpe.vocab_size
def __getitem__(self, i):
if self.load_first:
image_ = self.image_data[i]
else:
im = Image.open(os.path.join(self.image_data[i]))
image_ = self.transform(im.convert(mode='RGB'))
label_ = self.label[i]
return image_, label_
def __len__(self):
return len(self.image_data)