In [1]:
import os
import time
import math
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from torch import device
from torchvision import datasets
from torch.utils.data import DataLoader, Subset

from models import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# DEFINE SETTINGS

In [2]:
out_path = os.path.join('..', 'results', 'mnist') # remove or add 'convs'
strategy = "classIL"     # ["taskIL", "classIL"] classIL is harder 
lrs = [1e-2]
decays = [0.8]
epochss = [10]
models = [Efficient_KAN_Fix(strategy, device)]
longer_last_tasks = False
reverse_taks = False

reverse_path = ""
if reverse_taks:
    reverse_path = "reverse_tasks"
longer_last_path = ""
if longer_last_tasks:
    longer_last_path = "longer_last_tasks"

out_path = os.path.join(out_path, strategy, longer_last_path, reverse_path, 'trainings')
cfgs = []
for model in models[:1]:
    for lr in lrs:
        for decay in decays:
            for epochs in epochss:
                cfgs.append([model, epochs, lr, decay])

# Train and test sets

In [7]:
dataset = [datasets.MNIST, datasets.CIFAR10][0]
dataset_name = dataset.__name__.lower()
input_size = 28 * 28 if dataset == datasets.MNIST \
    else 3 * 32 * 32 if dataset == datasets.CIFAR10 \
    else -1

In [8]:
transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.Normalize((0.5,), (0.5,))
                                ])
