In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
# --------------------------------------------------------------
# Config
# --------------------------------------------------------------
BATCH_SIZE = 512

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------------------------------------------
# Dataset + Masking Module
# --------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
])



In [46]:
train_data = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [47]:
train_data2 = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader2 = DataLoader(train_data2, batch_size=BATCH_SIZE, shuffle=True)
test_data2 = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader2 = DataLoader(test_data2, batch_size=BATCH_SIZE, shuffle=True)


In [48]:
class SimpleCNN(nn.Module):
  def __init__(self,dim,out_feat):
    super().__init__()
    self.net=nn.Sequential(
        nn.Conv2d(1,32,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*7*7,dim),
        nn.ReLU(),
        nn.Linear(dim,out_feat)
    )
  def forward(self,x):
    return self.net(x)

In [49]:
model=SimpleCNN(dim=768,out_feat=10).to(device)
opt=torch.optim.Adam(model.parameters(),lr=1e-3)
crit=nn.CrossEntropyLoss()

In [59]:
for epoch in range(1):
  for x,y in tqdm(train_loader):
    x,y=x.to(device),y.to(device)
    preds=model(x)
    loss=crit(preds,y)
    opt.zero_grad()
    loss.backward()
    opt.step()


100%|██████████| 118/118 [00:45<00:00,  2.61it/s]


In [60]:
def validation(loader):
  total=0
  correct=0
  for x,y in loader:
    x,y=x.to(device),y.to(device)
    preds=model(x)
    score=torch.argmax(preds,dim=1)
    correct+=(score==y).sum().item()
    total+=len(y)
  return correct*100/total


In [61]:
acc=validation(test_loader)
acc

86.17

# ***EWC***

In [62]:
model.eval()
fisher={n:torch.zeros_like(p) for n,p in model.named_parameters()}
for x,y in tqdm(train_loader):
  x,y=x.to(device),y.to(device)
  loss=crit(model(x),y)
  loss.backward()
  for n,p in model.named_parameters():
    if p.grad is not None:
      fisher[n]+=(p.grad.detach()**2)
for n,p in model.named_parameters():
  fisher[n]/=len(train_loader)

old_params={n:p.clone().detach() for n,p in model.named_parameters()}

100%|██████████| 118/118 [00:40<00:00,  2.95it/s]


In [63]:
ewc_lambda=10000

In [64]:
for epoch in range(1):
  for x,y in tqdm(train_loader2):
    x,y=x.to(device),y.to(device)
    preds=model(x)
    loss=crit(preds,y)
    ewc_loss=0
    for n,p in model.named_parameters():
      if n in fisher:
        ewc_loss+=(fisher[n]*(p-old_params[n])**2).sum()
    loss=loss+(ewc_lambda/2)*ewc_loss
    opt.zero_grad()
    loss.backward()
    opt.step()


100%|██████████| 118/118 [00:47<00:00,  2.50it/s]


In [65]:
acc=validation(test_loader)
acc

81.56

In [66]:
acc=validation(test_loader2)
acc

96.89