In [1]:
import torch
from datasets import IndexedDataset
from dataloader import IndexedDataLoader
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from samplers import InfoBatch, DataSchedule 
from losses import trades_loss
from tqdm.notebook import tqdm



args = get_args()
args.arch = 'LeNet5'
args.dataset = 'MNIST'
args.selection_method = 'none'

# Example Usage
dataset = IndexedDataset(args, train=True)
model, target_layers = load_architecture(args)

args.epochs = 5
args.ratio = 0.5
args.delta = 1

trainset = DataSchedule(dataset, args.epochs, args.ratio, args.delta)
# trainset = InfoBatch(dataset, args.epochs, args.ratio, args.delta)

sampler = trainset.sampler
train_shuffle = False

trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=0, sampler=sampler)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )

def train_info_batch(epoch):
    # safe_print('\nEpoch: %d, iterations %d' % (epoch, len(trainloader)))
    model.train()

    for batch_idx, blobs in tqdm( enumerate(trainloader) ):
        # print(len(blobs) )
        inputs, targets, idxs = blobs 
    # for batch in trainloader:

        # print(len(batch))
        inputs, targets = inputs.to('cuda'), targets.to('cuda')
        optimizer.zero_grad()

        loss_values, clean_values, robust_values = trades_loss(model=model, x_natural=inputs, y=targets, optimizer=optimizer,)
        # 3. use <InfoBatch>.update(loss), all scoring/rescaling/getting mean is now conducted at the backend, see previous (research version) code for details.

        trainset.update_scores(clean_values, robust_values)
        # print(trainset.weights.shape)
        # print(loss_values.shape)
        # print(trainset.cur_batch_index)
        loss = trainset.update_loss_weights(loss_values)

        # print(loss)
        loss.backward()
        # print(loss)
        
    # safe_print('epoch:', epoch, '  Training Accuracy:', round(100. * correct /
    #     total, 3), '  Train loss:', round(train_loss / len(trainloader), 4))
    # train_acc.append(correct / total)


for epoch in range(args.epochs):
    print('epoch')
    # trainloader.sampler.set_epoch(epoch)
    train_info_batch(epoch) 

There are 0 well learned samples and 60000 still to learn. We sampled 0
For next epoch, there will be 60000 samples. Total dataset size was 60000
epoch


0it [00:00, ?it/s]

There are 0 well learned samples and 60000 still to learn. We sampled 0
For next epoch, there will be 60000 samples. Total dataset size was 60000
epoch


0it [00:00, ?it/s]

There are 19591 well learned samples and 40409 still to learn. We sampled 9795
For next epoch, there will be 50204 samples. Total dataset size was 60000
epoch


0it [00:00, ?it/s]

There are 19599 well learned samples and 40401 still to learn. We sampled 9799
For next epoch, there will be 50200 samples. Total dataset size was 60000
epoch


0it [00:00, ?it/s]

There are 19600 well learned samples and 40400 still to learn. We sampled 9800
For next epoch, there will be 50200 samples. Total dataset size was 60000
epoch


0it [00:00, ?it/s]

There are 19603 well learned samples and 40397 still to learn. We sampled 9801
For next epoch, there will be 50198 samples. Total dataset size was 60000


In [None]:
# world_size = 2
# sampler1 = DistributedSampler(dataset, num_replicas=world_size, rank=0, shuffle=True)
# sampler2 = DistributedSampler(dataset, num_replicas=world_size, rank=1, shuffle=True)
# # indexed_loader = DataLoader(custom_dataset, batch_size=5, shuffle=True)
# indexed_loader1 = IndexedDataLoader(dataset, batch_size=5, sampler=sampler1)
# indexed_loader2 = IndexedDataLoader(dataset, batch_size=5, sampler=sampler2)

# # Before reshuffling
# batch_5 = indexed_loader.get_batch(5)
# print("Before reshuffle:", batch_5)

# # Reshuffle for a new epoch
# indexed_loader.shuffle_indices()

# # After reshuffling
# batch_5 = indexed_loader.get_batch(5)
# print("After reshuffle:", batch_5)

In [13]:
import numpy as np
clean_scores = np.array([1,2,3,4,5] )
robust_scores = np.array([6,7,8,9,10])

well_learned_clean_mask = ( (clean_scores < clean_scores.mean()) & (robust_scores < robust_scores.mean()) )


In [5]:
torch.ones(len(dataset)) * 3

tensor([3., 3., 3.,  ..., 3., 3., 3.])

In [19]:

for epoch in range(2):
    # indexed_loader.shuffle_data()
    
    sampler1.set_epoch(epoch)
    sampler2.set_epoch(epoch)

    print('sampler 1')
    print()


    for batch_id in range( len(indexed_loader1) ):
        print(batch_id, indexed_loader1.get_batch(batch_id) )
    print()

    print('sampler 2')
    print()
    
    for batch_id in range( len(indexed_loader2) ):
        print(batch_id, indexed_loader2.get_batch(batch_id) )
    print()

