In [None]:
import torch.optim as optim
from torchvision import datasets, transforms

from utils import *

# Step 5: Load Data and Run Experiments
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
mnist_test = datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True)

# Split MNIST into 5 tasks with 2 classes each
tasks = split_mnist_by_classes(mnist_train)
tasks_test = split_mnist_by_classes(mnist_train)


# Create data loaders for each task
batch_size = 64
data_loaders = [DataLoader(task, batch_size=batch_size, shuffle=True) for task in tasks]

# Initialize the model, optimizer, and criterion
model = SimpleCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train without experience replay
print("\nTraining without Experience Replay")
# train_model(model, data_loaders, tasks_test, optimizer, criterion, anim_title="FMNIST_forget")

# Train with experience replay
replay_model = SimpleCNN()
replay_optimizer = optim.Adam(replay_model.parameters(), lr=0.001)
replay_buffer = ReplayBufferReservoir(capacity=100)
print("\nTraining with Experience Replay")
train_model(replay_model, data_loaders, tasks_test, replay_optimizer, criterion, replay_buffer=replay_buffer, anim_title="FMNIST_forget")