Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNIST sequence's first two numbers are randomly picked #9

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 81 additions & 74 deletions ambiguous/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset # For custom datasets
# For sequential emnist
# from english_words import english_words_lower_alpha_set as dictionary
# from abl_expectation.utils.expectation_clamp import extract_n_letter_words
from english_words import english_words_lower_alpha_set as dictionary
from abl_expectation.utils.expectation_clamp import extract_n_letter_words
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -97,9 +97,12 @@ def __getitem__(self, index, img_size=28):
single_image_path = self.image_list[index]
im_as_np = np.load(single_image_path).astype(np.float64)/255.
im_as_ten = torch.from_numpy(im_as_np)
#clean1 = self.transform(im_as_ten[:, :, :img_size])
#amb = self.transform(im_as_ten[:, :, img_size:2*img_size])
#clean2 = self.transform(im_as_ten[:, :, 2*img_size:3*img_size])
clean1 = im_as_ten[:, :, :img_size]
amb = im_as_ten[:, :, img_size:2*img_size]
clean2 = im_as_ten[:, :, 2*img_size:3*img_size]
clean2 = im_as_ten[:, :, 2*img_size:3*img_size]
label = torch.from_numpy(np.load(self.label_list[index]))
return (clean1, amb, clean2), label

Expand Down Expand Up @@ -144,14 +147,18 @@ def __getitem__(self, index, img_size=28):
return the sequence as a tuple and the target as a tensor
"""
if self.cache:
(clean1, _, clean2), label = self.triplet_dataset[index]
target = (label[0] + label[1]) % 10
#(clean1, _, clean2), label = self.triplet_dataset[index]
label1 = random.randint(0,9)
label2 = random.randint(0,9)
clean1, label1 = self.sample(label1)
clean2, label2 = self.sample(label2)
target = (label1 + label2) % 10
cleansum1, sum_label = self.sample(target)
if self.include_irrelevant and self.ambiguous:
ambsum1, pair_label = self.sample(target, ambiguous=True) # 2nd label in the pair
ambsum2, _ = self.sample(pair_label, ambiguous=True)
cleansum2, _ = self.sample(pair_label)
clean3, _ = self.sample((pair_label-label[0]) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8
clean3, _ = self.sample((pair_label-label1) % 10) # 0/6 .. 2 0-2=-2 % 10 = 8
img_seq = torch.stack([clean1, clean2, cleansum1, ambsum1, clean3, ambsum2, cleansum2]) # clean1 + clean2 = cleansum1 (ambsum1). clean1 + clean3 = cleansum2 (ambsum2). irrelevant = clean2 + clean3
else:
img_seq = torch.stack([clean1, clean2, cleansum1])
Expand Down Expand Up @@ -179,77 +186,77 @@ def sample(self, target, ambiguous=False):
label *= 2 # 0 or 1 -
return dataset[idx][0][label], dataset[idx][1][label]

# class SequenceEMNIST(Dataset):
# def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3,
# cache=False,cache_dir=None, include_irrelevant=False):
# """
# A dataset where the input is a sequence of letters that make a 3 letter word
# """
class SequenceEMNIST(Dataset):
def __init__(self, root, download=False, split='train', transform=None, n_cls=26, ambiguous=False, word_length=3,
cache=False,cache_dir=None, include_irrelevant=False):
"""
A dataset where the input is a sequence of letters that make a 3 letter word
"""

# #self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ?
# self.dataset = DatasetTriplet(root, download, split, transform)
# self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls)
# self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls)
# self.ambiguous = ambiguous
# self.cache = cache
# self.cache_dir = cache_dir
# self.split = split
# self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset])
# self.include_irrelevant = include_irrelevant

