Skip to content

Commit

Permalink
Merge pull request #9 from ABL-Lab/emnist_seq
Browse files Browse the repository at this point in the history
MNIST sequence's first two numbers are randomly picked
  • Loading branch information
NizarIslah committed Apr 25, 2024
2 parents cc47c1e + 66cf5db commit 12b891d
Showing 1 changed file with 81 additions and 74 deletions.
155 changes: 81 additions & 74 deletions ambiguous/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset # For custom datasets
# For sequential emnist
# from english_words import english_words_lower_alpha_set as dictionary
# from abl_expectation.utils.expectation_clamp import extract_n_letter_words
from english_words import english_words_lower_alpha_set as dictionary
from abl_expectation.utils.expectation_clamp import extract_n_letter_words
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -97,9 +97,12 @@ def __getitem__(self, index, img_size=28):
single_image_path = self.image_list[index]
im_as_np = np.load(single_image_path).astype(np.float64)/255.
im_as_ten = torch.from_numpy(im_as_np)
#clean1 = self.transform(im_as_ten[:, :, :img_size])
#amb = self.transform(im_as_ten[:, :, img_size:2*img_size])
#clean2 = self.transform(im_as_ten[:, :, 2*img_size:3*img_size])
clean1 = im_as_ten[:, :, :img_size]
amb = im_as_ten[:, :, img_size:2*img_size]
clean2 = im_as_ten[:, :, 2*img_size:3*img_size]
clean2 = im_as_ten[:, :, 2*img_size:3*img_size]
label = torch.from_numpy(np.load(self.label_list[index]))
return (clean1, amb, clean2), label

Expand Down Expand Up @@ -144,14 +147,18 @@ def __getitem__(self, index, img_size=28):
return the sequence as a tuple and the target as a tensor
"""
if self.cache:
(clean1, _, clean2), label = self.triplet_dataset[index]
target = (label[0] + label[1]) % 10
#(clean1, _, clean2), label = self.triplet_dataset[index]
label1 = random.randint(0,9)
label2 = random.randint(0,9)
clean1, label1 = self.sample(label1)
clean2, label2 = self.sample(label2)
target = (label1 + label2) % 10
cleansum1, sum_label = self.sample(target)
if self.include_irrelevant and self.ambiguous:
ambsum1, pair_label = self.sample(target, ambiguous=True) # 2nd label in the pair
ambsum2, _ = self.sample(pair_label, ambiguous=True)
cleansum2, _ = self.sample(pair_label)
clean3, _ = self.sample((pair_label-label[0]) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8
clean3, _ = self.sample((pair_label-label1) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8
img_seq = torch.stack([clean1, clean2, cleansum1, ambsum1, clean3, ambsum2, cleansum2]) # clean1 + clean2 = cleansum1 (ambsum1). clean1 + clean3 = cleansum2 (ambsum2). irrelevant = clean2 + clean3
else:
img_seq = torch.stack([clean1, clean2, cleansum1])
Expand Down Expand Up @@ -179,77 +186,77 @@ def sample(self, target, ambiguous=False):
label *= 2 # 0 or 1 -
return dataset[idx][0][label], dataset[idx][1][label]

# class SequenceEMNIST(Dataset):
# def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3,
# cache=False,cache_dir=None, include_irrelevant=False):
# """
# A dataset where the input is a sequence of letters that make a 3 letter word
# """
class SequenceEMNIST(Dataset):
def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3,
cache=False,cache_dir=None, include_irrelevant=False):
"""
A dataset where the input is a sequence of letters that make a 3 letter word
"""

# #self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ?
# self.dataset = DatasetTriplet(root, download, split, transform)
# self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls)
# self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls)
# self.ambiguous = ambiguous
# self.cache = cache
# self.cache_dir = cache_dir
# self.split = split
# self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset])
# self.include_irrelevant = include_irrelevant

