In [1]:
import logging
import os
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torchvision import models

from tqdm import tqdm

from continuum.scenarios import ClassIncremental
from continuum.datasets import CIFAR10, ImageFolderDataset


  from .autonotebook import tqdm as notebook_tqdm


In [17]:
class Config:
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  increment = 2
  initial_increment = 2
  batch_size_train = 32
  batch_size_valid = 32

  num_epochs = 300

cfg = Config()


In [13]:
def seed_everything(seed=0):
    """Fix all random seeds"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)


In [7]:
class IncrementalResNet50(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.backbone = models.resnet50()
    self.transforms = models.ResNet50_Weights.IMAGENET1K_V1.transforms

    self.backbone.fc = nn.Linear(self.backbone.fc.in_features, 0)
    self.num_classes = 0

  def adaptation(self, increment: int) -> None:
    old_fc = self.backbone.fc
    in_features = old_fc.in_features

    new_fc = nn.Linear(in_features, self.num_classes + increment)

    with torch.no_grad():
      new_fc.weight[:-increment] = old_fc.weight.detach().clone()

    self.backbone.fc = new_fc
    self.num_classes += increment

  def forward(self, x) -> None:
    x = self.backbone(x)
    return x


In [None]:
model = IncrementalResNet50().to(cfg.device)




In [None]:
dataset_train = CIFAR10(data_path='input', train=True, download=True)
dataset_valid = CIFAR10(data_path='input', train=False, download=True)

scenario_train = ClassIncremental(dataset_train, increment=2, initial_increment=2, transformations=model.transforms)
scenario_valid = ClassIncremental(dataset_valid, increment=2, initial_increment=2, transformations=model.transforms)


In [None]:
for task_id in range(len(scenario_valid)):
  logging.info(f'Train for task {task_id} has started.')
  model.adaptation(cfg.initial_increment if task_id == 0 else cfg.increment)

  dataloader_train = DataLoader(scenario_train, batch_size=cfg.batch_size_train, shuffle=True)
  dataloader_valid = DataLoader(scenario_valid, batch_size=cfg.batch_size_valid)

  model.train()

  optimizer = optim.AdamW(params=model.parameters())
  scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=cfg.num_epochs)

  for i_epoch in range(cfg.num_epochs):
    optimizer.zero_grad()

    loss = torch.tensor(0.0).to(cfg.device)
    tqdm_loader = tqdm(dataloader_train)

    for X, y, task_ids in tqdm_loader:
      X, y = X.to(cfg.device), y.to(cfg.device)

      y_pred = model(X)

      loss = F.cross_entropy(y_pred, y)
      loss.backward()
      optimizer.step()
      scheduler.step()

      tqdm_loader.set_description(f'Epoch: {i_epoch}/{cfg.num_epochs} | Loss: {loss.item()}')

    tqdm_loader = tqdm(dataloader_valid)
    for X, y, task_ids in tqdm_loader:
      X = X.to(cfg.device)

      y_pred = model(X)
      
