In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from EWC import EWC
from SimpleCNN import SimpleCNN

In [None]:
import wandb

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "CIFAR-100",
        "epochs": 10,
    }
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-6
momentum = 0.9
num_epochs = 20
batch_size = 8
lambda_ewc = 0.4

In [None]:
# Data transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Datasets and DataLoaders
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

dataloaders = {'train': train_loader, 'test': test_loader}

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def train_with_ewc(model, dataloaders, criterion, optimizer, num_epochs, ewc=None, lambda_ewc=0.4):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in dataloaders['train']:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            if ewc:
                ewc_penalty = ewc.penalty(model)
                loss += lambda_ewc * ewc_penalty

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(dataloaders['train'].dataset)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    return model

In [None]:
def test_model(model, filename):
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        n_class_correct = [0 for i in range(10)]
        n_class_samples = [0 for i in range(10)]
        with open(f'./models/{filename}.txt', 'w') as f:
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                n_samples += labels.size(0)
                n_correct += (predicted == labels).sum().item()

                for i in range(batch_size):
                    label = labels[i]
                    pred = predicted[i]
                    if label == pred:
                        n_class_correct[label] += 1
                    n_class_samples[label] += 1

            acc = 100.0 * n_correct / n_samples
            print(f'Accuracy of th/e network: {acc} ?%')
            f.write(f'Accuracy of the network: {acc:.2f} %\n')

            for i in range(10):
                acc = 100.0 * n_class_correct[i] / n_class_samples[i]
                print(f'Accuracy of {classes[i]}: {acc} %')
                f.write(f'Accuracy of {classes[i]}: {acc} %')

In [None]:
def save_model(model, filename):
    torch.save(model.state_dict(), f'./models/{filename}.pth')

In [None]:
# base model
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

In [21]:
# CNN
model = train_with_ewc(model, dataloaders, criterion, optimizer, num_epochs)
filename = f'cnn-lr-{learning_rate}-m-{momentum}-bz-{batch_size}-ep-{num_epochs}'
save_model(model, filename)
test_model(model, filename)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/ryan/Desktop/CSCI566/.venv/lib/python3.9/site-packages/torchvision/__init__.py", line 10, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
  File "/Users/ryan/Desktop/CSCI566/.venv/lib/python3.9/site-packages/torchvision/models/__init__.py", line 2, in <module>
    from .convnext import *
  File "/Users/ryan/Desktop/CSCI566/.venv/lib/python3.9/site-packages/torchvision/models/convnext.py", line 8, in <module>
    from ..o

KeyboardInterrupt: 

In [None]:
# EWC
ewc = EWC(model, train_loader, device=device)
model = train_with_ewc(model, dataloaders, criterion, optimizer, num_epochs, ewc=ewc, lambda_ewc=0.4)
save_model(model, f'ewc-lr-{learning_rate}-m-{momentum}-bz-{batch_size}-ep-{num_epochs}-lambda-{lambda_ewc}')