In [13]:
from rail1.callbacks.checkpoint import checkpoint, load_checkpoint
from rail1.utils.seed import set_seed

from torch import nn
import random
import numpy as np
import torch
from torch import optim
import os
import numpy


In [14]:
def generate_data(steps):
    all_inputs = []
    all_targets = []
    for i in range(steps):
        
        inputs = torch.randn(32, 10).cuda()  # Batch size of 32
        targets = torch.randn(32, 1).cuda()  # Corresponding targets
        
        all_inputs.append(inputs)
        all_targets.append(targets)
        
    return torch.stack(all_inputs), torch.stack(all_targets)
    
    

In [15]:
set_seed(0, deterministic=True)
data_1 = generate_data(32)

random_state_dict = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}

set_seed(0)

torch.set_rng_state(random_state_dict["torch"])
torch.cuda.set_rng_state(random_state_dict["cuda"])
torch.cuda.set_rng_state_all(random_state_dict["cuda_all"])
numpy.random.set_state(random_state_dict["numpy"])
random.setstate(random_state_dict["random"])


data_2 = generate_data(32)

set_seed(0, deterministic=True)


data_1_ = generate_data(32)
data_2_ = generate_data(32)


assert torch.allclose(data_1[0], data_1_[0])

In [16]:
set_seed(0)

In [17]:
# Simple feed-forward neural network
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)  # Assuming input features are of size 10
        self.fc2 = nn.Linear(50, 1)   # Output layer (e.g., for regression or binary classification)
        
    def reset_parameters(self):
        for m in self.children():
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [18]:
def train_model(model, optimizer, steps):
    model.train()
    for _ in range(steps):
        inputs = torch.randn(32, 10).cuda()  # Batch size of 32
        targets = torch.randn(32, 1).cuda()  # Corresponding targets

        # Forward pass
        outputs = model(inputs)
        loss = torch.nn.functional.mse_loss(outputs, targets)  # Mean Squared Error Loss for simplicity

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [19]:
def _print_models_equal(model1, model2):
    for p in model1.state_dict():
        v1 = model1.state_dict()[p]
        v2 = model2.state_dict()[p]

        print(torch.equal(v1, v2))

In [20]:
set_seed(0, deterministic=True)

# Initialize models and optimizer
model1 = SimpleModel().cuda()
optimizer1 = optim.Adam(model1.parameters())

# Train model1 for a few steps and checkpoint
train_model(model1, optimizer1, steps=64)


set_seed(0, deterministic=True)
model2 = SimpleModel().cuda()
optimizer2 = optim.Adam(model2.parameters())


# Train model1 for a few steps and checkpoint
train_model(model2, optimizer2, steps=32)
random_state_dict = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}

model2_state_dict = model2.state_dict()
optimizer2_state_dict = optimizer2.state_dict()

checkpoint = {
    "model": model2_state_dict,
    "optimizer": optimizer2_state_dict,
    "random_state": random_state_dict,
}
torch.save(checkpoint, './checkpoint.pt')
del checkpoint


set_seed(0)
model2.reset_parameters()

checkpoint = torch.load('./checkpoint.pt')
model2_state_dict = checkpoint['model']
model2.load_state_dict(model2_state_dict)
optimizer2_state_dict = checkpoint['optimizer']
optimizer2.load_state_dict(optimizer2_state_dict)
random_state_dict = checkpoint['random_state']

torch.set_rng_state(random_state_dict["torch"])
torch.cuda.set_rng_state(random_state_dict["cuda"])
torch.cuda.set_rng_state_all(random_state_dict["cuda_all"])
numpy.random.set_state(random_state_dict["numpy"])
random.setstate(random_state_dict["random"])

train_model(model2, optimizer2, steps=32)

_print_models_equal(model1, model2)





# train_model(model1, optimizer1, steps=32)


True
True
True
True


# Deterministic train loader

In [21]:
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])



In [22]:
from torch.utils.data import Sampler

In [23]:
class InfiniteRandomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        while True:
            index = random.randint(0, len(self.data_source) - 1)
            yield index

    def __len__(self):
        return torch.iinfo(torch.int64).max

In [24]:
set_seed(0, deterministic=True)

