In [1]:
import os
import time
import math
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from torch import device
from torchvision import datasets
from torch.utils.data import DataLoader, Subset

from models import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# DEFINE SETTINGS

In [2]:
out_path = os.path.join('..', 'results', 'mnist') # remove or add 'convs'
strategy = "classIL"     # ["taskIL", "classIL"] classIL is harder 
lrs = [1e-2]
decays = [0.8]
epochss = [10]
models = [Efficient_KAN_Fix(strategy, device)]
longer_last_tasks = False
reverse_taks = False

reverse_path = ""
if reverse_taks:
    reverse_path = "reverse_tasks"
longer_last_path = ""
if longer_last_tasks:
    longer_last_path = "longer_last_tasks"

out_path = os.path.join(out_path, strategy, longer_last_path, reverse_path, 'trainings')
cfgs = []
for model in models[:1]:
    for lr in lrs:
        for decay in decays:
            for epochs in epochss:
                cfgs.append([model, epochs, lr, decay])

# Train and test sets

In [3]:
dataset = [datasets.MNIST, datasets.CIFAR10][0]
dataset_name = dataset.__name__.lower()
input_size = 28 * 28 if dataset == datasets.MNIST \
    else 3 * 32 * 32 if dataset == datasets.CIFAR10 \
    else -1

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.Normalize((0.5,), (0.5,))
                                ])