# def __getitem__(self, index, img_size=28):
# """
# return the sequence as a tuple and the target as a tensor
# """
# if self.cache:
# triplet, label = self.dataset[index]
# gt_idx = random.randint(0, 1) #ground truth index
# correct_label = label[gt_idx]
# clean_img = triplet[gt_idx*2]
# word = random.choice(self.word_dict[correct_label])
# letter1 = self.sample(idx(word[0]))
# letter2 = self.sample(idx(word[1]))
# if self.include_irrelevant and self.ambiguous:
# opposite_label = label[int(not gt_idx)]
# opposite_img = triplet[int(not gt_idx)*2]
# opposite_word = random.choice(self.word_dict[opposite_label])
# opposite_clue1 = self.sample(idx(opposite_word[0]))
# opposite_clue2 = self.sample(idx(opposite_word[1]))
# img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img])
# else:
# img_seq = torch.stack([letter1, letter2, clean_img])
# torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
# torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
# else:
# img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
# correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
# return img_seq, correct_label

# def __len__(self):
# return self.data_len
#self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ?
self.dataset = DatasetTriplet(root, download, split, transform)
self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls)
self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls)
self.ambiguous = ambiguous
self.cache = cache
self.cache_dir = cache_dir
self.split = split
self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset]) - 1
self.include_irrelevant = include_irrelevant

def __getitem__(self, index, img_size=28):
"""
return the sequence as a tuple and the target as a tensor
"""
if self.cache:
triplet, label = self.dataset[index]
gt_idx = random.randint(0, 1) #ground truth index
correct_label = label[gt_idx]
clean_img = triplet[gt_idx*2]
word = random.choice(self.word_dict[correct_label])
letter1 = self.sample(idx(word[0]))
letter2 = self.sample(idx(word[1]))
if self.ambiguous:
opposite_label = label[int(not gt_idx)]
opposite_img = triplet[int(not gt_idx)*2]
opposite_word = random.choice(self.word_dict[opposite_label])
opposite_clue1 = self.sample(idx(opposite_word[0]))
opposite_clue2 = self.sample(idx(opposite_word[1]))
img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img])
else:
img_seq = torch.stack([letter1, letter2, clean_img])
torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
else:
img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
return img_seq, correct_label

def __len__(self):
return self.data_len

# def dictionary_by_last_letter(self, dictionary, n_cls):
# '''
# Sort word dictionary by last letter
# '''
# sorted_dict = [[] for _ in range(n_cls)]
# for word in dictionary:
# sorted_dict[idx(word[-1])].append(word)
# return sorted_dict
def dictionary_by_last_letter(self, dictionary, n_cls):
'''
Sort word dictionary by last letter
'''
sorted_dict = [[] for _ in range(n_cls)]
for word in dictionary:
sorted_dict[idx(word[-1])].append(word)
return sorted_dict

# def sample(self, target, ambiguous=False):
# """
# sample n images from the dataset
# """
# dataset = self.partitioned_dataset[target]
# idx = torch.randint(0, len(dataset), (1,))
# label = torch.where(dataset[idx][1] == target)[0]
# other = torch.where(dataset[idx][1] != target)[0]
# if ambiguous:
# label = 1
# else:
# label *= 2 # 0 or 1 -
# return dataset[idx][0][label]
def sample(self, target, ambiguous=False):
"""
sample n images from the dataset
"""
dataset = self.partitioned_dataset[target]
idx = torch.randint(0, len(dataset), (1,))
label = torch.where(dataset[idx][1] == target)[0]
other = torch.where(dataset[idx][1] != target)[0]
if ambiguous:
label = 1
else:
label *= 2 # 0 or 1 -
return dataset[idx][0][label]

def save_dataset_to_file(dataset_name, og_root, new_root, blend, pairs=None, batch_size=100, n_train=60000, n_test=10000):
os.makedirs(new_root+'/train/')
Expand Down

0 comments on commit 12b891d

Please sign in to comment.