# # Download and load the training data
mnist = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transforms.ToTensor())
# dataset = torch.arange(128)
# trainloader1 = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=InfiniteRandomSampler(dataset))


# for i, batch in enumerate(trainloader1):
#     images = batch
#     print(images.sum())
    
#     if i == 16:
#         break
# dataset = torch.arange(128)
sampler = InfiniteRandomSampler(dataset)
sampler_iter = iter(sampler)

# trainloader1 = torch.utils.data.DataLoader(mnist, batch_size=32, sampler=InfiniteRandomSampler(dataset))

random_state_dict_a = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}

for i, batch in enumerate(dataset):
    batch = [next(sampler_iter) for i in range(8)]
    batch = [mnist[i] for i in batch]
    images, labels = zip(*batch)
    print(torch.stack(images).sum())
    if i == 2:
        break

random_state_dict_b = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}

set_seed(0)

torch.set_rng_state(random_state_dict_b["torch"])
torch.cuda.set_rng_state(random_state_dict_b["cuda"])
torch.cuda.set_rng_state_all(random_state_dict_b["cuda_all"])
numpy.random.set_state(random_state_dict_b["numpy"])
random.setstate(random_state_dict_b["random"])



dataset = torch.arange(128)
sampler = InfiniteRandomSampler(dataset)
# trainloader2 = torch.utils.data.DataLoader(mnist, batch_size=32, sampler=InfiniteRandomSampler(dataset))

sampler_iter = iter(sampler)

print(torch.allclose(random_state_dict_b['torch'], torch.get_rng_state()))

for i, batch in enumerate(dataset):
    batch = [next(sampler_iter) for i in range(8)]
    batch = [mnist[i] for i in batch]
    images, labels = zip(*batch)
    print(torch.stack(images).sum())
    if i == 2:
        break


random_state_dict_c = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}


print(torch.allclose(random_state_dict_b['torch'],random_state_dict_c['torch']))

# dataset = torch.arange(128)
# trainloader2 = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=InfiniteRandomSampler(dataset))

# print()
# for i, batch in enumerate(trainloader2):
#     images = batch
#     print(images.sum())
    
#     if i == 16:
#         break


NameError: name 'dataset' is not defined

In [None]:
set_seed(0)
print(torch.allclose(torch.get_rng_state(), random_state_dict_a['torch']))
print(torch.allclose(torch.get_rng_state(), random_state_dict_b['torch']))
print(torch.allclose(torch.get_rng_state(), random_state_dict_c['torch']))


dataset = torch.arange(128)
sampler = InfiniteRandomSampler(dataset)
sampler_iter = iter(sampler)
# trainloader3 = torch.utils.data.DataLoader(mnist, batch_size=32, sampler=InfiniteRandomSampler(dataset))

for i, batch in enumerate(dataset):

    batch = [next(sampler_iter) for i in range(8)]
    batch = [mnist[i] for i in batch]
    images, labels = zip(*batch)
    print(torch.stack(images).sum())
    if i == 5:
        break

print(torch.allclose(torch.get_rng_state(), random_state_dict_c['torch']))
print(torch.allclose(random_state_dict_b['torch'], random_state_dict_c['torch']))


# dataset = torch.arange(128)
# trainloader3 = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=InfiniteRandomSampler(dataset))

# print()
# for i, batch in enumerate(trainloader3):
#     images = batch
    
#     print(images.sum())
    
#     if i == 32:
#         break



In [None]:
set_seed(0)
print(torch.allclose(torch.get_rng_state(), random_state_dict_a['torch']))
print(torch.allclose(torch.get_rng_state(), random_state_dict_b['torch']))
print(torch.allclose(torch.get_rng_state(), random_state_dict_c['torch']))


dataset = torch.arange(128)
sampler = InfiniteRandomSampler(dataset)
trainloader3 = torch.utils.data.DataLoader(mnist, batch_size=32, sampler=InfiniteRandomSampler(dataset))
for i, batch in enumerate(trainloader3):

    print(batch[0].sum())
    if i == 4:
        break

print(torch.allclose(torch.get_rng_state(), random_state_dict_b['torch']))

print(torch.allclose(torch.get_rng_state(), random_state_dict_c['torch']))
print(torch.allclose(random_state_dict_b['torch'], random_state_dict_c['torch']))


