In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
import numpy as np
import random
import os
import time
from datetime import datetime
import copy

In [2]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:7' if is_cuda else 'cpu')

print('Current cuda device is', device)

Current cuda device is cuda:7


In [3]:
normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                     std=[0.2675, 0.2565, 0.2761])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

In [5]:
train_data = datasets.CIFAR100(root = '../CIFAR-100/data/02/',
                            train=True,
                            download=False,
                            transform=transform_train)
test_data = datasets.CIFAR100(root = '../CIFAR-100/data/02/',
                            train=False,
                            download=False,
                            transform=transform_test)
print('number of training data : ', len(train_data))
print('number of test data : ', len(test_data))

number of training data :  50000
number of test data :  10000


In [6]:
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [7]:
num_task = 10
num_class = 100
num_class_per_task = num_class//num_task
train_task = {x: [] for x in range(num_task)}
test_task = {x: [] for x in range(num_task)}

train_class_idx = {x: [] for x in range(num_class)}
test_class_idx = {x: [] for x in range(num_class)}

cnt = 0
for data in train_data:
    x, y = data
    train_class_idx[y].append(cnt)
    cnt +=1
    
print(len(train_class_idx[0]))

cnt = 0
for data in test_data:
    x, y = data
    test_class_idx[y].append(cnt)
    cnt +=1
    
for i in range(num_task):
    curr_task_idx_train = []
    curr_task_idx_test = []
    for j in range(num_class_per_task):
        curr_task_idx_train += train_class_idx[i*num_class_per_task+j]
        curr_task_idx_test += test_class_idx[i*num_class_per_task+j]
    train_task[i] = [train_data[j] for j in curr_task_idx_train]
    test_task[i] = [test_data[j] for j in curr_task_idx_test]

500


In [8]:
print(len(train_task[0]))
print(len(train_task[4]))
print(len(train_class_idx[0]))

5000
5000
500


# Experience Replay

In [9]:
batch_size = 128
learning_rate = 1e-1
num_epochs = 200
val_size = 100
total_memory_size = 2000
method_type = 'ER'

In [None]:
from torch.optim.lr_scheduler import MultiStepLR
from model import resnet32
from tqdm import tqdm

print(datetime.now())
acc_list=[]
forget_list = []
task_order = np.arange(num_task)
milestones=[100, 150]
start_time = time.time()
PATH_model = './saved_model/'
PATH_buffer = './saved_buffer_indices/'

for seed in [1]:
    set_seed(seed)
    model = resnet32().to(device)
    criterion = nn.CrossEntropyLoss()
    final_avg_acc = 0

    for t in task_order:
        print(f"---task {t:d}---")
        optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum=0.9, weight_decay=2e-4)
        scheduler = MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.1)
        num_total_class = (t+1) * num_class_per_task
        num_known_class = t * num_class_per_task
        
        new_task_data_idx = []
        for i in range(num_class_per_task):
            new_task_data_idx += train_class_idx[num_known_class + i]
        num_data_per_class_new = len(new_task_data_idx)//num_class_per_task

        num_data_per_class_valid = val_size // num_class_per_task
        valid_idx = []
        for i in range(num_class_per_task):
            start_idx = i*num_data_per_class_new
            end_idx = (i+1)*num_data_per_class_new
            valid_idx += random.sample(new_task_data_idx[start_idx:end_idx], num_data_per_class_valid)
        new_task_data_idx = [x for x in new_task_data_idx if x not in valid_idx]
        num_data_per_class_new = len(new_task_data_idx)//num_class_per_task
        new_task_data = torch.utils.data.Subset(train_data, new_task_data_idx)
        print(f"train_size: {len(new_task_data)}, val_size: {len(valid_idx)}")
        
        if t == 0:
            train_loader = torch.utils.data.DataLoader(dataset=new_task_data,
                                                       batch_size = batch_size, shuffle = True, num_workers = 4)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=new_task_data + buffer,
                                                        batch_size = batch_size, shuffle = True, num_workers = 4)
            
        for epoch in tqdm(range(num_epochs)):
            model.train()
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output[:, :num_total_class], target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        torch.save(model.state_dict(), PATH_model+f'{method_type}_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task.pt')
        
        if t > 0:
            num_data_per_class_buffer_prev = total_memory_size // (num_known_class)
        num_data_per_class_buffer = total_memory_size // (num_total_class)
        temp_buffer_idx = []
        for i in range(num_total_class):
            if i < num_known_class:
                start_idx = i*num_data_per_class_buffer_prev
                end_idx = (i+1)*num_data_per_class_buffer_prev
                temp_buffer_idx += random.sample(buffer_idx[start_idx:end_idx], num_data_per_class_buffer)
            else:
                temp_idx = i - num_known_class
                start_idx = temp_idx*num_data_per_class_new
                end_idx = (temp_idx+1)*num_data_per_class_new
                temp_buffer_idx += random.sample(new_task_data_idx[start_idx:end_idx], num_data_per_class_buffer)
        buffer_idx = copy.deepcopy(temp_buffer_idx)
        
        buffer = torch.utils.data.Subset(train_data, buffer_idx)
        print(len(buffer))
        
        np.save(PATH_buffer + f'{method_type}_buffer_indices_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task', np.array(buffer_idx))
        np.save(PATH_buffer + f'{method_type}_valid_indices_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task', np.array(valid_idx))
        np.save(PATH_buffer + f'{method_type}_train_indices_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task', np.array(new_task_data_idx))
        print(f"[Acc]")
        model.eval()
        avg_acc = 0
        for test_task_idx in range(t+1):
            correct = 0
            test_loader = torch.utils.data.DataLoader(dataset=test_task[test_task_idx],
                                                       batch_size = batch_size, shuffle = False)
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)[:, :num_total_class]
                    prediction = output.data.max(1)[1]
                    correct += prediction.eq(target.data).sum()
                temp_acc = correct/len(test_loader.dataset)
            avg_acc += temp_acc / (t+1)
            print(f"task {test_task_idx:d}: {temp_acc*100:.2f}")
        final_avg_acc += avg_acc/num_task
    print(f"SEED: {seed}, [Total] Acc: {final_avg_acc*100:.2f}\n")
print(f"runtime: {(time.time()-start_time)/60:.2f} mins")

2024-11-19 21:40:06.563346
---task 0---
train_size: 4900, val_size: 100


100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [02:36<00:00,  1.28it/s]


2000
[Acc]
task 0: 87.80
---task 1---
train_size: 4900, val_size: 100


100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [03:12<00:00,  1.04it/s]


2000
[Acc]
task 0: 73.90
task 1: 82.00
---task 2---
train_size: 4900, val_size: 100


100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [03:11<00:00,  1.05it/s]


1980
[Acc]
task 0: 61.30
task 1: 59.80
task 2: 87.10
---task 3---
train_size: 4900, val_size: 100


  1%|▊                                                                                    | 2/200 [00:02<03:30,  1.06s/it]

Val size: 0
SEED: 1, [Total] Acc: 56.75
SEED: 2, [Total] Acc: 56.38
SEED: 3, [Total] Acc: 56.47

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()