In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
import numpy as np
from torch.utils.data import DataLoader
import torchvision

In [3]:
def seed_everything(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(0)

In [4]:
print('STEP 1: LOADING DATASET')

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = dsets.CIFAR10(
    root='./data/CIFAR10/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = dsets.CIFAR10(
    root='./data/CIFAR10/',
    train=False,
    transform=transform_test
)

STEP 1: LOADING DATASET
Files already downloaded and verified


In [5]:
def remove_samples(dataset, l=[]):
    targets = []
    data = []
    for d, t in zip(dataset.data, dataset.targets):
        if t in l:
            continue
        else:
            targets.append(t)
            data.append(d)
    dataset.targets = targets
    dataset.data = data
    return dataset

train_dataset = remove_samples(train_dataset, [5,6,7,8,9])

In [6]:
print('STEP 2: MAKING DATASET ITERABLE')
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=4,
)

STEP 2: MAKING DATASET ITERABLE


In [7]:
print('STEP 3: CREATE MODEL CLASS')
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, temp=1.0):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.temp = temp

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out) / self.temp  
        return out


def resnet18(temp=1.0, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], temp=temp, **kwargs)
    return model

STEP 3: CREATE MODEL CLASS


In [8]:
print('STEP 4: INSTANTIATE MODEL CLASS')
model = resnet18(num_classes=5)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

STEP 4: INSTANTIATE MODEL CLASS


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [9]:
print('STEP 5: INSTANTIATE LOSS CLASS')
criterion = nn.CrossEntropyLoss()

STEP 5: INSTANTIATE LOSS CLASS


In [10]:
print('STEP 6: INSTANTIATE OPTIMIZER CLASS')
num_epochs = 20

learning_rate = 1e-2
momentum = 0.9
weight_decay = 5e-4

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum = momentum, weight_decay = weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

STEP 6: INSTANTIATE OPTIMIZER CLASS


In [11]:
# Codes below were commented for simplicity of project(We are using pretrained weights)

# print('STEP 7: TRAIN THE MODEL')
# # Train
# for epoch in range(num_epochs):
#     model.train()
#     loss_avg = 0.
#     for i, (images, labels) in enumerate(train_loader):
#         images = images.to(device)
#         labels = labels.to(device)

#         outputs = model(images)

#         loss = criterion(outputs, labels)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         loss_avg += loss.item()
        
#     scheduler.step()
    
#     print(f'Epochs: {epoch}. Loss: {loss_avg / len(train_loader):.5f}.')

In [12]:
# model.load_state_dict(torch.load('./resnet18_base_weight/resnet18_base.pth'))

In [13]:
# # Test
# class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# correct = {f"{i}": 0 for i in range(10)}
# total = {f"{i}": 0 for i in range(10)}

# model.eval()
# with torch.no_grad():
#     for images, labels in test_loader:
#         images = images.to(device)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
        
#         for label, prediction in zip(labels, predicted):
#             if label == prediction.cpu():
#                 correct[str(label.item())] += 1
#             total[str(label.item())] += 1
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         total_ += total[idx]
#         correct_ += correct_count
#     accuracy = 100 * float(correct_) / total_
#     print(f'Total Accuracy: {accuracy:3.1f} %')
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         if int(idx) in [0,1,2,3,4]:
#             total_ += total[idx]
#             correct_ += correct_count
#     try:
#         accuracy = 100 * float(correct_) / total_
#     except ZeroDivisionError:
#         print(f'Old classes Accuracy: 00.0 %')
#     else:
#         print(f'Old classes Accuracy: {accuracy:3.1f} %')
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         if int(idx) in [5,6,7,8,9]:
#             total_ += total[idx]
#             correct_ += correct_count
#     try:
#         accuracy = 100 * float(correct_) / total_
#     except ZeroDivisionError:
#         print(f'Novel classes Accuracy: 00.0 %')
#     else:
#         print(f'Novel classes Accuracy: {accuracy:3.1f} %')
    
#     print('Accuracy for class')
#     for idx, correct_count in correct.items():
#         accuracy = 100 * float(correct_count) / total[idx]
#         print(f'\t{class_name[int(idx)]:10s} : {accuracy:3.1f} %')

In [14]:
# torch.save(model.state_dict(), 'resnet18_base.pth')

In [15]:
# Create the ResNet model for 10 classes
model = resnet18(num_classes=10)

# Load pre-trained weights for the initial 5 classes
pretrained_weights_path = './resnet18_base_weight/resnet18_base.pth'
pretrained_dict = torch.load(pretrained_weights_path)

# Get the current state dict
model_dict = model.state_dict()

# Filter out the final layer weights from the pre-trained dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}