In [None]:
import random

import torch
from torch.utils.data import Sampler
import math


class InfiniteRandomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        while True:
            index = random.randint(0, len(self.data_source) - 1)  # type: ignore
            yield index

    def __len__(self):
        return torch.iinfo(torch.int64).max


In [None]:
set_seed(0)
dataset = torch.arange(128)
sampler = InfiniteRandomSampler(dataset)
# dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

class CustomDataLoader:
    def __init__(self, dataset, sampler, batch_size=1):
        self.dataset = dataset
        self.batch_size = batch_size
        self.sampler = sampler
        
    def __iter__(self):
        batch_indices = []
        for index in self.sampler:
            batch_indices.append(index)
            if len(batch_indices) == self.batch_size:
                yield [self.dataset[i] for i in batch_indices]
                batch_indices = []

        if batch_indices:
            yield [self.dataset[i] for i in batch_indices]


    def __len__(self):
        return len(self.sampler)
    

dataloader = CustomDataLoader(dataset, sampler=sampler, batch_size=3)


random_state_0 = torch.get_rng_state()

# for i, batch in enumerate(range(len(dataloader))):  # WORKS
for i, batch in enumerate(dataloader):  # DOES NOT WORK
# for i, batch in enumerate(dataset):  # DOES NOT WORK

# for i, batch in enumerate(range(10)):  # WORKS

    x = torch.randn(32)
    if i == 2:
        break
    
random_state_2 = torch.get_rng_state()
    


for i, batch in enumerate(range(len(dataloader))):  # WORKS
# for i, batch in enumerate(dataloader):  # DOES NOT WORK
# for i, batch in enumerate(range(10)):  # WORKS

    x = torch.randn(32)
    if i == 2:
        break
random_state_4 = torch.get_rng_state()



In [None]:
set_seed(0)
dataset = torch.arange(128)
sampler = iter(InfiniteRandomSampler(dataset))


# for i, batch in enumerate(range(len(dataloader))): # WORKS
# for i, batch in enumerate(dataset):  # DOES NOT WORK
for i, batch in enumerate(dataloader):  # DOES NOT WORK
# for i, batch in enumerate(range(10)):  # WORKS

    x = torch.randn(32)
    
    if i == 2:
        print(torch.allclose(torch.get_rng_state(), random_state_2))
        
    if i == 5:
        print(torch.allclose(torch.get_rng_state(), random_state_4))
        break
    

##### 

In [26]:
from rail1 import datasets

In [32]:
set_seed(0, deterministic=True)

train_loader = datasets.cifar10()['train_loader']
for i, batch in enumerate(train_loader):
    images, labels = batch
    print(images.sum())
    if i == 2:
        break

random_state_dict = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}

set_seed(0)

torch.set_rng_state(random_state_dict["torch"])
torch.cuda.set_rng_state(random_state_dict["cuda"])
torch.cuda.set_rng_state_all(random_state_dict["cuda_all"])
numpy.random.set_state(random_state_dict["numpy"])
random.setstate(random_state_dict["random"])



train_loader = datasets.cifar10()['train_loader']
for i, batch in enumerate(train_loader):
    images, labels = batch
    print(images.sum())
    if i == 2:
        break



random_state_dict = {
    "torch": torch.get_rng_state(),
    "numpy": numpy.random.get_state(),
    "random": random.getstate(),
    "cuda": torch.cuda.get_rng_state(),
    "cuda_all": torch.cuda.get_rng_state_all(),
}



Files already downloaded and verified
Files already downloaded and verified
tensor(716.1142)
tensor(711.8868)
tensor(730.3835)
Files already downloaded and verified
Files already downloaded and verified
tensor(734.9149)
tensor(731.8375)
tensor(735.6448)


In [33]:
set_seed(0)



train_loader = datasets.cifar10()['train_loader']
for i, batch in enumerate(train_loader):
    images, labels = batch
    print(images.sum())
    if i == 5:
        break



Files already downloaded and verified
Files already downloaded and verified
tensor(716.1142)
tensor(711.8868)
tensor(730.3835)
tensor(734.9149)
tensor(731.8375)
tensor(735.6448)
