forked from SiddhanthHegde/You-Need-to-Pay-More-Attention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
56 lines (48 loc) · 1.49 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from PIL import Image
from torch.utils.data import DataLoader
class TamilDataset(torch.utils.data.Dataset):
def __init__(self,df,tokenizer,max_len,path,transforms=None):
self.data_dir = path
self.df = df
self.tokenizer = tokenizer
self.transforms = transforms
self.max_len = max_len
def __len__(self):
return self.df.shape[0]
def __getitem__(self,index):
img_name, captions = self.df.iloc[index]
img_path = os.path.join(self.data_dir,img_name)
labels = 0 if img_name.startswith('N') else 1
img = Image.open(img_path).convert('RGB')
if self.transforms is not None:
img = self.transforms(img)
encoding = self.tokenizer.encode_plus(
captions,
add_special_tokens=True,
max_length = self.max_len,
return_token_type_ids = False,
padding = 'max_length',
return_attention_mask= True,
return_tensors='pt',
truncation=True
)
return {
'image' : img,
'text' : captions,
'input_ids' : encoding['input_ids'].flatten(),
'attention_mask' : encoding['attention_mask'].flatten(),
'label' : torch.tensor(labels,dtype=torch.float)
}
def create_data_loader(df,tokenizer,max_len,batch_size,mytransforms,path,shuffle):
ds = TamilDataset(
df,
tokenizer,
max_len,
path,
mytransforms
)
return DataLoader(ds,
batch_size = batch_size,
shuffle=False,
num_workers=4)