# Update the current model dict with the pre-trained weights
model_dict.update(pretrained_dict)

# Load the updated state dict into the model
model.load_state_dict(model_dict)

# Manually initialize the final layer weights
with torch.no_grad():
    # Get the pre-trained weights for the original 5 classes
    pretrained_fc_weights = torch.load(pretrained_weights_path)['fc.weight']
    pretrained_fc_bias = torch.load(pretrained_weights_path)['fc.bias']

    # Initialize the first 5 weights with pre-trained values
    model.fc.weight[:5] = pretrained_fc_weights
    model.fc.bias[:5] = pretrained_fc_bias

    # Initialize the remaining 5 weights with Xavier initialization
    nn.init.xavier_uniform_(model.fc.weight[5:])
    model.fc.bias[5:].zero_()

# Move the model to the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
old_model = model.to(device)

# Initialize the new model
new_model = resnet18(num_classes=10).to(device)
with torch.no_grad():
    nn.init.xavier_uniform_(new_model.fc.weight[5:])
    new_model.fc.bias[5:].zero_()

In [16]:
# Test, Checking process for whether weight load is properly done, used only for initialization purposes
# Set to "old_model"

# class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# correct = {f"{i}": 0 for i in range(10)}
# total = {f"{i}": 0 for i in range(10)}

# old_model.eval()
# with torch.no_grad():
#     for images, labels in test_loader:
#         images = images.to(device)
#         outputs = old_model(images)
#         _, predicted = torch.max(outputs.data, 1)
        
#         for label, prediction in zip(labels, predicted):
#             if label == prediction.cpu():
#                 correct[str(label.item())] += 1
#             total[str(label.item())] += 1
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         total_ += total[idx]
#         correct_ += correct_count
#     accuracy = 100 * float(correct_) / total_
#     print(f'Total Accuracy: {accuracy:3.1f} %')
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         if int(idx) in [0,1,2,3,4]:
#             total_ += total[idx]
#             correct_ += correct_count
#     try:
#         accuracy = 100 * float(correct_) / total_
#     except ZeroDivisionError:
#         print(f'Old classes Accuracy: 00.0 %')
#     else:
#         print(f'Old classes Accuracy: {accuracy:3.1f} %')
    
#     total_ = 0.
#     correct_ = 0.
#     for idx, correct_count in correct.items():
#         if int(idx) in [5,6,7,8,9]:
#             total_ += total[idx]
#             correct_ += correct_count
#     try:
#         accuracy = 100 * float(correct_) / total_
#     except ZeroDivisionError:
#         print(f'Novel classes Accuracy: 00.0 %')
#     else:
#         print(f'Novel classes Accuracy: {accuracy:3.1f} %')
    
#     print('Accuracy for class')
#     for idx, correct_count in correct.items():
#         accuracy = 100 * float(correct_count) / total[idx]
#         print(f'\t{class_name[int(idx)]:10s} : {accuracy:3.1f} %')

In [17]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = dsets.CIFAR10(
    root='./data/CIFAR10/',
    train=True,
    transform=transform_train,
    download=True
)

# New dataset for 0~4
train_new_dataset = remove_samples(train_dataset, [0,1,2,3,4])

