In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
# --------------------------------------------------------------
# Config
# --------------------------------------------------------------
BATCH_SIZE = 512

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------------------------------------------
# Dataset + Masking Module
# --------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
])



In [20]:
train_data = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [21]:
train_data2 = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader2 = DataLoader(train_data2, batch_size=BATCH_SIZE, shuffle=True)
test_data2 = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader2 = DataLoader(test_data2, batch_size=BATCH_SIZE, shuffle=True)


In [22]:
memory_buffer=[]
Memory_size=10000
def update_memory(x,y):
  for xi,yi in zip(x,y):
    if len(memory_buffer)<Memory_size:
      memory_buffer.append((xi.cpu(),yi.cpu()))
    else:
      idx=random.randint(0,Memory_size-1)
      memory_buffer[idx]=(xi.cpu(),yi.cpu())

def sample_memory(batch_size,device):
  if len(memory_buffer)==0:
    return None,None
  else:
    batch=random.sample(memory_buffer,min(len(memory_buffer),Memory_size))
    x_sample,y_sample=zip(*batch)
    return torch.stack(x_sample).to(device),torch.stack(y_sample).to(device)

In [23]:
class SimpleCNN(nn.Module):
  def __init__(self,dim,out_feat):
    super().__init__()
    self.net=nn.Sequential(
        nn.Conv2d(1,32,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*7*7,dim),
        nn.ReLU(),
        nn.Linear(dim,out_feat)
    )
  def forward(self,x):
    return self.net(x)

In [24]:
model=SimpleCNN(dim=768,out_feat=10).to(device)
opt=torch.optim.Adam(model.parameters(),lr=1e-3)
crit=nn.CrossEntropyLoss()

In [25]:
def train_model(loader,epochs,update_mem):
  for epoch in range(epochs):
    for x,y in tqdm(loader):
      x,y=x.to(device),y.to(device)
      x_m,y_m=sample_memory(20,device)
      if x_m is not None:
        x_b,y_b=torch.concat([x,x_m]),torch.concat([y,y_m])
      else:
        x_b,y_b=x,y
      loss=crit(model(x_b),y_b)
      opt.zero_grad()
      loss.backward()
      opt.step()

      if update_mem:
        update_memory(x,y)

In [26]:
train_model(train_loader,2,True)

100%|██████████| 118/118 [00:22<00:00,  5.23it/s]
100%|██████████| 118/118 [00:23<00:00,  5.03it/s]


In [27]:
train_model(train_loader2,2,False)

100%|██████████| 118/118 [00:14<00:00,  8.22it/s]
100%|██████████| 118/118 [00:14<00:00,  8.04it/s]


In [28]:
def validation(loader):
  total=0
  correct=0
  for x,y in loader:
    x,y=x.to(device),y.to(device)
    preds=model(x)
    score=torch.argmax(preds,dim=1)
    correct+=(score==y).sum().item()
    total+=len(y)
  return correct*100/total


In [29]:
acc=validation(test_loader)
acc

87.68

In [30]:
acc=validation(test_loader2)
acc

93.82