In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
# --------------------------------------------------------------
# Config
# --------------------------------------------------------------
BATCH_SIZE = 512

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------------------------------------------
# Dataset + Masking Module
# --------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
])



In [17]:
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [18]:
train_data2 = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
train_loader2 = DataLoader(train_data2, batch_size=BATCH_SIZE, shuffle=True)
test_data2 = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
test_loader2 = DataLoader(test_data2, batch_size=BATCH_SIZE, shuffle=True)


In [19]:
class SimpleCNN(nn.Module):
  def __init__(self,dim,out_feat):
    super().__init__()
    self.net=nn.Sequential(
        nn.Conv2d(1,32,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*7*7,dim),
        nn.ReLU(),
        nn.Linear(dim,out_feat)
    )
  def forward(self,x):
    return self.net(x)

In [20]:
model=SimpleCNN(dim=768,out_feat=10).to(device)
opt=torch.optim.Adam(model.parameters(),lr=1e-3)
crit=nn.CrossEntropyLoss()

In [21]:
memory_buffer=[]
Memory_Size=10000
def update_memory(x,y):
  for xi,yi in zip(x,y):
    if len(memory_buffer)<Memory_Size:
      memory_buffer.append((xi.cpu(),yi.cpu()))
    else:
      idx=random.randint(0,Memory_Size-1)
      memory_buffer[idx]=(xi.cpu(),yi.cpu())

def sample_memory(batch_size,device):
  if len(memory_buffer)==0:
    return None,None
  else:
    batch=random.sample(memory_buffer,min(len(memory_buffer),batch_size))
    x_b,y_b=zip(*batch)
    return torch.stack(x_b).to(device),torch.stack(y_b).to(device)

In [22]:
def get_fisher(train_loader:DataLoader):
  model.eval()
  fisher={n:torch.zeros_like(p) for n,p in model.named_parameters()}
  for x,y in tqdm(train_loader):
    x,y=x.to(device),y.to(device)
    loss=crit(model(x),y)
    model.zero_grad()
    loss.backward()
    for n,p in model.named_parameters():
      if p.grad is not None:
        fisher[n]+=(p.grad.detach()**2)
  for n,p in model.named_parameters():
    fisher[n]/=len(train_loader)
  old_params={n:p.clone().detach() for n,p in model.named_parameters()}
  return fisher,old_params

In [23]:
def train_model(loader,epochs,update_mem,ewc=False,fisher=None,old_params=None):
  ewc_lambda=5000
  for epoch in range(epochs):
    for x,y in tqdm(loader):
      x,y=x.to(device),y.to(device)
      x_m,y_m=sample_memory(10,device)
      if x_m is not None:
        x_b,y_b=torch.cat([x,x_m]),torch.cat([y,y_m])
      else:
        x_b,y_b=x,y
      loss=crit(model(x_b),y_b)
      if ewc:
        ewc_loss=0
        for n,p in model.named_parameters():
          if n in fisher:
            ewc_loss+=(fisher[n]*(p-old_params[n])**2).sum()
        loss=loss+((ewc_loss)*ewc_lambda)
      opt.zero_grad()
      loss.backward()
      opt.step()

      if update_mem:
        update_memory(x,y)

In [24]:
train_model(train_loader,2,True)

100%|██████████| 118/118 [00:09<00:00, 11.83it/s]
100%|██████████| 118/118 [00:10<00:00, 11.65it/s]


In [25]:
fisher,old_params=get_fisher(train_loader=train_loader)

100%|██████████| 118/118 [00:06<00:00, 19.01it/s]


In [26]:
def validation(loader):
  total=0
  correct=0
  for x,y in loader:
    x,y=x.to(device),y.to(device)
    preds=model(x)
    score=torch.argmax(preds,dim=1)
    correct+=(score==y).sum().item()
    total+=len(y)
  return correct*100/total


In [27]:
acc=validation(test_loader)
acc

97.76

In [28]:
train_model(train_loader2,2,False,True,fisher,old_params)

100%|██████████| 118/118 [00:07<00:00, 16.63it/s]
100%|██████████| 118/118 [00:06<00:00, 18.40it/s]


In [29]:
acc=validation(test_loader2)
acc

87.29

In [30]:
acc=validation(test_loader)
acc

92.86