In [43]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch import nn
import itertools

In [34]:
## Dataloading both intial and online data
class SplitDataset(Dataset):
    def __init__(self, start, end, transform):
        self.train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
        self.start = start
        self.end = end
    
    def __len__(self):
        return self.end - self.start

    def __getitem__(self, idx):
        return self.train_dataset[idx + self.start]

## Traning Data
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,),)])
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)

## Initial and Online Data splitting
split = len(train_dataset) // 2
last_idx = len(train_dataset)
initial_data = SplitDataset(0, split, transform)
online_data = SplitDataset(split, last_idx, transform)

## Intial and online dataloader
initial_dataloader = DataLoader(initial_data, batch_size=32, shuffle=True)
online_dataloader = DataLoader(initial_data, batch_size=32, shuffle=True)


## Test Data
testset = torchvision.datasets.MNIST('/tmp', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

print("Initial Data Size: ", len(initial_data))
print("Online Data Size: ", len(online_data))
print("Test Data Size: ", len(testset))

Initial Data Size:  30000
Online Data Size:  30000
Test Data Size:  10000


In [35]:
## Creating the model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [36]:
## Tranning the with 50% of the traning data
def train(dataloader, model, loss_fn, optimizer, device):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct


loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(5):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(initial_dataloader, model, loss_fn, optimizer, device)
    test(testloader, model, loss_fn)


Epoch 1
-------------------------------
Test Error: 
 Accuracy: 36.0%, Avg loss: 2.260007 

Epoch 2
-------------------------------
Test Error: 
 Accuracy: 53.4%, Avg loss: 2.199481 

Epoch 3
-------------------------------
Test Error: 
 Accuracy: 59.9%, Avg loss: 2.098018 

Epoch 4
-------------------------------
Test Error: 
 Accuracy: 65.4%, Avg loss: 1.919036 

Epoch 5
-------------------------------
Test Error: 
 Accuracy: 69.5%, Avg loss: 1.640180 



In [41]:
def train_online(OnX, Ony, initial_dataloader, model, loss_fn, optimizer, device):
    # storing intial weights
    weights_initial = {name: parameter.clone() for name, parameter in model.named_parameters()}
    
    # traning with 5 batches of initial data
    i = 0
    for OfX, Ofy in initial_dataloader:
        if(i == 5):
            break
        OfX, Ofy = OfX.to(device), Ofy.to(device)
        conX, cony = torch.cat((OnX, OfX)), torch.cat((Ony, Ofy))
        output = model(conX)
        optimizer.zero_grad()
        loss = loss_fn(output, cony)
        loss.backward()
        optimizer.step()
        i += 1
        
    # modifying the weights with with weigth avarage to avoid catastrophic forgetting
    alpha = 0.7
    weights_new = {name: parameter.clone() for name, parameter in model.named_parameters()}
    for key in model.state_dict():
        model.state_dict()[key].data.copy_(alpha * weights_initial[key] + (1 - alpha) * weights_new[key])
        

## Return accuaracy and loss for the old dataset and new batches
def test_online(dataloader, new_batches, model, loss_fn, device, batch_size=32):
    size = len(dataloader.dataset) + len(new_batches) * batch_size
    num_batches = len(dataloader) + len(new_batches)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in itertools.chain(dataloader, new_batches):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"OldDataset and new batches Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct



In [45]:
New_batches = []
j = 0
for X, y in online_dataloader:
    if(j == 2):
        break
    
    print("\nSimulating the online learning with new batches: ", j + 1)
    X, y = X.to(device), y.to(device)
    train_online(X, y, initial_dataloader, model, loss_fn, optimizer, device)

    # reporting the accuracy of the model on the test data
    test(testloader, model, loss_fn)

    # reporting the accuracy of the model on the train_data and new batches
    New_batches.append((X, y))
    test_online(initial_dataloader, New_batches, model, loss_fn, device)
    j += 1


Simulating the online learning with new batches:  0
Test Error: 
 Accuracy: 69.8%, Avg loss: 1.637130 

OldDataset and new batches Error: 
 Accuracy: 68.9%, Avg loss: 1.644118 


Simulating the online learning with new batches:  1
Test Error: 
 Accuracy: 69.8%, Avg loss: 1.636620 

OldDataset and new batches Error: 
 Accuracy: 68.9%, Avg loss: 1.643776 


Simulating the online learning with new batches:  2