# def __getitem__(self, index, img_size=28):
# """
# return the sequence as a tuple and the target as a tensor
# """
# if self.cache:
# triplet, label = self.dataset[index]
# gt_idx = random.randint(0, 1) #ground truth index
# correct_label = label[gt_idx]
# clean_img = triplet[gt_idx*2]
# word = random.choice(self.word_dict[correct_label])
# letter1 = self.sample(idx(word[0]))
# letter2 = self.sample(idx(word[1]))
# if self.include_irrelevant and self.ambiguous:
# opposite_label = label[int(not gt_idx)]
# opposite_img = triplet[int(not gt_idx)*2]
# opposite_word = random.choice(self.word_dict[opposite_label])
# opposite_clue1 = self.sample(idx(opposite_word[0]))
# opposite_clue2 = self.sample(idx(opposite_word[1]))
# img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img])
# else:
# img_seq = torch.stack([letter1, letter2, clean_img])
# torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
# torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
# else:
# img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
# correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
# return img_seq, correct_label

# def __len__(self):
# return self.data_len
#self.unambiguous_EMNIST = torchvision.datasets.EMNIST(root=emnist_root, split='letters', download=download) # try with dataloader ?
self.dataset = DatasetTriplet(root, download, split, transform)
self.partitioned_dataset = partition_datasetV2(self.dataset, n_cls)
self.word_dict = self.dictionary_by_last_letter(extract_n_letter_words(dictionary, word_length), n_cls)
self.ambiguous = ambiguous
self.cache = cache
self.cache_dir = cache_dir
self.split = split
self.data_len = sum([len(dataset) for dataset in self.partitioned_dataset]) - 1
self.include_irrelevant = include_irrelevant

def __getitem__(self, index, img_size=28):
"""
return the sequence as a tuple and the target as a tensor
"""
if self.cache:
triplet, label = self.dataset[index]
gt_idx = random.randint(0, 1) #ground truth index
correct_label = label[gt_idx]
clean_img = triplet[gt_idx*2]
word = random.choice(self.word_dict[correct_label])
letter1 = self.sample(idx(word[0]))
letter2 = self.sample(idx(word[1]))
if self.ambiguous:
opposite_label = label[int(not gt_idx)]
opposite_img = triplet[int(not gt_idx)*2]
opposite_word = random.choice(self.word_dict[opposite_label])
opposite_clue1 = self.sample(idx(opposite_word[0]))
opposite_clue2 = self.sample(idx(opposite_word[1]))
img_seq = torch.stack([letter1, letter2, clean_img, triplet[1], opposite_clue1, opposite_clue2, opposite_img])
else:
img_seq = torch.stack([letter1, letter2, clean_img])
torch.save(img_seq, f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
torch.save(correct_label, f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
else:
img_seq = torch.load(f'{self.cache_dir}/{self.split}/img_seq_{index}.pt')
correct_label = torch.load(f'{self.cache_dir}/{self.split}/sum_label_{index}.pt')
return img_seq, correct_label

def __len__(self):
return self.data_len

# def dictionary_by_last_letter(self, dictionary, n_cls):
# '''
# Sort word dictionary by last letter
# '''
# sorted_dict = [[] for _ in range(n_cls)]
# for word in dictionary:
# sorted_dict[idx(word[-1])].append(word)
# return sorted_dict
def dictionary_by_last_letter(self, dictionary, n_cls):
'''
Sort word dictionary by last letter
'''
sorted_dict = [[] for _ in range(n_cls)]
for word in dictionary:
sorted_dict[idx(word[-1])].append(word)
return sorted_dict

# def sample(self, target, ambiguous=False):
# """
# sample n images from the dataset
# """
# dataset = self.partitioned_dataset[target]
# idx = torch.randint(0, len(dataset), (1,))
# label = torch.where(dataset[idx][1] == target)[0]
# other = torch.where(dataset[idx][1] != target)[0]
# if ambiguous:
# label = 1
# else:
# label *= 2 # 0 or 1 -
# return dataset[idx][0][label]
def sample(self, target, ambiguous=False):
"""
sample n images from the dataset
"""
dataset = self.partitioned_dataset[target]
idx = torch.randint(0, len(dataset), (1,))
label = torch.where(dataset[idx][1] == target)[0]
other = torch.where(dataset[idx][1] != target)[0]
if ambiguous:
label = 1
else:
label *= 2 # 0 or 1 -
return dataset[idx][0][label]

def save_dataset_to_file(dataset_name, og_root, new_root, blend, pairs=None, batch_size=100, n_train=60000, n_test=10000):
os.makedirs(new_root+'/train/')
Expand Down
Loading