# Train set. Here we sort the MNIST by digits and disable data shuffling
train_dataset = dataset(root='../data', train=True, download=True, transform=transform)
sorted_indices = sorted(range(len(train_dataset) // 1), key=lambda idx: train_dataset.targets[idx])
train_subset = Subset(train_dataset, sorted_indices)
train_loader = DataLoader(train_subset, batch_size=64, shuffle=False)

# MultiTask training sets
train_loader_tasks = []
indices = []
for k in range(5):
    indices.append(list(
        filter(lambda idx: train_dataset.targets[idx] in range(k * 2, k * 2 + 2), range(len(train_dataset)))))
    train_loader_tasks.append(
        DataLoader(Subset(train_dataset, indices[-1]), batch_size=64, shuffle=True))

# Test set
test_dataset = dataset(root='../data', train=False, download=True, transform=transform)
test_subset = Subset(test_dataset, range(len(test_dataset) // 1))
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

if reverse_taks:
    train_loader_tasks.reverse()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:15<00:00, 634492.07it/s] 


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 101712.56it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:05<00:00, 297456.45it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 12641359.50it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [5]:
# stats = [0 for i in range(10)]
# for sample in test_dataset:
#     stats[sample[1]] += 1
# print(stats)
# mean = sum(stats)/len(stats)
# variance = sum([((x - mean) ** 2) for x in stats]) / len(stats) 
# res = variance ** 0.5
# print(mean, res)

## Trainset visualizer
The following code prints the images of the 5 domain IL scenarios. This way we can clearly see that for the MNIST dataset each task contains a pair of digits (0-1, 2-3, etc.), while for CIFAR10 each task contains a pair of objects (car-airplane, bird-dog, deer-dog, frog-horse and truck-ship).

In [6]:
# import numpy as np
# def imshow(img):
#     # img = (img / 2 + 0.5).numpy()
#     img = img.numpy()
#     plt.imshow(np.transpose(img, (1, 2, 0)))
#     plt.axis('off')
#     plt.show()


# def show_images(class_index, num_images=16):
#     dataiter = iter(train_loader_tasks[class_index])
#     images, labels = next(dataiter)
#     imshow(utils.make_grid(images))


# for class_index in range(5):
#     print(f"TASK ID = {class_index}")
#     show_images(class_index)

# Train and test functions

In [7]:
def train(model, save_dir, optimizer, lr, on_epoch_end, start_epoch=0, epochs=5, isKAN=False):
    criterion = nn.NLLLoss()
    for epoch in range(start_epoch, epochs + start_epoch):
        if not isKAN:
            model.train()
            model.to(device)
        epoch_start = time.time_ns()
        with tqdm(train_loader) as pbar:
            for images, labels in pbar:
                labels = labels.to(device)
                images = images.to(device)
                optimizer.zero_grad()
                output = model(images)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step(closure=lambda: loss)
                accuracy = (output.argmax(dim=1) == labels).float().mean()
                pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
        epoch_duration = (time.time_ns() - epoch_start) // 1000000
        if on_epoch_end is not None:
            on_epoch_end(model, save_dir, epoch, loss.item(), epoch_duration, lr, isKAN)

In [8]:
def test(model, isKAN=False):
    if not isKAN:
        model.eval()
    criterion = nn.NLLLoss()
    predictions = []
    ground_truths = []
    val_accuracy = 0
    loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            labels = labels.to(device)  #(labels % 2 if model.layers[-1] == 2 else labels).to(device)
            images = images.to(device)
            output = model(images)
            loss = criterion(output, labels)
            predictions.extend(output.argmax(dim=1).to('cpu').numpy())
            ground_truths.extend(labels.to('cpu').numpy())
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_accuracy /= len(test_loader)
    print(f"Accuracy: {val_accuracy}")
    return loss.item(), ground_truths, predictions

In [9]:
class EpochStat:
    @staticmethod
    def loadModelStats(dir, name, subdir) -> list['EpochStat']:
        return sorted([pickle.load(open(os.path.join(dir, subdir, file), 'rb')) for file in
                       filter(lambda e: name == '_'.join(e.split('_')[:-1]), os.listdir(os.path.join(dir, subdir)))],
                      key=lambda e: e.epoch)

    def __init__(self, name, save_dir, epoch, train_loss=0, test_loss=0, labels=None, predictions=None, epoch_duration=0, lr=0):
        self.name = name
        self.save_dir = save_dir
        self.train_loss = train_loss
        self.test_loss = test_loss
        self.epoch = epoch
        self.predictions = predictions
        self.labels = labels
        self.epoch_duration = epoch_duration
        self.lr = lr
        self.train_losses = []
        self.train_accuracies = []

    def save(self):
        os.makedirs(self.save_dir, exist_ok=True)
        pickle.dump(self, open(os.path.join(self.save_dir, self.name + '_e' + str(self.epoch) + '.pickle'), 'wb'))

    def get_accuracy(self):
        accuracy = 0
        for label, prediction in zip(self.labels, self.predictions):
            if label == prediction:
                accuracy += 1
        return accuracy / len(self.labels)


def onEpochEnd(model, save_dir, epoch, train_loss, epoch_duration, lr, isKAN):
    test_loss, labels, predictions = test(model, isKAN)
    stat = EpochStat(model.__class__.__name__, save_dir, epoch, train_loss, test_loss, labels, predictions, epoch_duration, lr)
    stat.save()

# Domain IL - training

In [10]:
for cfg in cfgs:
    model = cfg[0]
    epochs = cfg[1]
    lr = cfg[2]
    decay_f = cfg[3]
    if decay_f == 1:
        lr_decay = False
    else:
        lr_decay = True
    start_epochs_list = [int(epochs + epochs*i[0]) for i in enumerate(train_loader_tasks)]
    start_epochs_list.insert(0, 0)
    naam = model.__class__.__name__
    isKAN = False
    print("\n\n", naam)
    print(epochs, lr, decay_f, "\n")
    if 'Py_KAN' in naam:
        isKAN = True
    for i, task in enumerate(train_loader_tasks):
        epochs_act = epochs
        if longer_last_tasks and i > 3:
            epochs_act = epochs + epochs

        str_print = f'\t\t\t\tTRAINING ON TASK {i}'
        str_print +=  f' for {epochs_act} epochs' 
        print(str_print)
        # str_epoch = f"ep{epochs}_10fin_"
        str_epoch = f"ep{epochs}"
        str_lr = f"_lr{round(math.log10(lr))}"
        str_decay = '_dec'+ str(decay_f) if lr_decay else ''
        lr_act = lr * decay_f**(i)
        train(model, os.path.join(out_path,f"{str_epoch}{str_lr}{str_decay}", naam), optimizer=optim.Adam(model.parameters()),
                lr=lr_act, on_epoch_end=onEpochEnd, start_epoch=start_epochs_list[i], epochs=epochs_act, isKAN=isKAN)
    torch.cuda.empty_cache()



 Efficient_KAN_Fix
10 0.01 0.8 

				TRAINING ON TASK 0 for 10 epochs


100%|██████████| 938/938 [00:13<00:00, 68.37it/s, accuracy=1, loss=0.0245, lr=0.001]   


Epoch 1, Loss: 0.024505821268227493
Accuracy: 0.10071656050955415


100%|██████████| 938/938 [00:12<00:00, 77.35it/s, accuracy=1, loss=0.024, lr=0.001]     


Epoch 2, Loss: 0.024010357680064084
Accuracy: 0.11146496815286625


100%|██████████| 938/938 [00:12<00:00, 72.99it/s, accuracy=1, loss=0.0197, lr=0.001]    


Epoch 3, Loss: 0.019666902333475355
Accuracy: 0.19098328025477707


100%|██████████| 938/938 [00:11<00:00, 79.15it/s, accuracy=1, loss=0.0196, lr=0.001]    


Epoch 4, Loss: 0.019637820211908588
Accuracy: 0.26383359872611467


100%|██████████| 938/938 [00:11<00:00, 78.31it/s, accuracy=0.969, loss=0.0303, lr=0.001]


Epoch 5, Loss: 0.030268004302943997
Accuracy: 0.28174761146496813


100%|██████████| 938/938 [00:12<00:00, 77.75it/s, accuracy=0.969, loss=0.0377, lr=0.001]


Epoch 6, Loss: 0.03773374510626884
Accuracy: 0.2864251592356688


100%|██████████| 938/938 [00:11<00:00, 78.96it/s, accuracy=0.969, loss=0.0538, lr=0.001]


Epoch 7, Loss: 0.05384498206819273
Accuracy: 0.3054339171974522


100%|██████████| 938/938 [00:11<00:00, 79.89it/s, accuracy=0.969, loss=0.0536, lr=0.001]


Epoch 8, Loss: 0.0535617967265714
Accuracy: 0.3213574840764331


100%|██████████| 938/938 [00:11<00:00, 80.63it/s, accuracy=1, loss=0.0213, lr=0.001]    


Epoch 9, Loss: 0.021263594637896004
Accuracy: 0.34225716560509556


100%|██████████| 938/938 [00:11<00:00, 79.66it/s, accuracy=0.969, loss=0.0237, lr=0.001]


Epoch 10, Loss: 0.02365722935065369
Accuracy: 0.3231488853503185
				TRAINING ON TASK 1 for 10 epochs


100%|██████████| 938/938 [00:11<00:00, 80.26it/s, accuracy=1, loss=0.00747, lr=0.001]   


Epoch 11, Loss: 0.0074746906638473625
Accuracy: 0.28244426751592355


100%|██████████| 938/938 [00:12<00:00, 75.37it/s, accuracy=1, loss=0.00376, lr=0.001]   


Epoch 12, Loss: 0.0037568708044690667
Accuracy: 0.3476313694267516


100%|██████████| 938/938 [00:11<00:00, 79.32it/s, accuracy=1, loss=0.00123, lr=0.001]   


Epoch 13, Loss: 0.001225815626377133
Accuracy: 0.36186305732484075


100%|██████████| 938/938 [00:11<00:00, 78.28it/s, accuracy=1, loss=0.000461, lr=0.001]  


Epoch 14, Loss: 0.00046126564874150207
Accuracy: 0.3335987261146497


100%|██████████| 938/938 [00:11<00:00, 82.30it/s, accuracy=1, loss=0.000994, lr=0.001]  


Epoch 15, Loss: 0.0009944117102907162
Accuracy: 0.33767914012738853


100%|██████████| 938/938 [00:11<00:00, 82.18it/s, accuracy=1, loss=0.000205, lr=0.001]  


Epoch 16, Loss: 0.0002053679657114599
Accuracy: 0.1979498407643312


100%|██████████| 938/938 [00:11<00:00, 81.93it/s, accuracy=0.969, loss=0.0278, lr=0.001]


Epoch 17, Loss: 0.027781268024435483
Accuracy: 0.3271297770700637


100%|██████████| 938/938 [00:11<00:00, 83.50it/s, accuracy=1, loss=6.98e-5, lr=0.001]   


Epoch 18, Loss: 6.980861427650621e-05
Accuracy: 0.29926353503184716


100%|██████████| 938/938 [00:11<00:00, 82.87it/s, accuracy=1, loss=0.000156, lr=0.001]  


Epoch 19, Loss: 0.00015636608544214507
Accuracy: 0.3919187898089172


100%|██████████| 938/938 [00:11<00:00, 82.76it/s, accuracy=1, loss=8.24e-5, lr=0.001]   


Epoch 20, Loss: 8.244020851881166e-05
Accuracy: 0.35509554140127386
				TRAINING ON TASK 2 for 10 epochs


100%|██████████| 938/938 [00:11<00:00, 82.30it/s, accuracy=1, loss=5.99e-5, lr=0.001]   


Epoch 21, Loss: 5.987012963802092e-05
Accuracy: 0.35957404458598724


100%|██████████| 938/938 [00:11<00:00, 80.80it/s, accuracy=1, loss=0.000289, lr=0.001]  


Epoch 22, Loss: 0.00028946571751239854
Accuracy: 0.37798566878980894


100%|██████████| 938/938 [00:11<00:00, 82.34it/s, accuracy=1, loss=0.000543, lr=0.001]  


Epoch 23, Loss: 0.000543470594816005
Accuracy: 0.36345541401273884


100%|██████████| 938/938 [00:12<00:00, 77.56it/s, accuracy=1, loss=6.47e-5, lr=0.001]   


Epoch 24, Loss: 6.471557713456317e-05
Accuracy: 0.42237261146496813


100%|██████████| 938/938 [00:11<00:00, 82.43it/s, accuracy=1, loss=0.000349, lr=0.001]  


Epoch 25, Loss: 0.0003494618428742788
Accuracy: 0.4122213375796178


100%|██████████| 938/938 [00:11<00:00, 80.38it/s, accuracy=1, loss=0.0024, lr=0.001]    


Epoch 26, Loss: 0.002395448518947993
Accuracy: 0.4133160828025478


100%|██████████| 938/938 [00:11<00:00, 82.44it/s, accuracy=1, loss=9.77e-5, lr=0.001]   


Epoch 27, Loss: 9.770370702376249e-05
Accuracy: 0.3919187898089172


100%|██████████| 938/938 [00:11<00:00, 82.12it/s, accuracy=1, loss=0.000121, lr=0.001]  


Epoch 28, Loss: 0.00012131712129320144
Accuracy: 0.44874601910828027


100%|██████████| 938/938 [00:11<00:00, 82.38it/s, accuracy=1, loss=0.00027, lr=0.001]   


Epoch 29, Loss: 0.00026982137414204954
Accuracy: 0.46466958598726116


100%|██████████| 938/938 [00:11<00:00, 82.29it/s, accuracy=1, loss=8.21e-5, lr=0.001]   


Epoch 30, Loss: 8.207605366485048e-05
Accuracy: 0.48039410828025475
				TRAINING ON TASK 3 for 10 epochs


100%|██████████| 938/938 [00:11<00:00, 81.92it/s, accuracy=1, loss=9.79e-5, lr=0.001]   


Epoch 31, Loss: 9.78673220214209e-05
Accuracy: 0.4713375796178344


100%|██████████| 938/938 [00:11<00:00, 81.43it/s, accuracy=1, loss=0.000105, lr=0.001]  


Epoch 32, Loss: 0.00010478774259994228
Accuracy: 0.4850716560509554


100%|██████████| 938/938 [00:11<00:00, 78.35it/s, accuracy=1, loss=5.12e-5, lr=0.001]   


Epoch 33, Loss: 5.1174370216823454e-05
Accuracy: 0.4941281847133758


100%|██████████| 938/938 [00:11<00:00, 78.65it/s, accuracy=1, loss=2.04e-5, lr=0.001]   


Epoch 34, Loss: 2.0445477689004558e-05
Accuracy: 0.46725716560509556


100%|██████████| 938/938 [00:11<00:00, 81.74it/s, accuracy=1, loss=4.89e-5, lr=0.001]   


Epoch 35, Loss: 4.893855284956412e-05
Accuracy: 0.45203025477707004


100%|██████████| 938/938 [00:11<00:00, 80.48it/s, accuracy=1, loss=1.95e-6, lr=0.001]   


Epoch 36, Loss: 1.9494519673330454e-06
Accuracy: 0.45451831210191085


100%|██████████| 938/938 [00:11<00:00, 81.72it/s, accuracy=1, loss=0.000318, lr=0.001]  


Epoch 37, Loss: 0.0003178591494160477
Accuracy: 0.5157245222929936


100%|██████████| 938/938 [00:11<00:00, 82.09it/s, accuracy=1, loss=0.000114, lr=0.001]  


Epoch 38, Loss: 0.00011410422681981975
Accuracy: 0.5428941082802548


100%|██████████| 938/938 [00:11<00:00, 82.37it/s, accuracy=1, loss=9.34e-7, lr=0.001]   


Epoch 39, Loss: 9.340227530838193e-07
Accuracy: 0.5004976114649682


100%|██████████| 938/938 [00:11<00:00, 81.09it/s, accuracy=1, loss=1.66e-5, lr=0.001]   


Epoch 40, Loss: 1.6592956130455e-05
Accuracy: 0.543093152866242
				TRAINING ON TASK 4 for 10 epochs


100%|██████████| 938/938 [00:11<00:00, 82.55it/s, accuracy=1, loss=1.17e-7, lr=0.001]   


Epoch 41, Loss: 1.1702512043912908e-07
Accuracy: 0.5341361464968153


100%|██████████| 938/938 [00:11<00:00, 82.87it/s, accuracy=1, loss=9.82e-7, lr=0.001]   


Epoch 42, Loss: 9.821581516884862e-07
Accuracy: 0.5361265923566879


100%|██████████| 938/938 [00:11<00:00, 82.48it/s, accuracy=1, loss=6.18e-7, lr=0.001]   


Epoch 43, Loss: 6.178445517705236e-07
Accuracy: 0.5007961783439491


100%|██████████| 938/938 [00:11<00:00, 82.52it/s, accuracy=0.969, loss=0.196, lr=0.001] 


Epoch 44, Loss: 0.19581500143936803
Accuracy: 0.5338375796178344


100%|██████████| 938/938 [00:11<00:00, 82.74it/s, accuracy=1, loss=1.87e-6, lr=0.001]   


Epoch 45, Loss: 1.867426038542774e-06
Accuracy: 0.5256767515923567


100%|██████████| 938/938 [00:11<00:00, 82.91it/s, accuracy=1, loss=0.00154, lr=0.001]   


Epoch 46, Loss: 0.0015372908857892679
Accuracy: 0.5831011146496815


100%|██████████| 938/938 [00:11<00:00, 82.85it/s, accuracy=0.969, loss=0.0236, lr=0.001]


Epoch 47, Loss: 0.023603221462660123
Accuracy: 0.6330613057324841


100%|██████████| 938/938 [00:11<00:00, 82.71it/s, accuracy=1, loss=0.00391, lr=0.001]   


Epoch 48, Loss: 0.0039100999339592764
Accuracy: 0.6398288216560509


100%|██████████| 938/938 [00:11<00:00, 84.06it/s, accuracy=1, loss=0.00691, lr=0.001]   


Epoch 49, Loss: 0.00690649979388787
Accuracy: 0.613953025477707


100%|██████████| 938/938 [00:11<00:00, 83.41it/s, accuracy=0.969, loss=0.129, lr=0.001] 


Epoch 50, Loss: 0.12881216133520457
Accuracy: 0.6022093949044586


In [11]:
# PyKAN custom training
for lr in [1e-0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
    kan = Py_KAN()
    test(kan)
    # kan.train(lr=lr, train_loader=train_loader_tasks[0])

TypeError: Py_KAN.__init__() missing 2 required positional arguments: 'strategy' and 'device'