# Train set. Here we sort the MNIST by digits and disable data shuffling
train_dataset = dataset(root='../data', train=True, download=True, transform=transform)
sorted_indices = sorted(range(len(train_dataset) // 1), key=lambda idx: train_dataset.targets[idx])
train_subset = Subset(train_dataset, sorted_indices)
train_loader = DataLoader(train_subset, batch_size=64, shuffle=False)

# MultiTask training sets
train_loader_tasks = []
indices = []
for k in range(5):
    indices.append(list(
        filter(lambda idx: train_dataset.targets[idx] in range(k * 2, k * 2 + 2), range(len(train_dataset)))))
    train_loader_tasks.append(
        DataLoader(Subset(train_dataset, indices[-1]), batch_size=64, shuffle=True))

# Test set
test_dataset = dataset(root='../data', train=False, download=True, transform=transform)
test_subset = Subset(test_dataset, range(len(test_dataset) // 1))
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

if reverse_taks:
    train_loader_tasks.reverse()

In [9]:
# stats = [0 for i in range(10)]
# for sample in test_dataset:
#     stats[sample[1]] += 1
# print(stats)
# mean = sum(stats)/len(stats)
# variance = sum([((x - mean) ** 2) for x in stats]) / len(stats) 
# res = variance ** 0.5
# print(mean, res)

## Trainset visualizer
The following code prints the images of the 5 domain IL scenarios. This way we can clearly see that for the MNIST dataset each task contains a pair of digits (0-1, 2-3, etc.), while for CIFAR10 each task contains a pair of objects (car-airplane, bird-dog, deer-dog, frog-horse and truck-ship).

In [10]:
# import numpy as np
# def imshow(img):
#     # img = (img / 2 + 0.5).numpy()
#     img = img.numpy()
#     plt.imshow(np.transpose(img, (1, 2, 0)))
#     plt.axis('off')
#     plt.show()


# def show_images(class_index, num_images=16):
#     dataiter = iter(train_loader_tasks[class_index])
#     images, labels = next(dataiter)
#     imshow(utils.make_grid(images))


# for class_index in range(5):
#     print(f"TASK ID = {class_index}")
#     show_images(class_index)

# Train and test functions

In [14]:
def train(model, save_dir, optimizer, lr, on_epoch_end, start_epoch=0, epochs=5, isKAN=False):
    criterion = nn.NLLLoss()
    for epoch in range(start_epoch, epochs + start_epoch):
        if not isKAN:
            model.train()
            model.to(device)
        epoch_start = time.time_ns()
        with tqdm(train_loader) as pbar:
            for images, labels in pbar:
                labels = labels.to(device)
                images = images.to(device)
                optimizer.zero_grad()
                output = model(images)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step(closure=lambda: loss)
                accuracy = (output.argmax(dim=1) == labels).float().mean()
                pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
        epoch_duration = (time.time_ns() - epoch_start) // 1000000
        if on_epoch_end is not None:
            on_epoch_end(model, save_dir, epoch, loss.item(), epoch_duration, lr, isKAN)

In [15]:
def test(model, isKAN=False):
    if not isKAN:
        model.eval()
    criterion = nn.NLLLoss()
    predictions = []
    ground_truths = []
    val_accuracy = 0
    loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            labels = labels.to(device)  #(labels % 2 if model.layers[-1] == 2 else labels).to(device)
            images = images.to(device)
            output = model(images)
            loss = criterion(output, labels)
            predictions.extend(output.argmax(dim=1).to('cpu').numpy())
            ground_truths.extend(labels.to('cpu').numpy())
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_accuracy /= len(test_loader)
    print(f"Accuracy: {val_accuracy}")
    return loss.item(), ground_truths, predictions

In [16]:
class EpochStat:
    @staticmethod
    def loadModelStats(dir, name, subdir) -> list['EpochStat']:
        return sorted([pickle.load(open(os.path.join(dir, subdir, file), 'rb')) for file in
                       filter(lambda e: name == '_'.join(e.split('_')[:-1]), os.listdir(os.path.join(dir, subdir)))],
                      key=lambda e: e.epoch)

    def __init__(self, name, save_dir, epoch, train_loss=0, test_loss=0, labels=None, predictions=None, epoch_duration=0, lr=0):
        self.name = name
        self.save_dir = save_dir
        self.train_loss = train_loss
        self.test_loss = test_loss
        self.epoch = epoch
        self.predictions = predictions
        self.labels = labels
        self.epoch_duration = epoch_duration
        self.lr = lr
        self.train_losses = []
        self.train_accuracies = []

    def save(self):
        os.makedirs(self.save_dir, exist_ok=True)
        pickle.dump(self, open(os.path.join(self.save_dir, self.name + '_e' + str(self.epoch) + '.pickle'), 'wb'))

    def get_accuracy(self):
        accuracy = 0
        for label, prediction in zip(self.labels, self.predictions):
            if label == prediction:
                accuracy += 1
        return accuracy / len(self.labels)


def onEpochEnd(model, save_dir, epoch, train_loss, epoch_duration, lr, isKAN):
    test_loss, labels, predictions = test(model, isKAN)
    stat = EpochStat(model.__class__.__name__, save_dir, epoch, train_loss, test_loss, labels, predictions, epoch_duration, lr)
    stat.save()

# Domain IL - training

In [None]:
for cfg in cfgs:
    model = cfg[0]
    epochs = cfg[1]
    lr = cfg[2]
    decay_f = cfg[3]
    if decay_f == 1:
        lr_decay = False
    else:
        lr_decay = True
    start_epochs_list = [int(epochs + epochs*i[0]) for i in enumerate(train_loader_tasks)]
    start_epochs_list.insert(0, 0)
    naam = model.__class__.__name__
    isKAN = False
    print("\n\n", naam)
    print(epochs, lr, decay_f, "\n")
    if 'Py_KAN' in naam:
        isKAN = True
    for i, task in enumerate(train_loader_tasks):
        epochs_act = epochs
        if longer_last_tasks and i > 3:
            epochs_act = epochs + epochs

        str_print = f'\t\t\t\tTRAINING ON TASK {i}'
        str_print +=  f' for {epochs_act} epochs' 
        print(str_print)
        # str_epoch = f"ep{epochs}_10fin_"
        str_epoch = f"ep{epochs}"
        str_lr = f"_lr{round(math.log10(lr))}"
        str_decay = '_dec'+ str(decay_f) if lr_decay else ''
        lr_act = lr * decay_f**(i)
        train(model, os.path.join(out_path,f"{str_epoch}{str_lr}{str_decay}", naam), optimizer=optim.Adam(model.parameters()),
                lr=lr_act, on_epoch_end=onEpochEnd, start_epoch=start_epochs_list[i], epochs=epochs_act, isKAN=isKAN)
    torch.cuda.empty_cache()

In [16]:
# PyKAN custom training
for lr in [1e-0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
    kan = Py_KAN()
    test(kan)
    # kan.train(lr=lr, train_loader=train_loader_tasks[0])

Accuracy: 0.15684713375796178
Accuracy: 0.15684713375796178



KeyboardInterrupt

