In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
import numpy as np
import random
import os
import time
from datetime import datetime
import copy

## set device

In [4]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:7' if is_cuda else 'cpu')
print('Current cuda device is', device)
torch.set_num_threads(2)

Current cuda device is cuda:7


## Data preprocessing

In [5]:
normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                     std=[0.2675, 0.2565, 0.2761])
transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

In [7]:
train_data = datasets.CIFAR100(root = '../CIFAR-100/data/02/',
                            train=True,
                            download=True,
                            transform=transform_train)
test_data = datasets.CIFAR100(root = '../CIFAR-100/data/02/',
                            train=False,
                            download=True,
                            transform=transform_test)
print('number of training data : ', len(train_data))
print('number of test data : ', len(test_data))

Files already downloaded and verified
Files already downloaded and verified
number of training data :  50000
number of test data :  10000


In [8]:
num_task = 10
num_class = 100
per_num_class = num_class//num_task
train_task = {x: [] for x in range(num_task)}
test_task = {x: [] for x in range(num_task)}

train_class_idx = {x: [] for x in range(num_class)}
test_class_idx = {x: [] for x in range(num_class)}

cnt = 0
for data in train_data:
    x, y = data
    train_class_idx[y].append(cnt)
    cnt +=1
    
print(len(train_class_idx[0]))

cnt = 0
for data in test_data:
    x, y = data
    test_class_idx[y].append(cnt)
    cnt +=1
    
for i in range(num_task):
    curr_task_idx_train = []
    curr_task_idx_test = []
    for j in range(per_num_class):
        curr_task_idx_train += train_class_idx[i*per_num_class+j]
        curr_task_idx_test += test_class_idx[i*per_num_class+j]
    train_task[i] = [train_data[j] for j in curr_task_idx_train]
    test_task[i] = [test_data[j] for j in curr_task_idx_test]

500


# T-CIL

In [18]:
# from utils import evaluation, tune_temp, tune_temp_batch, adaptive_evaluation, tune_temp_batch_efficient
from utils import cal_ece, cal_aece, tune_temp_batch_efficient, set_seed
from TCIL import AdversarialTrainer, find_optimal_epsilon
from model import resnet32
method_type = 'ER'
num_tasks = 10
num_classes = 100
num_class_per_task = num_classes//num_tasks
batch_size = 128
# seed_list = [1, 2, 3, 4, 5]
seed_list = [1]
val_size = 100
epochs = 100

In [22]:
all_seeds_ece_overall = []
all_seeds_ece_task = []
all_seeds_aece_overall = []
all_seeds_aece_task = []

for seed in seed_list:
    print(f"\n=== Running experiment with seed {seed} ===")
    ece_overall_hist = []
    ece_task_hist = []
    aece_overall_hist = []
    aece_task_hist = []
    set_seed(seed)
    for t in range(num_tasks):
        test_task_total = copy.deepcopy(test_task[0])
        if t > 0:
            for i in range(1, t+1):
                test_task_total += test_task[i]
        model = resnet32().to(device)
        model.load_state_dict(torch.load(f'./saved_model/{method_type}_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task.pt', weights_only=True))
        model.eval()
        buffer_indices = np.load(f'./saved_buffer_indices/{method_type}_buffer_indices_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task.npy')
        buffer = torch.utils.data.Subset(train_data, buffer_indices)
        valid_new_task_indices = np.load(f'./saved_buffer_indices/{method_type}_valid_indices_{seed}_seed_{num_task}_tasks_{val_size}_val_{t}_task.npy')
        valid_new_task = torch.utils.data.Subset(train_data, valid_new_task_indices)
        temperature_new_task_opt = tune_temp_batch_efficient(model, valid_new_task, (t+1)*num_class_per_task, epochs, batch_size, device).item()
        
        new_task_idx = []
        for j, (data, label) in enumerate(buffer):
            if label // num_class_per_task == t:
                new_task_idx.append(j)
        new_task_data = torch.utils.data.Subset(buffer, new_task_idx)
        
        trainer = AdversarialTrainer(model, device, method_type)
        
        best_epsilon = find_optimal_epsilon(
            trainer=trainer,
            buffer_data=buffer,
            valid_data=new_task_data,
            target_temp=temperature_new_task_opt,
            num_class=(t+1)*num_class_per_task,
            num_task=t+1,
            epochs=epochs,
            batch_size=batch_size
        )
        
        adv_data, labels = trainer.generate_adversarial_data(buffer, buffer, 
                                                               (t+1)*num_class_per_task, batch_size, best_epsilon, num_class_per_task)
        adv_dataset = torch.utils.data.TensorDataset(adv_data, labels)
        adv_loader = torch.utils.data.DataLoader(adv_dataset, batch_size=batch_size, shuffle=True)
            
        temperature = trainer.get_temperature(adv_loader, (t+1)*num_class_per_task, (t+1), epochs, batch_size).item()
        
        ece_overall, Bm, acc, conf, _, _ = cal_ece(model, test_task_total, (t+1)*num_class_per_task, num_class_per_task,
                                              batch_size, n_bins = 10, temperature = temperature, device = device)
        aece_overall, _, _, _ = cal_aece(model, test_task_total, (t+1)*num_class_per_task, num_class_per_task,
                                                batch_size, n_bins = 10, temperature = temperature, device = device)
        ece_overall_hist.append(ece_overall)
        aece_overall_hist.append(aece_overall)
        print(f"[Task {t}] ece_overall: {ece_overall*100:.2f}, aece_overall: {aece_overall*100:.2f}")

    
    print(f"[Seed {seed} Avg] ece_overall: {np.mean(ece_overall_hist)*100:.2f}, aece_overall: {np.mean(aece_overall_hist)*100:.2f}")
    all_seeds_ece_overall.append(np.mean(ece_overall_hist)*100)
    all_seeds_aece_overall.append(np.mean(aece_overall_hist)*100)

final_ece_overall_mean = np.mean(all_seeds_ece_overall)
final_ece_overall_std = np.std(all_seeds_ece_overall)

final_aece_overall_mean = np.mean(all_seeds_aece_overall)
final_aece_overall_std = np.std(all_seeds_aece_overall)

print("\n=== Final Results ===")
print(f"ECE Overall - Mean: {final_ece_overall_mean:.2f}%, Std: {final_ece_overall_std:.2f}%, AECE Overall - Mean: {final_aece_overall_mean:.2f}%, Std: {final_aece_overall_std:.2f}%")


=== Running experiment with seed 1 ===
[Task 0] ece_overall: 3.01, aece_overall: 2.64
[Task 1] ece_overall: 11.11, aece_overall: 11.10
[Task 2] ece_overall: 5.66, aece_overall: 5.88
[Task 3] ece_overall: 4.49, aece_overall: 4.67
[Task 4] ece_overall: 3.43, aece_overall: 3.48
[Task 5] ece_overall: 4.32, aece_overall: 4.34
[Task 6] ece_overall: 5.90, aece_overall: 5.73
[Task 7] ece_overall: 3.92, aece_overall: 4.08
[Task 8] ece_overall: 6.19, aece_overall: 6.19
[Task 9] ece_overall: 9.35, aece_overall: 9.35
[Seed 1 Avg] ece_overall: 5.74, aece_overall: 5.75

=== Final Results ===
ECE Overall - Mean: 5.74%, Std: 0.00%, AECE Overall - Mean: 5.75%, Std: 0.00%