sampler 1

0 [tensor([[6],
        [1],
        [4],
        [7],
        [0]]), tensor([[6],
        [1],
        [4],
        [7],
        [0]]), tensor([4, 7, 3, 0, 6])]

sampler 2

0 [tensor([[0],
        [8],
        [3],
        [4],
        [6]]), tensor([[0],
        [8],
        [3],
        [4],
        [6]]), tensor([6, 2, 8, 3, 4])]

sampler 1

0 [tensor([[9],
        [2],
        [7],
        [5],
        [1]]), tensor([[9],
        [2],
        [7],
        [5],
        [1]]), tensor([5, 1, 0, 9, 7])]

sampler 2

0 [tensor([[0],
        [8],
        [3],
        [4],
        [6]]), tensor([[0],
        [8],
        [3],
        [4],
        [6]]), tensor([6, 2, 8, 3, 4])]



In [229]:
from torch.utils.data import Sampler, DataLoader
import random
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label, idx



class ResamplingSampler(Sampler):

    def __init__(self, data_source, batch_size, sample_size):

        self.sampled_indices = set()
        self.tosample_indices = set( range( len(data_source) ) )

        self.sample_size = sample_size
        self.batch_size = batch_size
        
    def __iter__(self):

        while len(self.tosample_indices) > 0:
            # Ensure sample size does not exceed the available to-sample indices
            current_sample_size = min(self.sample_size, len(self.tosample_indices))

            ### The function to create the batches : 
            sampled_set = set(random.sample(self.tosample_indices, current_sample_size))

            # Ensure batch size does not exceed the sampled set size
            current_batch_size = min(self.batch_size, len(sampled_set))
            batch_set = set(random.sample(sampled_set, current_batch_size))

            if not batch_set:
                break  # Exit if batch set is empty to avoid errors

            # Update sampled and tosample indices
            self.sampled_indices.update(batch_set)
            self.tosample_indices.difference_update(batch_set)

            yield list(batch_set)



class CustomDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, sample_size, *args, **kwargs):
        sampler = ResamplingSampler(dataset, batch_size=batch_size, sample_size=subsample_size)
        super().__init__(dataset, batch_sampler=sampler, *args, **kwargs)


import torch
import numpy as np

# Parameters for synthetic data
num_samples = 100  # Total number of samples
input_dim = 10  # Dimensionality of each observation
num_classes = 5  # Number of classes

# Generate synthetic data
data = torch.randn(num_samples, input_dim)  # Randomly generated features
labels = torch.randint(0, num_classes, (num_samples,))  # Randomly generated labels


dataset = CustomDataset(data, labels)
batch_size = 10  # The size of the initial batch
subsample_size = 20  # The size of the subsample sent to the model

dataloader = CustomDataLoader(dataset, batch_size=batch_size, sample_size=subsample_size)

for epoch in range(1):
    for batch in dataloader:
        data, labels, idx = batch
        print(idx)
        
        # Pass batch_data and batch_labels to your model
        # Forward pass, backward pass, and optimization
        pass


tensor([64, 33,  0,  2, 37, 38,  9, 48, 57, 63])
tensor([98,  5, 11, 80, 81, 83, 20, 87, 94, 62])
tensor([68,  8, 60, 17, 18, 53, 58, 91, 28, 61])
tensor([34, 36, 70,  6, 72, 39, 74, 12, 84, 29])
tensor([96,  1, 41, 42, 14, 47, 79, 23, 59, 93])
tensor([97, 71, 40, 76, 45, 78, 46, 15, 50, 56])
tensor([73, 10, 75, 49, 54, 86, 90, 92, 30, 31])
tensor([65, 66,  3,  7, 43, 13, 85, 24, 25, 26])
tensor([32, 67, 99, 35, 44, 51, 52, 22, 89, 95])
tensor([ 4, 69, 77, 16, 82, 19, 21, 55, 88, 27])


In [19]:
a = set()
b = set([1,2,3])
a.update(b)
a

{1, 2, 3}

In [9]:
remaining_indices = [0,1,2,3,4,5,6,7,8,9]
sample_size = 10
batch_size = 5

sampled_indices = set( random.sample(remaining_indices, sample_size) )
print(sampled_indices)

selected_indices = set( random.sample(sampled_indices, batch_size) )
print(selected_indices)

unselected_indices = sampled_indices - selected_indices
print(unselected_indices)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
{0, 3, 7, 8, 9}
{1, 2, 4, 5, 6}


In [5]:
for subset_indices in dataloader:
    print('hey')

hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey
hey


KeyboardInterrupt: 