From 66cf5dbf77fc63553ab7e0082a065e861b589cd9 Mon Sep 17 00:00:00 2001 From: "mashbayar.tugsbayar" Date: Thu, 7 Dec 2023 15:32:12 -0500 Subject: [PATCH] MNIST sequence's first two numbers randomly picked --- ambiguous/dataset/dataset.py | 155 ++++++++++++++++++----------------- 1 file changed, 81 insertions(+), 74 deletions(-) diff --git a/ambiguous/dataset/dataset.py b/ambiguous/dataset/dataset.py index 38cdd65..516f68a 100644 --- a/ambiguous/dataset/dataset.py +++ b/ambiguous/dataset/dataset.py @@ -18,8 +18,8 @@ from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset # For custom datasets # For sequential emnist -# from english_words import english_words_lower_alpha_set as dictionary -# from abl_expectation.utils.expectation_clamp import extract_n_letter_words +from english_words import english_words_lower_alpha_set as dictionary +from abl_expectation.utils.expectation_clamp import extract_n_letter_words import random device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -97,9 +97,12 @@ def __getitem__(self, index, img_size=28): single_image_path = self.image_list[index] im_as_np = np.load(single_image_path).astype(np.float64)/255. im_as_ten = torch.from_numpy(im_as_np) + #clean1 = self.transform(im_as_ten[:, :, :img_size]) + #amb = self.transform(im_as_ten[:, :, img_size:2*img_size]) + #clean2 = self.transform(im_as_ten[:, :, 2*img_size:3*img_size]) clean1 = im_as_ten[:, :, :img_size] amb = im_as_ten[:, :, img_size:2*img_size] - clean2 = im_as_ten[:, :, 2*img_size:3*img_size] + clean2 = im_as_ten[:, :, 2*img_size:3*img_size] label = torch.from_numpy(np.load(self.label_list[index])) return (clean1, amb, clean2), label @@ -144,14 +147,18 @@ def __getitem__(self, index, img_size=28): return the sequence as a tuple and the target as a tensor """ if self.cache: - (clean1, _, clean2), label = self.triplet_dataset[index] - target = (label[0] + label[1]) % 10 + #(clean1, _, clean2), label = self.triplet_dataset[index] + label1 = random.randint(0,9) + label2 = random.randint(0,9) + clean1, label1 = self.sample(label1) + clean2, label2 = self.sample(label2) + target = (label1 + label2) % 10 cleansum1, sum_label = self.sample(target) if self.include_irrelevant and self.ambiguous: ambsum1, pair_label = self.sample(target, ambiguous=True) # 2nd label in the pair ambsum2, _ = self.sample(pair_label, ambiguous=True) cleansum2, _ = self.sample(pair_label) - clean3, _ = self.sample((pair_label-label[0]) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8 + clean3, _ = self.sample((pair_label-label1) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8 img_seq = torch.stack([clean1, clean2, cleansum1, ambsum1, clean3, ambsum2, cleansum2]) # clean1 + clean2 = cleansum1 (ambsum1). clean1 + clean3 = cleansum2 (ambsum2). irrelevant = clean2 + clean3 else: img_seq = torch.stack([clean1, clean2, cleansum1]) @@ -179,77 +186,77 @@ def sample(self, target, ambiguous=False): label *= 2 # 0 or 1 - return dataset[idx][0][label], dataset[idx][1][label] -# class SequenceEMNIST(Dataset): -# def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3, -# cache=False,cache_dir=None, include_irrelevant=False): -# """ -# A dataset where the input is a sequence of letters that make a 3 letter word -# """ +class SequenceEMNIST(Dataset): + def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3, + cache=False,cache_dir=None, include_irrelevant=False): + """ + A dataset where the input is a sequence of letters that make a 3 letter word + """ -# #self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ? -# self.dataset = DatasetTriplet(root, download, split, transform) -# self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls) -# self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls) -# self.ambiguous = ambiguous -# self.cache = cache -# self.cache_dir = cache_dir -# self.split = split -# self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset]) -# self.include_irrelevant = include_irrelevant - -# def __getitem__(self, index, img_size=28): -# """ -# return the sequence as a tuple and the target as a tensor -# """ -# if self.cache: -# triplet, label = self.dataset[index] -# gt_idx = random.randint(0, 1) #ground truth index -# correct_label = label[gt_idx] -# clean_img = triplet[gt_idx*2] -# word = random.choice(self.word_dict[correct_label]) -# letter1 = self.sample(idx(word[0])) -# letter2 = self.sample(idx(word[1])) -# if self.include_irrelevant and self.ambiguous: -# opposite_label = label[int(not gt_idx)] -# opposite_img = triplet[int(not gt_idx)*2] -# opposite_word = random.choice(self.word_dict[opposite_label]) -# opposite_clue1 = self.sample(idx(opposite_word[0])) -# opposite_clue2 = self.sample(idx(opposite_word[1])) -# img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img]) -# else: -# img_seq = torch.stack([letter1, letter2, clean_img]) -# torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt') -# torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt') -# else: -# img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt') -# correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt') -# return img_seq, correct_label - -# def __len__(self): -# return self.data_len + #self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ? + self.dataset = DatasetTriplet(root, download, split, transform) + self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls) + self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls) + self.ambiguous = ambiguous + self.cache = cache + self.cache_dir = cache_dir + self.split = split + self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset]) - 1 + self.include_irrelevant = include_irrelevant + + def __getitem__(self, index, img_size=28): + """ + return the sequence as a tuple and the target as a tensor + """ + if self.cache: + triplet, label = self.dataset[index] + gt_idx = random.randint(0, 1) #ground truth index + correct_label = label[gt_idx] + clean_img = triplet[gt_idx*2] + word = random.choice(self.word_dict[correct_label]) + letter1 = self.sample(idx(word[0])) + letter2 = self.sample(idx(word[1])) + if self.ambiguous: + opposite_label = label[int(not gt_idx)] + opposite_img = triplet[int(not gt_idx)*2] + opposite_word = random.choice(self.word_dict[opposite_label]) + opposite_clue1 = self.sample(idx(opposite_word[0])) + opposite_clue2 = self.sample(idx(opposite_word[1])) + img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img]) + else: + img_seq = torch.stack([letter1, letter2, clean_img]) + torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt') + torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt') + else: + img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt') + correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt') + return img_seq, correct_label + + def __len__(self): + return self.data_len -# def dictionary_by_last_letter(self, dictionary, n_cls): -# ''' -# Sort word dictionary by last letter -# ''' -# sorted_dict = [[] for _ in range(n_cls)] -# for word in dictionary: -# sorted_dict[idx(word[-1])].append(word) -# return sorted_dict + def dictionary_by_last_letter(self, dictionary, n_cls): + ''' + Sort word dictionary by last letter + ''' + sorted_dict = [[] for _ in range(n_cls)] + for word in dictionary: + sorted_dict[idx(word[-1])].append(word) + return sorted_dict -# def sample(self, target, ambiguous=False): -# """ -# sample n images from the dataset -# """ -# dataset = self.partitioned_dataset[target] -# idx = torch.randint(0, len(dataset), (1,)) -# label = torch.where(dataset[idx][1] == target)[0] -# other = torch.where(dataset[idx][1] != target)[0] -# if ambiguous: -# label = 1 -# else: -# label *= 2 # 0 or 1 - -# return dataset[idx][0][label] + def sample(self, target, ambiguous=False): + """ + sample n images from the dataset + """ + dataset = self.partitioned_dataset[target] + idx = torch.randint(0, len(dataset), (1,)) + label = torch.where(dataset[idx][1] == target)[0] + other = torch.where(dataset[idx][1] != target)[0] + if ambiguous: + label = 1 + else: + label *= 2 # 0 or 1 - + return dataset[idx][0][label] def save_dataset_to_file(dataset_name, og_root, new_root, blend, pairs=None, batch_size=100, n_train=60000, n_test=10000): os.makedirs(new_root+'/train/')