train_new_loader = torch.utils.data.DataLoader(
    dataset=train_new_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

Files already downloaded and verified


In [18]:
# Evaluation set as a function, for simplicity in implementing and finding acc for future models

def evaluate_model(model, test_loader, device):
    class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    correct = {str(i): 0 for i in range(10)}
    total = {str(i): 0 for i in range(10)}
    
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            for label, prediction in zip(labels, predicted):
                if label == prediction.cpu():
                    correct[str(label.item())] += 1
                total[str(label.item())] += 1
    
    # Calculate total accuracy
    total_count = sum(total.values())
    correct_count = sum(correct.values())
    total_accuracy = 100 * float(correct_count) / total_count
    print(f'Total Accuracy: {total_accuracy:.1f} %')
    
    # Calculate old classes accuracy
    old_classes = [0, 1, 2, 3, 4]
    old_total = sum(total[str(i)] for i in old_classes)
    old_correct = sum(correct[str(i)] for i in old_classes)
    if old_total > 0:
        old_accuracy = 100 * float(old_correct) / old_total
        print(f'Old classes Accuracy: {old_accuracy:.1f} %')
    else:
        print(f'Old classes Accuracy: 00.0 %')
    
    # Calculate novel classes accuracy
    novel_classes = [5, 6, 7, 8, 9]
    novel_total = sum(total[str(i)] for i in novel_classes)
    novel_correct = sum(correct[str(i)] for i in novel_classes)
    if novel_total != 0:
        novel_accuracy = 100 * float(novel_correct) / novel_total
        print(f'Novel classes Accuracy: {novel_accuracy:.1f} %')
    else:
        print(f'Novel classes Accuracy: 00.0 %')
    
    # Accuracy for each class
    print('Accuracy for each class:')
    for idx, correct_count in correct.items():
        accuracy = 100 * float(correct_count) / total[idx]
        print(f'\t{class_name[int(idx)]:10s} : {accuracy:.1f} %')


In [19]:
# Fisher Information matrix
def compute_fisher_information(model, dataloader, criterion, device):
    model.eval()
    fisher_information = {n: torch.zeros(p.shape).to(device) for n, p in model.named_parameters()}
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        model.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher_information[n] += p.grad.data.pow(2) * len(labels)
    fisher_information = {n: p / len(dataloader.dataset) for n, p in fisher_information.items()}
    return fisher_information

# EWC loss function
def ewc_loss(model, fisher_information, old_parameters, lambda_ewc):
    loss_ewc = 0
    for n, p in model.named_parameters():
        if n in fisher_information:
            loss_ewc += (fisher_information[n] * (p - old_parameters[n]).pow(2)).sum()
    return lambda_ewc * loss_ewc

In [21]:
# Distillation loss for LwF
def distillation_loss(y, teacher_scores, T, alpha):
    p = F.log_softmax(y / T, dim=1)
    q = F.softmax(teacher_scores / T, dim=1)
    l_kl = F.kl_div(p, q, reduction='batchmean') * (T * T * 2.0 * alpha)
    l_ce = F.cross_entropy(y, teacher_scores.argmax(dim=1)) * (1. - alpha)
    return l_kl + l_ce

# LwF Training
def train_lwf_ewc(student_model, teacher_model, train_loader, optimizer, criterion, fisher_information, old_parameters, lambda_ewc, alpha, T, device, epochs=20):
    student_model.train()
    teacher_model.eval()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = student_model(images)
            
            # Get soft targets from the teacher model
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            
            # Compute the classification loss, distillation loss, and EWC loss
            loss_ce = criterion(outputs, labels)
            loss_kd = distillation_loss(outputs, teacher_outputs, T, alpha)
            loss_ewc = ewc_loss(student_model, fisher_information, old_parameters, lambda_ewc)
            
            loss = loss_ce + loss_kd + 100*loss_ewc
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}')
        evaluate_model(student_model, test_loader, device)

# Compute Fisher Information after training the old model
lambda_ewc = 0.4  # Regularization strength for EWC
old_model.eval()
old_parameters = {n: p.clone().detach() for n, p in old_model.named_parameters()}
fisher_information = compute_fisher_information(old_model, train_loader, nn.CrossEntropyLoss(), device)

# Fine-tune the new model with combined LwF and EWC
alpha = 0.1  # LwF distillation parameter
T = 2.0     # Temperature for distillation
optimizer = torch.optim.SGD(new_model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
train_lwf_ewc(new_model, old_model, train_new_loader, optimizer, nn.CrossEntropyLoss(), fisher_information, old_parameters, lambda_ewc, alpha, T, device, epochs=50)

Epoch 1/50, Loss: 4.971718323474028
Total Accuracy: 37.8 %
Old classes Accuracy: 32.9 %
Novel classes Accuracy: 42.7 %
Accuracy for each class:
	airplane   : 43.5 %
	automobile : 15.6 %
	bird       : 1.9 %
	cat        : 63.9 %
	deer       : 39.4 %
	dog        : 30.5 %
	frog       : 41.9 %
	horse      : 32.1 %
	ship       : 27.7 %
	truck      : 81.2 %
Epoch 2/50, Loss: 5.477641361100333
Total Accuracy: 29.9 %
Old classes Accuracy: 20.1 %
Novel classes Accuracy: 39.6 %
Accuracy for each class:
	airplane   : 7.4 %
	automobile : 12.3 %
	bird       : 0.0 %
	cat        : 56.8 %
	deer       : 24.0 %
	dog        : 9.7 %
	frog       : 52.6 %
	horse      : 12.4 %
	ship       : 86.8 %
	truck      : 36.6 %
Epoch 3/50, Loss: 4.644251704216003
Total Accuracy: 38.4 %
Old classes Accuracy: 40.2 %
Novel classes Accuracy: 36.7 %
Accuracy for each class:
	airplane   : 55.9 %
	automobile : 48.6 %
	bird       : 4.4 %
	cat        : 68.0 %
	deer       : 24.2 %
	dog        : 28.0 %
	frog       : 57.9 %
	horse

In [22]:
torch.save(new_model.state_dict(), 'new_model.pth')